1use crate::error::{PreprocessingError, Result};
10use linfa::dataset::{AsTargets, Records, WithLapack, WithoutLapack};
11use linfa::traits::{Fit, Transformer};
12use linfa::{DatasetBase, Float};
13#[cfg(not(feature = "blas"))]
14use linfa_linalg::{
15 cholesky::{CholeskyInplace, InverseCInplace},
16 svd::SVD,
17};
18use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
19#[cfg(feature = "blas")]
20use ndarray_linalg::{
21 cholesky::{CholeskyInto, InverseCInto, UPLO},
22 svd::SVD,
23 Scalar,
24};
25
26#[cfg(feature = "serde")]
27use serde_crate::{Deserialize, Serialize};
28
29#[cfg_attr(
30 feature = "serde",
31 derive(Serialize, Deserialize),
32 serde(crate = "serde_crate")
33)]
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub enum WhiteningMethod {
36 Pca,
37 Zca,
38 Cholesky,
39}
40
41#[cfg_attr(
45 feature = "serde",
46 derive(Serialize, Deserialize),
47 serde(crate = "serde_crate")
48)]
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct Whitener {
51 method: WhiteningMethod,
52}
53
54impl Whitener {
55 pub fn pca() -> Self {
57 Self {
58 method: WhiteningMethod::Pca,
59 }
60 }
61 pub fn zca() -> Self {
63 Self {
64 method: WhiteningMethod::Zca,
65 }
66 }
67 pub fn cholesky() -> Self {
69 Self {
70 method: WhiteningMethod::Cholesky,
71 }
72 }
73
74 pub fn method(mut self, method: WhiteningMethod) -> Self {
75 self.method = method;
76 self
77 }
78}
79
80impl<F: Float, D: Data<Elem = F>, T: AsTargets> Fit<ArrayBase<D, Ix2>, T, PreprocessingError>
81 for Whitener
82{
83 type Object = FittedWhitener<F>;
84
85 fn fit(&self, x: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
86 if x.nsamples() == 0 {
87 return Err(PreprocessingError::NotEnoughSamples);
88 }
89 let mean = x.records().mean_axis(Axis(0)).unwrap();
91 let sigma = x.records() - &mean;
92
93 let sigma = sigma.with_lapack();
95
96 let transformation_matrix = match self.method {
97 WhiteningMethod::Pca => {
98 let (_, s, v_t) = sigma.svd(false, true)?;
99
100 let mut v_t = v_t.unwrap().without_lapack();
102 #[cfg(feature = "blas")]
103 let s = s.mapv(Scalar::from_real);
104 let s = s.without_lapack();
105
106 let s = s.mapv(|x: F| x.max(F::cast(1e-8)));
107
108 let cov_scale = F::cast(x.nsamples() - 1).sqrt();
109 for (mut v_t, s) in v_t.axis_iter_mut(Axis(0)).zip(s.iter()) {
110 v_t *= cov_scale / *s;
111 }
112
113 v_t
114 }
115 WhiteningMethod::Zca => {
116 let sigma = sigma.t().dot(&sigma) / F::Lapack::cast(x.nsamples() - 1);
117 let (u, s, _) = sigma.svd(true, false)?;
118
119 let u = u.unwrap().without_lapack();
121 #[cfg(feature = "blas")]
122 let s = s.mapv(Scalar::from_real);
123 let s = s.without_lapack();
124
125 let s = s.mapv(|x: F| (F::one() / x.sqrt()).max(F::cast(1e-8)));
126 let lambda: Array2<F> = Array2::<F>::eye(s.len()) * s;
127 u.dot(&lambda).dot(&u.t())
128 }
129 WhiteningMethod::Cholesky => {
130 let sigma = sigma.t().dot(&sigma) / F::Lapack::cast(x.nsamples() - 1);
131 #[cfg(feature = "blas")]
134 let out = sigma
135 .invc_into()?
136 .cholesky_into(UPLO::Upper)?
137 .without_lapack();
138 #[cfg(not(feature = "blas"))]
139 let mut sigma = sigma;
140 #[cfg(not(feature = "blas"))]
141 let out = sigma
142 .invc_inplace()?
143 .reversed_axes()
144 .cholesky_into()?
145 .reversed_axes()
146 .without_lapack();
147 out
148 }
149 };
150
151 Ok(FittedWhitener {
152 transformation_matrix,
153 mean,
154 })
155 }
156}
157
158#[cfg_attr(
179 feature = "serde",
180 derive(Serialize, Deserialize),
181 serde(crate = "serde_crate")
182)]
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct FittedWhitener<F: Float> {
185 transformation_matrix: Array2<F>,
186 mean: Array1<F>,
187}
188
189impl<F: Float> FittedWhitener<F> {
190 pub fn transformation_matrix(&self) -> ArrayView2<F> {
192 self.transformation_matrix.view()
193 }
194
195 pub fn mean(&self) -> ArrayView1<F> {
197 self.mean.view()
198 }
199}
200
201impl<F: Float> Transformer<Array2<F>, Array2<F>> for FittedWhitener<F> {
202 fn transform(&self, x: Array2<F>) -> Array2<F> {
203 (x - &self.mean).dot(&self.transformation_matrix.t())
204 }
205}
206
207impl<F: Float, D: Data<Elem = F>, T: AsTargets>
208 Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>>
209 for FittedWhitener<F>
210{
211 fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
212 let feature_names = x.feature_names().to_vec();
213 let target_names = x.target_names().to_vec();
214 let (records, targets, weights) = (x.records, x.targets, x.weights);
215 let records = self.transform(records.to_owned());
216 DatasetBase::new(records, targets)
217 .with_weights(weights)
218 .with_feature_names(feature_names)
219 .with_target_names(target_names)
220 }
221}
222
223#[cfg(test)]
224mod tests {
225
226 use super::*;
227 use approx::assert_abs_diff_eq;
228
229 use ndarray_rand::{
230 rand::distributions::Uniform, rand::rngs::SmallRng, rand::SeedableRng, RandomExt,
231 };
232
233 fn cov<D: Data<Elem = f64>>(x: &ArrayBase<D, Ix2>) -> Array2<f64> {
234 let mean = x.mean_axis(Axis(0)).unwrap();
235 let sigma = x - &mean;
236 let sigma = sigma.t().dot(&sigma) / ((x.dim().0 - 1) as f64);
237 sigma
238 }
239
240 fn inv_cov<D: Data<Elem = f64>>(x: &ArrayBase<D, Ix2>) -> Array2<f64> {
241 #[cfg(feature = "blas")]
242 let inv = cov(x).invc_into().unwrap();
243 #[cfg(not(feature = "blas"))]
244 let inv = cov(x).invc_inplace().unwrap();
245 inv
246 }
247
248 #[test]
249 fn autotraits() {
250 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
251 has_autotraits::<Whitener>();
252 has_autotraits::<WhiteningMethod>();
253 has_autotraits::<FittedWhitener<f64>>();
254 }
255
256 #[test]
257 fn test_zca_matrix() {
258 let mut rng = SmallRng::seed_from_u64(42);
259 let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
260 let whitener = Whitener::zca().fit(&dataset).unwrap();
261 let inv_cov_est = whitener
262 .transformation_matrix()
263 .t()
264 .dot(&whitener.transformation_matrix());
265 let inv_cov = inv_cov(dataset.records());
266 assert_abs_diff_eq!(inv_cov, inv_cov_est, epsilon = 1e-9);
267 }
268
269 #[test]
270 fn test_cholesky_matrix() {
271 let mut rng = SmallRng::seed_from_u64(42);
272 let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
273 let whitener = Whitener::cholesky().fit(&dataset).unwrap();
274 let inv_cov_est = whitener
275 .transformation_matrix()
276 .t()
277 .dot(&whitener.transformation_matrix());
278 let inv_cov = inv_cov(dataset.records());
279 assert_abs_diff_eq!(inv_cov, inv_cov_est, epsilon = 1e-10);
280 }
281
282 #[test]
283 fn test_pca_matrix() {
284 let mut rng = SmallRng::seed_from_u64(42);
285 let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
286 let whitener = Whitener::pca().fit(&dataset).unwrap();
287 let inv_cov_est = whitener
288 .transformation_matrix()
289 .t()
290 .dot(&whitener.transformation_matrix());
291 let inv_cov = inv_cov(dataset.records());
292 assert_abs_diff_eq!(inv_cov, inv_cov_est, epsilon = 1e-10);
293 }
294
295 #[test]
296 fn test_cholesky_whitening() {
297 let mut rng = SmallRng::seed_from_u64(64);
298 let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
299 let whitener = Whitener::cholesky().fit(&dataset).unwrap();
300 let whitened = whitener.transform(dataset);
301 let cov = cov(whitened.records());
302 assert_abs_diff_eq!(cov, Array2::eye(cov.dim().0), epsilon = 1e-10)
303 }
304
305 #[test]
306 fn test_zca_whitening() {
307 let mut rng = SmallRng::seed_from_u64(64);
308 let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
309 let whitener = Whitener::zca().fit(&dataset).unwrap();
310 let whitened = whitener.transform(dataset);
311 let cov = cov(whitened.records());
312 assert_abs_diff_eq!(cov, Array2::eye(cov.dim().0), epsilon = 1e-10)
313 }
314
315 #[test]
316 fn test_pca_whitening() {
317 let mut rng = SmallRng::seed_from_u64(64);
318 let dataset = Array2::random_using((1000, 7), Uniform::from(-30. ..30.), &mut rng).into();
319 let whitener = Whitener::pca().fit(&dataset).unwrap();
320 let whitened = whitener.transform(dataset);
321 let cov = cov(whitened.records());
322 assert_abs_diff_eq!(cov, Array2::eye(cov.dim().0), epsilon = 1e-10)
323 }
324
325 #[test]
326 fn test_train_val_matrix() {
327 let (train, val) = linfa_datasets::diabetes().split_with_ratio(0.9);
328 let (train_dim, val_dim) = (train.records().dim(), val.records().dim());
329 let whitener = Whitener::pca().fit(&train).unwrap();
330 let whitened_train = whitener.transform(train);
331 let whitened_val = whitener.transform(val);
332 assert_eq!(train_dim, whitened_train.records.dim());
333 assert_eq!(val_dim, whitened_val.records.dim());
334 }
335
336 #[test]
337 fn test_retain_feature_names() {
338 let dataset = linfa_datasets::diabetes();
339 let original_feature_names = dataset.feature_names().to_vec();
340 let transformed = Whitener::cholesky()
341 .fit(&dataset)
342 .unwrap()
343 .transform(dataset);
344 assert_eq!(original_feature_names, transformed.feature_names())
345 }
346
347 #[test]
348 #[should_panic]
349 fn test_pca_fail_on_empty_input() {
350 let dataset: DatasetBase<Array2<f64>, _> = Array2::zeros((0, 0)).into();
351 let _whitener = Whitener::pca().fit(&dataset).unwrap();
352 }
353
354 #[test]
355 #[should_panic]
356 fn test_zca_fail_on_empty_input() {
357 let dataset: DatasetBase<Array2<f64>, _> = Array2::zeros((0, 0)).into();
358 let _whitener = Whitener::zca().fit(&dataset).unwrap();
359 }
360
361 #[test]
362 #[should_panic]
363 fn test_cholesky_fail_on_empty_input() {
364 let dataset: DatasetBase<Array2<f64>, _> = Array2::zeros((0, 0)).into();
365 let _whitener = Whitener::cholesky().fit(&dataset).unwrap();
366 }
367}