1use crate::error::{ReductionError, Result};
25#[cfg(not(feature = "blas"))]
26use linfa_linalg::{lobpcg::TruncatedSvd, Order};
27use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
28#[cfg(feature = "blas")]
29use ndarray_linalg::{TruncatedOrder, TruncatedSvd};
30use rand::{prelude::SmallRng, SeedableRng};
31#[cfg(feature = "serde")]
32use serde_crate::{Deserialize, Serialize};
33
34use linfa::{
35 dataset::Records,
36 traits::{Fit, PredictInplace, Transformer},
37 DatasetBase, Float,
38};
39
40#[cfg_attr(
42 feature = "serde",
43 derive(Serialize, Deserialize),
44 serde(crate = "serde_crate")
45)]
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct PcaParams {
48 embedding_size: usize,
49 apply_whitening: bool,
50}
51
52impl PcaParams {
53 pub fn whiten(mut self, apply: bool) -> Self {
58 self.apply_whitening = apply;
59
60 self
61 }
62}
63
64impl<T, D: Data<Elem = f64>> Fit<ArrayBase<D, Ix2>, T, ReductionError> for PcaParams {
77 type Object = Pca<f64>;
78
79 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Pca<f64>> {
80 if dataset.nsamples() == 0 {
81 return Err(ReductionError::NotEnoughSamples);
82 } else if dataset.nfeatures() < self.embedding_size || self.embedding_size == 0 {
83 return Err(ReductionError::EmbeddingTooSmall(self.embedding_size));
84 }
85
86 let x = dataset.records();
87 let mean = x.mean_axis(Axis(0)).unwrap();
90 let x = x - &mean;
91
92 #[cfg(feature = "blas")]
94 let result =
95 TruncatedSvd::new(x, TruncatedOrder::Largest).decompose(self.embedding_size)?;
96 #[cfg(not(feature = "blas"))]
97 let result = TruncatedSvd::new_with_rng(x, Order::Largest, SmallRng::seed_from_u64(42))
98 .decompose(self.embedding_size)?;
99 let (_, sigma, mut v_t) = result.values_vectors();
101
102 let sigma = sigma.mapv(|x| x.max(1e-8));
104
105 if self.apply_whitening {
108 let cov_scale = (dataset.nsamples() as f64 - 1.).sqrt();
109 for (mut v_t, sigma) in v_t.axis_iter_mut(Axis(0)).zip(sigma.iter()) {
110 v_t *= cov_scale / *sigma;
111 }
112 }
113
114 Ok(Pca {
115 embedding: v_t,
116 sigma,
117 mean,
118 })
119 }
120}
121
122#[cfg_attr(
142 feature = "serde",
143 derive(Serialize, Deserialize),
144 serde(crate = "serde_crate")
145)]
146#[derive(Debug, Clone, PartialEq)]
147pub struct Pca<F> {
148 embedding: Array2<F>,
149 sigma: Array1<F>,
150 mean: Array1<F>,
151}
152
153impl Pca<f64> {
154 pub fn params(embedding_size: usize) -> PcaParams {
160 PcaParams {
161 embedding_size,
162 apply_whitening: false,
163 }
164 }
165
166 pub fn explained_variance(&self) -> Array1<f64> {
168 self.sigma.mapv(|x| x * x / (self.sigma.len() as f64 - 1.0))
169 }
170
171 pub fn explained_variance_ratio(&self) -> Array1<f64> {
173 let ex_var = self.sigma.mapv(|x| x * x / (self.sigma.len() as f64 - 1.0));
174 let sum_ex_var = ex_var.sum();
175
176 ex_var / sum_ex_var
177 }
178
179 pub fn components(&self) -> &Array2<f64> {
181 &self.embedding
182 }
183
184 pub fn mean(&self) -> &Array1<f64> {
186 &self.mean
187 }
188
189 pub fn singular_values(&self) -> &Array1<f64> {
191 &self.sigma
192 }
193
194 pub fn inverse_transform(
196 &self,
197 prediction: ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>,
198 ) -> ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>> {
199 prediction.dot(&self.embedding) + &self.mean
200 }
201}
202
203impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<F>> for Pca<F> {
204 fn predict_inplace(&self, records: &ArrayBase<D, Ix2>, targets: &mut Array2<F>) {
205 assert_eq!(
206 targets.shape(),
207 &[records.nrows(), self.embedding.nrows()],
208 "The number of data points must match the number of output targets."
209 );
210 *targets = (records - &self.mean).dot(&self.embedding.t());
211 }
212
213 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
214 Array2::zeros((x.nrows(), self.embedding.nrows()))
215 }
216}
217
218impl<F: Float, D: Data<Elem = F>, T>
219 Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>> for Pca<F>
220{
221 fn transform(&self, ds: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
222 let DatasetBase {
223 records,
224 targets,
225 weights,
226 ..
227 } = ds;
228
229 let mut new_records = self.default_target(&records);
230 self.predict_inplace(&records, &mut new_records);
231
232 DatasetBase::new(new_records, targets).with_weights(weights)
233 }
234}
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::{DiffusionMap, DiffusionMapParams, DiffusionMapValidParams};
239 use approx::assert_abs_diff_eq;
240 use linfa::{traits::Predict, Dataset};
241 use ndarray::{array, Array2};
242 use ndarray_rand::{
243 rand_distr::{StandardNormal, Uniform},
244 RandomExt,
245 };
246 use rand::{rngs::SmallRng, SeedableRng};
247
248 #[test]
249 fn autotraits() {
250 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
251 has_autotraits::<DiffusionMap<f64>>();
252 has_autotraits::<DiffusionMapValidParams>();
253 has_autotraits::<DiffusionMapParams>();
254 has_autotraits::<ReductionError>();
255 has_autotraits::<PcaParams>();
256 has_autotraits::<Pca<f64>>();
257 }
258
259 #[test]
264 fn test_whitening_small() {
265 let mut rng = SmallRng::seed_from_u64(42);
267
268 let tmp = Array2::random_using((300, 2), Uniform::new(-1.0f64, 1.), &mut rng);
270 let q = array![[1., 1.], [-1., 1.]];
271
272 let dataset = Dataset::from(tmp.dot(&q));
273
274 let model = Pca::params(2).whiten(true).fit(&dataset).unwrap();
275 let proj = model.predict(&dataset);
276
277 let cov = proj.t().dot(&proj);
279 assert_abs_diff_eq!(cov / (300. - 1.), Array2::eye(2), epsilon = 1e-5);
280 }
281
282 #[test]
287 fn test_whitening_rand() {
288 let mut rng = SmallRng::seed_from_u64(42);
290
291 let data = Array2::random_using((300, 50), Uniform::new(-1.0f64, 1.), &mut rng);
293 let dataset = Dataset::from(data);
294
295 let model = Pca::params(10).whiten(true).fit(&dataset).unwrap();
296 let proj = model.predict(&dataset);
297
298 let cov = proj.t().dot(&proj);
300 assert_abs_diff_eq!(cov / (300. - 1.), Array2::eye(10), epsilon = 1e-5);
301 }
302
303 #[test]
312 fn test_marchenko_pastur() {
313 let mut rng = SmallRng::seed_from_u64(3);
315
316 let data = Array2::random_using((1000, 500), StandardNormal, &mut rng);
318 let dataset = Dataset::from(data / 1000f64.sqrt());
319
320 let model = Pca::params(500).fit(&dataset).unwrap();
321 let sv = model.singular_values().mapv(|x| x * x);
322
323 let (a, b) = (
326 1. * (1. - 0.5f64.sqrt()).powf(2.0),
327 1. * (1. + 0.5f64.sqrt()).powf(2.0),
328 );
329
330 assert_abs_diff_eq!(b, sv[0], epsilon = 0.1);
332 assert_abs_diff_eq!(a, sv[sv.len() - 1], epsilon = 0.1);
333
334 let mut i = 0;
336 'outer: for th in Array1::linspace(0.1, 2.8, 28).iter().rev() {
337 let mut count = 0;
338 while sv[i] >= *th {
339 count += 1;
340 i += 1;
341
342 if i == sv.len() {
343 break 'outer;
344 }
345 }
346
347 let x = th + 0.05;
348 let mp_law = ((b - x) * (x - a)).sqrt() / std::f64::consts::PI / x;
349 let empirical = count as f64 / 500. / ((2.8 - 0.1) / 28.);
350
351 assert_abs_diff_eq!(mp_law, empirical, epsilon = 0.06);
352 }
353 }
354
355 #[test]
356 fn test_explained_variance_cutoff() {
357 let mut rng = SmallRng::seed_from_u64(42);
359
360 let n = 500;
362 let mut a = Array1::<f64>::random_using(n, StandardNormal, &mut rng);
363 a /= (a.t().dot(&a)).sqrt();
364
365 let mut b = Array1::random_using(n, StandardNormal, &mut rng);
367 b -= &(b.t().dot(&a) * &a);
368 b /= (b.t().dot(&b)).sqrt();
369
370 let data =
372 Array2::from_shape_fn((500, 500), |dim| a[dim.0] * a[dim.1] + b[dim.0] * b[dim.1]);
373
374 let dataset = Dataset::from(data);
375
376 let model = Pca::params(10).fit(&dataset).unwrap();
378
379 assert_eq!(model.explained_variance_ratio().len(), 2);
381 assert_abs_diff_eq!(
383 model.explained_variance_ratio(),
384 array![1. / 2., 1. / 2.],
385 epsilon = 1e-2
386 );
387 }
388
389 #[test]
390 fn test_explained_variance_diag() {
391 let dataset = Dataset::from(Array2::from_diag(&array![1., 1., 1., 1.]));
392 let model = Pca::params(3).fit(&dataset).unwrap();
393
394 assert_abs_diff_eq!(
395 model.explained_variance_ratio(),
396 array![1. / 3., 1. / 3., 1. / 3.],
397 epsilon = 1e-6
398 );
399 }
400}