linfa_preprocessing/
whitening.rs

1//! Methods for uncorrelating data
2//!
3//! Whitening refers to a collection of methods that, given in input a matrix `X` of records with
4//! covariance matrix =  `sigma`, output a whitening matrix `W` such that `W.T` dot `W` = `sigma`.
5//! Appliyng the whitening matrix `W` to the input data gives a new data matrix `Y` of the same
6//! size as the input such that `Y` has
7//! unit diagonal (white) covariance matrix.
8
9use crate::error::{PreprocessingError, Result};
10use linfa::dataset::{AsTargets, Records, WithLapack, WithoutLapack};
11use linfa::traits::{Fit, Transformer};
12use linfa::{DatasetBase, Float};
13#[cfg(not(feature = "blas"))]
14use linfa_linalg::{
15    cholesky::{CholeskyInplace, InverseCInplace},
16    svd::SVD,
17};
18use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
19#[cfg(feature = "blas")]
20use ndarray_linalg::{
21    cholesky::{CholeskyInto, InverseCInto, UPLO},
22    svd::SVD,
23    Scalar,
24};
25
26#[cfg(feature = "serde")]
27use serde_crate::{Deserialize, Serialize};
28
29#[cfg_attr(
30    feature = "serde",
31    derive(Serialize, Deserialize),
32    serde(crate = "serde_crate")
33)]
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub enum WhiteningMethod {
36    Pca,
37    Zca,
38    Cholesky,
39}
40
41/// Struct that can be fitted to the input data to obtain the related whitening matrix.
42/// Fitting returns a [FittedWhitener] struct that can be used to
43/// apply the whitening transformation to the input data.
44#[cfg_attr(
45    feature = "serde",
46    derive(Serialize, Deserialize),
47    serde(crate = "serde_crate")
48)]
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct Whitener {
51    method: WhiteningMethod,
52}
53
54impl Whitener {
55    /// Creates an instance of a Whitener that uses the PCA method
56    pub fn pca() -> Self {
57        Self {
58            method: WhiteningMethod::Pca,
59        }
60    }
61    /// Creates an instance of a Whitener that uses the ZCA (Mahalanobis) method
62    pub fn zca() -> Self {
63        Self {
64            method: WhiteningMethod::Zca,
65        }
66    }
67    /// Creates an instance of a Whitener that uses the cholesky decomposition of the inverse of the covariance matrix
68    pub fn cholesky() -> Self {
69        Self {
70            method: WhiteningMethod::Cholesky,
71        }
72    }
73
74    pub fn method(mut self, method: WhiteningMethod) -> Self {
75        self.method = method;
76        self
77    }
78}
79
80impl<F: Float, D: Data<Elem = F>, T: AsTargets> Fit<ArrayBase<D, Ix2>, T, PreprocessingError>
81    for Whitener
82{
83    type Object = FittedWhitener<F>;
84
85    fn fit(&self, x: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
86        if x.nsamples() == 0 {
87            return Err(PreprocessingError::NotEnoughSamples);
88        }
89        // safe because of above zero samples check
90        let mean = x.records().mean_axis(Axis(0)).unwrap();
91        let sigma = x.records() - &mean;
92
93        // add Lapack + Scalar trait bounds
94        let sigma = sigma.with_lapack();
95
96        let transformation_matrix = match self.method {
97            WhiteningMethod::Pca => {
98                let (_, s, v_t) = sigma.svd(false, true)?;
99
100                // Safe because the second argument in the above call is set to true
101                let mut v_t = v_t.unwrap().without_lapack();
102                #[cfg(feature = "blas")]
103                let s = s.mapv(Scalar::from_real);
104                let s = s.without_lapack();
105
106                let s = s.mapv(|x: F| x.max(F::cast(1e-8)));
107
108                let cov_scale = F::cast(x.nsamples() - 1).sqrt();
109                for (mut v_t, s) in v_t.axis_iter_mut(Axis(0)).zip(s.iter()) {
110                    v_t *= cov_scale / *s;
111                }
112
113                v_t
114            }
115            WhiteningMethod::Zca => {
116                let sigma = sigma.t().dot(&sigma) / F::Lapack::cast(x.nsamples() - 1);
117                let (u, s, _) = sigma.svd(true, false)?;
118
119                // Safe because the first argument in the above call is set to true
120                let u = u.unwrap().without_lapack();
121                #[cfg(feature = "blas")]
122                let s = s.mapv(Scalar::from_real);
123                let s = s.without_lapack();
124
125                let s = s.mapv(|x: F| (F::one() / x.sqrt()).max(F::cast(1e-8)));
126                let lambda: Array2<F> = Array2::<F>::eye(s.len()) * s;
127                u.dot(&lambda).dot(&u.t())
128            }
129            WhiteningMethod::Cholesky => {
130                let sigma = sigma.t().dot(&sigma) / F::Lapack::cast(x.nsamples() - 1);
131                // sigma must be positive definite for us to call cholesky on its inverse, so invc
132                // is allowed here
133                #[cfg(feature = "blas")]
134                let out = sigma
135                    .invc_into()?
136                    .cholesky_into(UPLO::Upper)?
137                    .without_lapack();
138                #[cfg(not(feature = "blas"))]
139                let mut sigma = sigma;
140                #[cfg(not(feature = "blas"))]
141                let out = sigma
142                    .invc_inplace()?
143                    .reversed_axes()
144                    .cholesky_into()?
145                    .reversed_axes()
146                    .without_lapack();
147                out
148            }
149        };
150
151        Ok(FittedWhitener {
152            transformation_matrix,
153            mean,
154        })
155    }
156}
157
158/// Struct that can be used to whiten data. Data will be scaled according to the whitening matrix learned
159/// during fitting.
160/// Obtained by fitting a [Whitener].
161///
162/// Transforming the data used during fitting will yield a scaled data matrix with
163/// unit diagonal covariance matrix.
164///
165/// ### Example
166///
167/// ```rust
168/// use linfa::traits::{Fit, Transformer};
169/// use linfa_preprocessing::whitening::Whitener;
170///
171/// // Load dataset
172/// let dataset = linfa_datasets::diabetes();
173/// // Learn whitening parameters
174/// let whitener = Whitener::pca().fit(&dataset).unwrap();
175/// // transform dataset according to whitening parameters
176/// let dataset = whitener.transform(dataset);
177/// ```
178#[cfg_attr(
179    feature = "serde",
180    derive(Serialize, Deserialize),
181    serde(crate = "serde_crate")
182)]
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct FittedWhitener<F: Float> {
185    transformation_matrix: Array2<F>,
186    mean: Array1<F>,
187}
188
189impl<F: Float> FittedWhitener<F> {
190    /// The matrix used for scaling the data
191    pub fn transformation_matrix(&self) -> ArrayView2<F> {
192        self.transformation_matrix.view()
193    }
194
195    /// The means that will be subtracted to the features before scaling the data
196    pub fn mean(&self) -> ArrayView1<F> {
197        self.mean.view()
198    }
199}
200
201impl<F: Float> Transformer<Array2<F>, Array2<F>> for FittedWhitener<F> {
202    fn transform(&self, x: Array2<F>) -> Array2<F> {
203        (x - &self.mean).dot(&self.transformation_matrix.t())
204    }
205}
206
207impl<F: Float, D: Data<Elem = F>, T: AsTargets>
208    Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>>
209    for FittedWhitener<F>
210{
211    fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
212        let feature_names = x.feature_names().to_vec();
213        let target_names = x.target_names().to_vec();
214        let (records, targets, weights) = (x.records, x.targets, x.weights);
215        let records = self.transform(records.to_owned());
216        DatasetBase::new(records, targets)
217            .with_weights(weights)
218            .with_feature_names(feature_names)
219            .with_target_names(target_names)
220    }
221}
222
223#[cfg(test)]
224mod tests {
225
226    use super::*;
227    use approx::assert_abs_diff_eq;
228
229    use ndarray_rand::{
230        rand::distributions::Uniform, rand::rngs::SmallRng, rand::SeedableRng, RandomExt,
231    };
232
233    fn cov<D: Data<Elem = f64>>(x: &ArrayBase<D, Ix2>) -> Array2<f64> {
234        let mean = x.mean_axis(Axis(0)).unwrap();
235        let sigma = x - &mean;
236        let sigma = sigma.t().dot(&sigma) / ((x.dim().0 - 1) as f64);
237        sigma
238    }
239
240    fn inv_cov<D: Data<Elem = f64>>(x: &ArrayBase<D, Ix2>) -> Array2<f64> {
241        #[cfg(feature = "blas")]
242        let inv = cov(x).invc_into().unwrap();
243        #[cfg(not(feature = "blas"))]
244        let inv = cov(x).invc_inplace().unwrap();
245        inv
246    }
247
248    #[test]
249    fn autotraits() {
250        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
251        has_autotraits::<Whitener>();
252        has_autotraits::<WhiteningMethod>();
253        has_autotraits::<FittedWhitener<f64>>();
254    }
255
256    #[test]
257    fn test_zca_matrix() {
258        let mut rng = SmallRng::seed_from_u64(42);
259        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
260        let whitener = Whitener::zca().fit(&dataset).unwrap();
261        let inv_cov_est = whitener
262            .transformation_matrix()
263            .t()
264            .dot(&whitener.transformation_matrix());
265        let inv_cov = inv_cov(dataset.records());
266        assert_abs_diff_eq!(inv_cov, inv_cov_est, epsilon = 1e-9);
267    }
268
269    #[test]
270    fn test_cholesky_matrix() {
271        let mut rng = SmallRng::seed_from_u64(42);
272        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
273        let whitener = Whitener::cholesky().fit(&dataset).unwrap();
274        let inv_cov_est = whitener
275            .transformation_matrix()
276            .t()
277            .dot(&whitener.transformation_matrix());
278        let inv_cov = inv_cov(dataset.records());
279        assert_abs_diff_eq!(inv_cov, inv_cov_est, epsilon = 1e-10);
280    }
281
282    #[test]
283    fn test_pca_matrix() {
284        let mut rng = SmallRng::seed_from_u64(42);
285        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
286        let whitener = Whitener::pca().fit(&dataset).unwrap();
287        let inv_cov_est = whitener
288            .transformation_matrix()
289            .t()
290            .dot(&whitener.transformation_matrix());
291        let inv_cov = inv_cov(dataset.records());
292        assert_abs_diff_eq!(inv_cov, inv_cov_est, epsilon = 1e-10);
293    }
294
295    #[test]
296    fn test_cholesky_whitening() {
297        let mut rng = SmallRng::seed_from_u64(64);
298        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
299        let whitener = Whitener::cholesky().fit(&dataset).unwrap();
300        let whitened = whitener.transform(dataset);
301        let cov = cov(whitened.records());
302        assert_abs_diff_eq!(cov, Array2::eye(cov.dim().0), epsilon = 1e-10)
303    }
304
305    #[test]
306    fn test_zca_whitening() {
307        let mut rng = SmallRng::seed_from_u64(64);
308        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
309        let whitener = Whitener::zca().fit(&dataset).unwrap();
310        let whitened = whitener.transform(dataset);
311        let cov = cov(whitened.records());
312        assert_abs_diff_eq!(cov, Array2::eye(cov.dim().0), epsilon = 1e-10)
313    }
314
315    #[test]
316    fn test_pca_whitening() {
317        let mut rng = SmallRng::seed_from_u64(64);
318        let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
319        let whitener = Whitener::pca().fit(&dataset).unwrap();
320        let whitened = whitener.transform(dataset);
321        let cov = cov(whitened.records());
322        assert_abs_diff_eq!(cov, Array2::eye(cov.dim().0), epsilon = 1e-10)
323    }
324
325    #[test]
326    fn test_train_val_matrix() {
327        let (train, val) = linfa_datasets::diabetes().split_with_ratio(0.9);
328        let (train_dim, val_dim) = (train.records().dim(), val.records().dim());
329        let whitener = Whitener::pca().fit(&train).unwrap();
330        let whitened_train = whitener.transform(train);
331        let whitened_val = whitener.transform(val);
332        assert_eq!(train_dim, whitened_train.records.dim());
333        assert_eq!(val_dim, whitened_val.records.dim());
334    }
335
336    #[test]
337    fn test_retain_feature_names() {
338        let dataset = linfa_datasets::diabetes();
339        let original_feature_names = dataset.feature_names().to_vec();
340        let transformed = Whitener::cholesky()
341            .fit(&dataset)
342            .unwrap()
343            .transform(dataset);
344        assert_eq!(original_feature_names, transformed.feature_names())
345    }
346
347    #[test]
348    #[should_panic]
349    fn test_pca_fail_on_empty_input() {
350        let dataset: DatasetBase<Array2<f64>, _> = Array2::zeros((0, 0)).into();
351        let _whitener = Whitener::pca().fit(&dataset).unwrap();
352    }
353
354    #[test]
355    #[should_panic]
356    fn test_zca_fail_on_empty_input() {
357        let dataset: DatasetBase<Array2<f64>, _> = Array2::zeros((0, 0)).into();
358        let _whitener = Whitener::zca().fit(&dataset).unwrap();
359    }
360
361    #[test]
362    #[should_panic]
363    fn test_cholesky_fail_on_empty_input() {
364        let dataset: DatasetBase<Array2<f64>, _> = Array2::zeros((0, 0)).into();
365        let _whitener = Whitener::cholesky().fit(&dataset).unwrap();
366    }
367}