1use crate::errors::{PlsError, Result};
2use crate::utils;
3use crate::{PlsParams, PlsValidParams};
4
5use linfa::{
6 dataset::{Records, WithLapack, WithoutLapack},
7 traits::Fit,
8 traits::PredictInplace,
9 traits::Transformer,
10 Dataset, DatasetBase, Float,
11};
12#[cfg(not(feature = "blas"))]
13use linfa_linalg::svd::*;
14use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
15#[cfg(feature = "blas")]
16use ndarray_linalg::svd::*;
17use ndarray_stats::QuantileExt;
18#[cfg(feature = "serde")]
19use serde_crate::{Deserialize, Serialize};
20
21#[cfg_attr(
22 feature = "serde",
23 derive(Serialize, Deserialize),
24 serde(crate = "serde_crate")
25)]
26#[derive(Debug, Clone, PartialEq)]
27pub(crate) struct Pls<F: Float> {
28 x_mean: Array1<F>,
29 x_std: Array1<F>,
30 y_mean: Array1<F>,
31 y_std: Array1<F>,
32 x_weights: Array2<F>, y_weights: Array2<F>, #[cfg(test)]
35 x_scores: Array2<F>, #[cfg(test)]
37 y_scores: Array2<F>, x_loadings: Array2<F>, y_loadings: Array2<F>, x_rotations: Array2<F>,
41 y_rotations: Array2<F>,
42 coefficients: Array2<F>,
43}
44
45#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
46pub enum Algorithm {
47 Nipals,
48 Svd,
49}
50
51#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
52pub(crate) enum DeflationMode {
53 Regression,
54 Canonical,
55}
56
57#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
58pub(crate) enum Mode {
59 A,
60 B,
61}
62
63impl<F: Float> Pls<F> {
68 pub fn regression(n_components: usize) -> PlsParams<F> {
70 PlsParams::new(n_components)
71 }
72
73 pub fn canonical(n_components: usize) -> PlsParams<F> {
75 PlsParams::new(n_components).deflation_mode(DeflationMode::Canonical)
76 }
77
78 pub fn cca(n_components: usize) -> PlsParams<F> {
80 PlsParams::new(n_components)
81 .deflation_mode(DeflationMode::Canonical)
82 .mode(Mode::B)
83 }
84
85 pub fn weights(&self) -> (&Array2<F>, &Array2<F>) {
86 (&self.x_weights, &self.y_weights)
87 }
88
89 #[cfg(test)]
90 pub fn scores(&self) -> (&Array2<F>, &Array2<F>) {
91 (&self.x_scores, &self.y_scores)
92 }
93
94 pub fn loadings(&self) -> (&Array2<F>, &Array2<F>) {
95 (&self.x_loadings, &self.y_loadings)
96 }
97
98 pub fn rotations(&self) -> (&Array2<F>, &Array2<F>) {
99 (&self.x_rotations, &self.y_rotations)
100 }
101
102 pub fn coefficients(&self) -> &Array2<F> {
103 &self.coefficients
104 }
105
106 pub fn inverse_transform(
107 &self,
108 dataset: DatasetBase<
109 ArrayBase<impl Data<Elem = F>, Ix2>,
110 ArrayBase<impl Data<Elem = F>, Ix2>,
111 >,
112 ) -> DatasetBase<Array2<F>, Array2<F>> {
113 let mut x_orig = dataset.records().dot(&self.x_loadings.t());
114 x_orig = &x_orig * &self.x_std;
115 x_orig = &x_orig + &self.x_mean;
116 let mut y_orig = dataset.targets().dot(&self.y_loadings.t());
117 y_orig = &y_orig * &self.y_std;
118 y_orig = &y_orig + &self.y_mean;
119 Dataset::new(x_orig, y_orig)
120 }
121}
122
123impl<F: Float, D: Data<Elem = F>>
124 Transformer<
125 DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
126 DatasetBase<Array2<F>, Array2<F>>,
127 > for Pls<F>
128{
129 fn transform(
130 &self,
131 dataset: DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
132 ) -> DatasetBase<Array2<F>, Array2<F>> {
133 let mut x_norm = dataset.records() - &self.x_mean;
134 x_norm /= &self.x_std;
135 let mut y_norm = dataset.targets() - &self.y_mean;
136 y_norm /= &self.y_std;
137 let x_proj = x_norm.dot(&self.x_rotations);
139 let y_proj = y_norm.dot(&self.y_rotations);
140 Dataset::new(x_proj, y_proj)
141 }
142}
143
144impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<F>> for Pls<F> {
145 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array2<F>) {
146 assert_eq!(
147 y.shape(),
148 &[x.nrows(), self.coefficients.ncols()],
149 "The number of data points must match the number of output targets."
150 );
151
152 let mut x = x - &self.x_mean;
153 x /= &self.x_std;
154 *y = x.dot(&self.coefficients) + &self.y_mean;
155 }
156
157 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
158 Array2::zeros((x.nrows(), self.coefficients.ncols()))
159 }
160}
161
162impl<F: Float, D: Data<Elem = F>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, PlsError>
163 for PlsValidParams<F>
164{
165 type Object = Pls<F>;
166
167 fn fit(
168 &self,
169 dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
170 ) -> Result<Self::Object> {
171 let records = dataset.records();
172 let targets = dataset.targets();
173
174 let n = records.nrows();
175 let p = records.ncols();
176 let q = targets.ncols();
177
178 if n < 2 {
179 return Err(PlsError::NotEnoughSamplesError(
180 dataset.records().nsamples(),
181 ));
182 }
183
184 let n_components = self.n_components();
185 let rank_upper_bound = match self.deflation_mode() {
186 DeflationMode::Regression => {
187 p
190 }
191 DeflationMode::Canonical => {
192 n.min(p.min(q))
195 }
196 };
197
198 if 1 > n_components || n_components > rank_upper_bound {
199 return Err(PlsError::BadComponentNumberError {
200 upperbound: rank_upper_bound,
201 actual: n_components,
202 });
203 }
204 let norm_y_weights = self.deflation_mode() == DeflationMode::Canonical;
205 let (mut xk, mut yk, x_mean, y_mean, x_std, y_std) =
206 utils::center_scale_dataset(dataset, self.scale());
207
208 let mut x_weights = Array2::<F>::zeros((p, n_components)); let mut y_weights = Array2::<F>::zeros((q, n_components)); let mut x_scores = Array2::<F>::zeros((n, n_components)); let mut y_scores = Array2::<F>::zeros((n, n_components)); let mut x_loadings = Array2::<F>::zeros((p, n_components)); let mut y_loadings = Array2::<F>::zeros((q, n_components)); let mut n_iters = Array1::zeros(n_components);
215
216 let eps = F::epsilon();
220 for k in 0..n_components {
221 let (mut x_weights_k, mut y_weights_k) = match self.algorithm() {
225 Algorithm::Nipals => {
226 for mut yj in yk.columns_mut() {
228 if *(yj.mapv(|y| y.abs()).max()?) < F::cast(10.) * eps {
229 yj.assign(&Array1::zeros(yj.len()));
230 }
231 }
232
233 let (x_weights_k, y_weights_k, n_iter) =
234 self.get_first_singular_vectors_power_method(&xk, &yk, norm_y_weights)?;
235 n_iters[k] = n_iter;
236 (x_weights_k, y_weights_k)
237 }
238 Algorithm::Svd => self.get_first_singular_vectors_svd(&xk, &yk)?,
239 };
240 utils::svd_flip_1d(&mut x_weights_k, &mut y_weights_k);
241
242 let x_scores_k = xk.dot(&x_weights_k);
244 let y_ss = if norm_y_weights {
245 F::one()
246 } else {
247 y_weights_k.dot(&y_weights_k)
248 };
249 let y_scores_k = yk.dot(&y_weights_k) / y_ss;
250
251 let x_loadings_k = x_scores_k.dot(&xk) / x_scores_k.dot(&x_scores_k);
253 xk = xk - utils::outer(&x_scores_k, &x_loadings_k); let y_loadings_k = match self.deflation_mode() {
256 DeflationMode::Canonical => {
257 let y_loadings_k = y_scores_k.dot(&yk) / y_scores_k.dot(&y_scores_k);
259 yk = yk - utils::outer(&y_scores_k, &y_loadings_k); y_loadings_k
261 }
262 DeflationMode::Regression => {
263 let y_loadings_k = x_scores_k.dot(&yk) / x_scores_k.dot(&x_scores_k);
265 yk = yk - utils::outer(&x_scores_k, &y_loadings_k); y_loadings_k
267 }
268 };
269
270 x_weights.column_mut(k).assign(&x_weights_k);
271 y_weights.column_mut(k).assign(&y_weights_k);
272 x_scores.column_mut(k).assign(&x_scores_k);
273 y_scores.column_mut(k).assign(&y_scores_k);
274 x_loadings.column_mut(k).assign(&x_loadings_k);
275 y_loadings.column_mut(k).assign(&y_loadings_k);
276 }
277 let x_rotations = x_weights.dot(&utils::pinv2(x_loadings.t().dot(&x_weights).view(), None));
284 let y_rotations = y_weights.dot(&utils::pinv2(y_loadings.t().dot(&y_weights).view(), None));
285
286 let mut coefficients = x_rotations.dot(&y_loadings.t());
287 coefficients *= &y_std;
288
289 Ok(Pls {
290 x_mean,
291 x_std,
292 y_mean,
293 y_std,
294 x_weights,
295 y_weights,
296 #[cfg(test)]
297 x_scores,
298 #[cfg(test)]
299 y_scores,
300 x_loadings,
301 y_loadings,
302 x_rotations,
303 y_rotations,
304 coefficients,
305 })
306 }
307}
308
309impl<F: Float> PlsValidParams<F> {
310 fn get_first_singular_vectors_power_method(
313 &self,
314 x: &ArrayBase<impl Data<Elem = F>, Ix2>,
315 y: &ArrayBase<impl Data<Elem = F>, Ix2>,
316 norm_y_weights: bool,
317 ) -> Result<(Array1<F>, Array1<F>, usize)> {
318 let eps = F::epsilon();
319
320 let mut y_score = None;
321 for col in y.t().rows() {
322 if *col.mapv(|v| v.abs()).max().unwrap() > eps {
323 y_score = Some(col.to_owned());
324 break;
325 }
326 }
327 let mut y_score = y_score.ok_or(PlsError::PowerMethodConstantResidualError())?;
328
329 let mut x_pinv = None;
330 let mut y_pinv = None;
331 if self.mode() == Mode::B {
332 x_pinv = Some(utils::pinv2(x.view(), Some(F::cast(10.) * eps)));
333 y_pinv = Some(utils::pinv2(y.view(), Some(F::cast(10.) * eps)));
334 }
335
336 let mut x_weights_old = Array1::<F>::from_elem(x.ncols(), F::cast(100.));
338
339 let mut n_iter = 1;
340 let mut x_weights = Array1::<F>::ones(x.ncols());
341 let mut y_weights = Array1::<F>::ones(y.ncols());
342 let mut converged = false;
343 while n_iter < self.max_iter() {
344 x_weights = match self.mode() {
345 Mode::A => x.t().dot(&y_score) / y_score.dot(&y_score),
346 Mode::B => x_pinv.to_owned().unwrap().dot(&y_score),
347 };
348 x_weights /= x_weights.dot(&x_weights).sqrt() + eps;
349 let x_score = x.dot(&x_weights);
350
351 y_weights = match self.mode() {
352 Mode::A => y.t().dot(&x_score) / x_score.dot(&x_score),
353 Mode::B => y_pinv.to_owned().unwrap().dot(&x_score),
354 };
355
356 if norm_y_weights {
357 y_weights /= y_weights.dot(&y_weights).sqrt() + eps
358 }
359
360 let ya = y.dot(&y_weights);
361 let yb = y_weights.dot(&y_weights) + eps;
362 y_score = ya.mapv(|v| v / yb);
363
364 let x_weights_diff = &x_weights - &x_weights_old;
365 if x_weights_diff.dot(&x_weights_diff) < self.tolerance() || y.ncols() == 1 {
366 converged = true;
367 break;
368 } else {
369 x_weights_old = x_weights.to_owned();
370 n_iter += 1;
371 }
372 }
373 if n_iter == self.max_iter() && !converged {
374 Err(PlsError::PowerMethodNotConvergedError(self.max_iter()))
375 } else {
376 Ok((x_weights, y_weights, n_iter))
377 }
378 }
379
380 fn get_first_singular_vectors_svd(
381 &self,
382 x: &ArrayBase<impl Data<Elem = F>, Ix2>,
383 y: &ArrayBase<impl Data<Elem = F>, Ix2>,
384 ) -> Result<(Array1<F>, Array1<F>)> {
385 let c = x.t().dot(y);
386
387 let c = c.with_lapack();
388 let (u, s, vt) = c.svd(true, true)?;
389 let max = s.argmax()?;
392 let u = u.unwrap().column(max).to_owned().without_lapack();
393 let vt = vt.unwrap().row(max).to_owned().without_lapack();
394
395 Ok((u, vt))
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use approx::assert_abs_diff_eq;
403 use linfa::{dataset::Records, traits::Predict, ParamGuard};
404 use linfa_datasets::linnerud;
405 use ndarray::{array, concatenate, Array, Axis};
406 use ndarray_rand::rand::SeedableRng;
407 use ndarray_rand::rand_distr::StandardNormal;
408 use ndarray_rand::RandomExt;
409 use rand_xoshiro::Xoshiro256Plus;
410
411 #[test]
412 fn autotraits() {
413 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
414 has_autotraits::<PlsParams<f64>>();
415 has_autotraits::<PlsValidParams<f64>>();
416 has_autotraits::<Pls<f64>>();
417 has_autotraits::<PlsError>();
418 }
419
420 fn assert_matrix_orthonormal(m: &Array2<f64>) {
421 assert_abs_diff_eq!(&m.t().dot(m), &Array::eye(m.ncols()), epsilon = 1e-7);
422 }
423
424 fn assert_matrix_orthogonal(m: &Array2<f64>) {
425 let k = m.t().dot(m);
426 assert_abs_diff_eq!(&k, &Array::from_diag(&k.diag()), epsilon = 1e-7);
427 }
428
429 #[test]
430 fn test_pls_canonical_basics() -> Result<()> {
431 let dataset = linnerud();
433 let records = dataset.records();
434
435 let pls = Pls::canonical(records.ncols()).fit(&dataset)?;
436
437 let (x_weights, y_weights) = pls.weights();
438 assert_matrix_orthonormal(x_weights);
439 assert_matrix_orthonormal(y_weights);
440
441 let (x_scores, y_scores) = pls.scores();
442 assert_matrix_orthogonal(x_scores);
443 assert_matrix_orthogonal(y_scores);
444
445 let (p, q) = pls.loadings();
447 let t = x_scores;
448 let u = y_scores;
449
450 let (xc, yc, ..) = utils::center_scale_dataset(&dataset, true);
452 assert_abs_diff_eq!(&xc, &t.dot(&p.t()), epsilon = 1e-7);
453 assert_abs_diff_eq!(&yc, &u.dot(&q.t()), epsilon = 1e-7);
454
455 let ds = pls.transform(dataset);
457 assert_abs_diff_eq!(ds.records(), x_scores, epsilon = 1e-7);
458 assert_abs_diff_eq!(ds.targets(), y_scores, epsilon = 1e-7);
459
460 Ok(())
461 }
462
463 #[test]
464 fn test_sanity_check_pls_regression() {
465 let dataset = linnerud();
466 let pls = Pls::regression(3)
467 .fit(&dataset)
468 .expect("PLS fitting failed");
469
470 let expected_x_weights = array![
472 [0.61330704, -0.00443647, 0.78983213],
473 [0.74697144, -0.32172099, -0.58183269],
474 [0.25668686, 0.94682413, -0.19399983]
475 ];
476
477 let expected_x_loadings = array![
478 [0.61470416, -0.24574278, 0.78983213],
479 [0.65625755, -0.14396183, -0.58183269],
480 [0.51733059, 1.00609417, -0.19399983]
481 ];
482
483 let expected_y_weights = array![
484 [-0.32456184, 0.29892183, 0.20316322],
485 [-0.42439636, 0.61970543, 0.19320542],
486 [0.13143144, -0.26348971, -0.17092916]
487 ];
488
489 let expected_y_loadings = array![
490 [-0.32456184, 0.29892183, 0.20316322],
491 [-0.42439636, 0.61970543, 0.19320542],
492 [0.13143144, -0.26348971, -0.17092916]
493 ];
494 assert_abs_diff_eq!(pls.x_weights, expected_x_weights, epsilon = 1e-6);
495 assert_abs_diff_eq!(pls.x_loadings, expected_x_loadings, epsilon = 1e-6);
496 assert_abs_diff_eq!(pls.y_weights, expected_y_weights, epsilon = 1e-6);
497 assert_abs_diff_eq!(pls.y_loadings, expected_y_loadings, epsilon = 1e-6);
498 }
499
500 #[test]
501 fn test_sanity_check_pls_regression_constant_column_y() {
502 let mut dataset = linnerud();
503 let nrows = dataset.targets.nrows();
504 dataset.targets.column_mut(0).assign(&Array1::ones(nrows));
505 let pls = Pls::regression(3)
506 .fit(&dataset)
507 .expect("PLS fitting failed");
508
509 let expected_x_weights = array![
511 [0.6273573, 0.007081799, 0.7786994],
512 [0.7493417, -0.277612681, -0.6011807],
513 [0.2119194, 0.960666981, -0.1794690]
514 ];
515
516 let expected_x_loadings = array![
517 [0.6273512, -0.22464538, 0.7786994],
518 [0.6643156, -0.09871193, -0.6011807],
519 [0.5125877, 1.01407380, -0.1794690]
520 ];
521
522 let expected_y_loadings = array![
523 [0.0000000, 0.0000000, 0.0000000],
524 [-0.4357300, 0.5828479, 0.2174802],
525 [0.1353739, -0.2486423, -0.1810386]
526 ];
527 assert_abs_diff_eq!(pls.x_weights, expected_x_weights, epsilon = 1e-6);
528 assert_abs_diff_eq!(pls.x_loadings, expected_x_loadings, epsilon = 1e-6);
529 assert_abs_diff_eq!(pls.y_loadings, expected_y_loadings, epsilon = 1e-6);
531 assert_abs_diff_eq!(pls.y_weights, expected_y_loadings, epsilon = 1e-6);
532 }
533
534 #[test]
535 fn test_sanity_check_pls_canonical() -> Result<()> {
536 let dataset = linnerud();
539 let pls = Pls::canonical(dataset.records().ncols()).fit(&dataset)?;
540
541 let expected_x_weights = array![
542 [-0.61330704, 0.25616119, -0.74715187],
543 [-0.74697144, 0.11930791, 0.65406368],
544 [-0.25668686, -0.95924297, -0.11817271]
545 ];
546
547 let expected_x_rotations = array![
548 [-0.61330704, 0.41591889, -0.62297525],
549 [-0.74697144, 0.31388326, 0.77368233],
550 [-0.25668686, -0.89237972, -0.24121788]
551 ];
552
553 let expected_y_weights = array![
554 [0.58989127, 0.7890047, 0.1717553],
555 [0.77134053, -0.61351791, 0.16920272],
556 [-0.23887670, -0.03267062, 0.97050016]
557 ];
558
559 let expected_y_rotations = array![
560 [0.58989127, 0.7168115, 0.30665872],
561 [0.77134053, -0.70791757, 0.19786539],
562 [-0.23887670, -0.00343595, 0.94162826]
563 ];
564
565 let (x_weights, y_weights) = pls.weights();
566 let (x_rotations, y_rotations) = pls.rotations();
567 assert_abs_diff_eq!(
568 expected_x_rotations.mapv(|v: f64| v.abs()),
569 x_rotations.mapv(|v| v.abs()),
570 epsilon = 1e-7
571 );
572 assert_abs_diff_eq!(
573 expected_x_weights.mapv(|v: f64| v.abs()),
574 x_weights.mapv(|v| v.abs()),
575 epsilon = 1e-7
576 );
577 assert_abs_diff_eq!(
578 expected_y_rotations.mapv(|v: f64| v.abs()),
579 y_rotations.mapv(|v| v.abs()),
580 epsilon = 1e-7
581 );
582 assert_abs_diff_eq!(
583 expected_y_weights.mapv(|v: f64| v.abs()),
584 y_weights.mapv(|v| v.abs()),
585 epsilon = 1e-7
586 );
587
588 let x_rotations_sign_flip = (x_rotations / &expected_x_rotations).mapv(|v| v.signum());
589 let x_weights_sign_flip = (x_weights / &expected_x_weights).mapv(|v| v.signum());
590 let y_rotations_sign_flip = (y_rotations / &expected_y_rotations).mapv(|v| v.signum());
591 let y_weights_sign_flip = (y_weights / &expected_y_weights).mapv(|v| v.signum());
592 assert_abs_diff_eq!(x_rotations_sign_flip, x_weights_sign_flip);
593 assert_abs_diff_eq!(y_rotations_sign_flip, y_weights_sign_flip);
594
595 assert_matrix_orthonormal(x_weights);
596 assert_matrix_orthonormal(y_weights);
597
598 let (x_scores, y_scores) = pls.scores();
599 assert_matrix_orthogonal(x_scores);
600 assert_matrix_orthogonal(y_scores);
601 Ok(())
602 }
603
604 #[test]
605 fn test_sanity_check_pls_canonical_random() {
606 let n = 500;
609 let p_noise = 10;
610 let q_noise = 5;
611
612 let mut rng = Xoshiro256Plus::seed_from_u64(100);
614 let l1: Array1<f64> = Array1::random_using(n, StandardNormal, &mut rng);
615 let l2: Array1<f64> = Array1::random_using(n, StandardNormal, &mut rng);
616 let mut latents = Array::zeros((4, n));
617 latents.row_mut(0).assign(&l1);
618 latents.row_mut(0).assign(&l1);
619 latents.row_mut(0).assign(&l2);
620 latents.row_mut(0).assign(&l2);
621 latents = latents.reversed_axes();
622
623 let mut x = &latents + &Array2::<f64>::random_using((n, 4), StandardNormal, &mut rng);
624 let mut y = latents + &Array2::<f64>::random_using((n, 4), StandardNormal, &mut rng);
625
626 x = concatenate(
627 Axis(1),
628 &[
629 x.view(),
630 Array2::random_using((n, p_noise), StandardNormal, &mut rng).view(),
631 ],
632 )
633 .unwrap();
634 y = concatenate(
635 Axis(1),
636 &[
637 y.view(),
638 Array2::random_using((n, q_noise), StandardNormal, &mut rng).view(),
639 ],
640 )
641 .unwrap();
642
643 let ds = Dataset::new(x, y);
644 let pls = Pls::canonical(3)
645 .fit(&ds)
646 .expect("PLS canonical fitting failed");
647
648 let (x_weights, y_weights) = pls.weights();
649 assert_matrix_orthonormal(x_weights);
650 assert_matrix_orthonormal(y_weights);
651
652 let (x_scores, y_scores) = pls.scores();
653 assert_matrix_orthogonal(x_scores);
654 assert_matrix_orthogonal(y_scores);
655 }
656
657 #[test]
658 fn test_scale_and_stability() -> Result<()> {
659 let ds = linnerud();
663 let (x_s, y_s, ..) = utils::center_scale_dataset(&ds, true);
664 let ds_s = Dataset::new(x_s, y_s);
665
666 let ds_score = Pls::regression(2)
667 .scale(true)
668 .tolerance(1e-3)
669 .fit(&ds)?
670 .transform(ds.to_owned());
671 let ds_s_score = Pls::regression(2)
672 .scale(false)
673 .tolerance(1e-3)
674 .fit(&ds_s)?
675 .transform(ds_s.to_owned());
676
677 assert_abs_diff_eq!(ds_s_score.records(), ds_score.records(), epsilon = 1e-4);
678 assert_abs_diff_eq!(ds_s_score.targets(), ds_score.targets(), epsilon = 1e-4);
679 Ok(())
680 }
681
682 #[test]
683 fn test_one_component_equivalence() -> Result<()> {
684 let ds = linnerud();
686 let ds2 = linnerud();
687 let regression = Pls::regression(1).fit(&ds)?.transform(ds);
688 let canonical = Pls::canonical(1).fit(&ds2)?.transform(ds2);
689
690 assert_abs_diff_eq!(regression.records(), canonical.records(), epsilon = 1e-7);
691 Ok(())
692 }
693
694 #[test]
695 fn test_convergence_fail() {
696 let ds = linnerud();
697 assert!(
698 Pls::canonical(ds.records().nfeatures())
699 .max_iterations(2)
700 .fit(&ds)
701 .is_err(),
702 "PLS power method should not converge, hence raise an error"
703 );
704 }
705
706 #[test]
707 fn test_bad_component_number() {
708 let ds = linnerud();
709 assert!(
710 Pls::cca(ds.records().nfeatures() + 1).fit(&ds).is_err(),
711 "n_components too large should raise an error"
712 );
713 assert!(
714 Pls::canonical(0).fit(&ds).is_err(),
715 "n_components=0 should raise an error"
716 );
717 }
718
719 #[test]
720 fn test_singular_value_helpers() -> Result<()> {
721 let ds = linnerud();
723
724 let (mut u1, mut v1, _) = PlsParams::new(2)
725 .check()?
726 .get_first_singular_vectors_power_method(ds.records(), ds.targets(), true)?;
727 let (mut u2, mut v2) = PlsParams::new(2)
728 .check()?
729 .get_first_singular_vectors_svd(ds.records(), ds.targets())?;
730
731 utils::svd_flip_1d(&mut u1, &mut v1);
732 utils::svd_flip_1d(&mut u2, &mut v2);
733
734 let rtol = 1e-1;
735 assert_abs_diff_eq!(u1, u2, epsilon = rtol);
736 assert_abs_diff_eq!(v1, v2, epsilon = rtol);
737 Ok(())
738 }
739
740 macro_rules! test_pls_algo_nipals_svd {
741 ($($name:ident, )*) => {
742 paste::item! {
743 $(
744 #[test]
745 fn [<test_pls_$name>]() -> Result<()> {
746 let ds = linnerud();
747 let pls = Pls::[<$name>](3).fit(&ds)?;
748 let ds1 = pls.transform(ds.to_owned());
749 let ds2 = Pls::[<$name>](3).algorithm(Algorithm::Svd).fit(&ds)?.transform(ds);
750 assert_abs_diff_eq!(ds1.records(), ds2.records(), epsilon=1e-2);
751 let exercices = array![[14., 146., 61.], [6., 80., 60.]];
752 let physios = pls.predict(exercices);
753 println!("Physiologicals = {:?}", physios.targets());
754 Ok(())
755 }
756 )*
757 }
758 };
759 }
760
761 test_pls_algo_nipals_svd! {
762 canonical, regression,
763 }
764
765 #[test]
766 fn test_cca() {
767 let ds = linnerud();
769 let cca = Pls::cca(3).fit(&ds).unwrap();
770 let ds = cca.transform(ds);
771 let expected_x = array![
772 [0.09597886, 0.13862931, -1.0311966],
773 [-0.7170194, 0.25195026, -0.83049671],
774 [-0.76492193, 0.37601463, 1.20714686],
775 [-0.03734329, -0.9746487, 0.79363542],
776 [0.42809962, -0.50053551, 0.40089685],
777 [-0.54141144, -0.29403268, -0.47221389],
778 [-0.29901672, -0.67023009, 0.17945745],
779 [-0.11425233, -0.43360723, -0.47235823],
780 [1.29212153, -0.9373391, 0.02572464],
781 [-0.17770025, 3.4785377, 0.8486413],
782 [0.39344638, -1.28718499, 1.43816035],
783 [0.52667844, 0.82080301, -0.02624471],
784 [0.74616393, 0.54578854, 0.01825073],
785 [-1.42623443, -0.00884605, -0.24019883],
786 [-0.72026991, -0.73588273, 0.2241694],
787 [0.4237932, 0.99977428, -0.1667137],
788 [-0.88437821, -0.73784626, -0.01073894],
789 [1.05159992, 0.26381077, -0.83138216],
790 [1.26196754, -0.18618728, -0.12863494],
791 [-0.53730151, -0.10896789, -0.92590428]
792 ];
793 assert_abs_diff_eq!(expected_x, ds.records(), epsilon = 1e-2);
794 }
795
796 #[test]
797 fn test_transform_and_inverse() -> Result<()> {
798 let ds = linnerud();
799 let pls = Pls::canonical(3).fit(&ds)?;
800
801 let ds_proj = pls.transform(ds);
802 let ds_orig = pls.inverse_transform(ds_proj);
803
804 let ds = linnerud();
805 assert_abs_diff_eq!(ds.records(), ds_orig.records(), epsilon = 1e-6);
806 assert_abs_diff_eq!(ds.targets(), ds_orig.targets(), epsilon = 1e-6);
807 Ok(())
808 }
809
810 #[test]
811 fn test_pls_constant_y() {
812 let n = 100;
814 let mut rng = Xoshiro256Plus::seed_from_u64(42);
815 let x = Array2::<f64>::random_using((n, 3), StandardNormal, &mut rng);
816 let y = Array2::zeros((n, 1));
817 let ds = Dataset::new(x, y);
818 assert!(matches!(
819 Pls::regression(2).fit(&ds).unwrap_err(),
820 PlsError::PowerMethodConstantResidualError()
821 ));
822 }
823}