linfa_pls/
pls_generic.rs

1use crate::errors::{PlsError, Result};
2use crate::utils;
3use crate::{PlsParams, PlsValidParams};
4
5use linfa::{
6    dataset::{Records, WithLapack, WithoutLapack},
7    traits::Fit,
8    traits::PredictInplace,
9    traits::Transformer,
10    Dataset, DatasetBase, Float,
11};
12#[cfg(not(feature = "blas"))]
13use linfa_linalg::svd::*;
14use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
15#[cfg(feature = "blas")]
16use ndarray_linalg::svd::*;
17use ndarray_stats::QuantileExt;
18#[cfg(feature = "serde")]
19use serde_crate::{Deserialize, Serialize};
20
21#[cfg_attr(
22    feature = "serde",
23    derive(Serialize, Deserialize),
24    serde(crate = "serde_crate")
25)]
26#[derive(Debug, Clone, PartialEq)]
27pub(crate) struct Pls<F: Float> {
28    x_mean: Array1<F>,
29    x_std: Array1<F>,
30    y_mean: Array1<F>,
31    y_std: Array1<F>,
32    x_weights: Array2<F>, // U
33    y_weights: Array2<F>, // V
34    #[cfg(test)]
35    x_scores: Array2<F>, // xi
36    #[cfg(test)]
37    y_scores: Array2<F>, // Omega
38    x_loadings: Array2<F>, // Gamma
39    y_loadings: Array2<F>, // Delta
40    x_rotations: Array2<F>,
41    y_rotations: Array2<F>,
42    coefficients: Array2<F>,
43}
44
45#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
46pub enum Algorithm {
47    Nipals,
48    Svd,
49}
50
51#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
52pub(crate) enum DeflationMode {
53    Regression,
54    Canonical,
55}
56
57#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
58pub(crate) enum Mode {
59    A,
60    B,
61}
62
63/// Generic PLS algorithm.
64/// Main ref: Wegelin, a survey of Partial Least Squares (PLS) methods,
65/// with emphasis on the two-block case
66/// https://www.stat.washington.edu/research/reports/2000/tr371.pdf
67impl<F: Float> Pls<F> {
68    // Constructor for PlsRegression method
69    pub fn regression(n_components: usize) -> PlsParams<F> {
70        PlsParams::new(n_components)
71    }
72
73    // Constructor for PlsCanonical method
74    pub fn canonical(n_components: usize) -> PlsParams<F> {
75        PlsParams::new(n_components).deflation_mode(DeflationMode::Canonical)
76    }
77
78    // Constructor for PlsCca method
79    pub fn cca(n_components: usize) -> PlsParams<F> {
80        PlsParams::new(n_components)
81            .deflation_mode(DeflationMode::Canonical)
82            .mode(Mode::B)
83    }
84
85    pub fn weights(&self) -> (&Array2<F>, &Array2<F>) {
86        (&self.x_weights, &self.y_weights)
87    }
88
89    #[cfg(test)]
90    pub fn scores(&self) -> (&Array2<F>, &Array2<F>) {
91        (&self.x_scores, &self.y_scores)
92    }
93
94    pub fn loadings(&self) -> (&Array2<F>, &Array2<F>) {
95        (&self.x_loadings, &self.y_loadings)
96    }
97
98    pub fn rotations(&self) -> (&Array2<F>, &Array2<F>) {
99        (&self.x_rotations, &self.y_rotations)
100    }
101
102    pub fn coefficients(&self) -> &Array2<F> {
103        &self.coefficients
104    }
105
106    pub fn inverse_transform(
107        &self,
108        dataset: DatasetBase<
109            ArrayBase<impl Data<Elem = F>, Ix2>,
110            ArrayBase<impl Data<Elem = F>, Ix2>,
111        >,
112    ) -> DatasetBase<Array2<F>, Array2<F>> {
113        let mut x_orig = dataset.records().dot(&self.x_loadings.t());
114        x_orig = &x_orig * &self.x_std;
115        x_orig = &x_orig + &self.x_mean;
116        let mut y_orig = dataset.targets().dot(&self.y_loadings.t());
117        y_orig = &y_orig * &self.y_std;
118        y_orig = &y_orig + &self.y_mean;
119        Dataset::new(x_orig, y_orig)
120    }
121}
122
123impl<F: Float, D: Data<Elem = F>>
124    Transformer<
125        DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
126        DatasetBase<Array2<F>, Array2<F>>,
127    > for Pls<F>
128{
129    fn transform(
130        &self,
131        dataset: DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
132    ) -> DatasetBase<Array2<F>, Array2<F>> {
133        let mut x_norm = dataset.records() - &self.x_mean;
134        x_norm /= &self.x_std;
135        let mut y_norm = dataset.targets() - &self.y_mean;
136        y_norm /= &self.y_std;
137        // Apply rotations
138        let x_proj = x_norm.dot(&self.x_rotations);
139        let y_proj = y_norm.dot(&self.y_rotations);
140        Dataset::new(x_proj, y_proj)
141    }
142}
143
144impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<F>> for Pls<F> {
145    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array2<F>) {
146        assert_eq!(
147            y.shape(),
148            &[x.nrows(), self.coefficients.ncols()],
149            "The number of data points must match the number of output targets."
150        );
151
152        let mut x = x - &self.x_mean;
153        x /= &self.x_std;
154        *y = x.dot(&self.coefficients) + &self.y_mean;
155    }
156
157    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
158        Array2::zeros((x.nrows(), self.coefficients.ncols()))
159    }
160}
161
162impl<F: Float, D: Data<Elem = F>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, PlsError>
163    for PlsValidParams<F>
164{
165    type Object = Pls<F>;
166
167    fn fit(
168        &self,
169        dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
170    ) -> Result<Self::Object> {
171        let records = dataset.records();
172        let targets = dataset.targets();
173
174        let n = records.nrows();
175        let p = records.ncols();
176        let q = targets.ncols();
177
178        if n < 2 {
179            return Err(PlsError::NotEnoughSamplesError(
180                dataset.records().nsamples(),
181            ));
182        }
183
184        let n_components = self.n_components();
185        let rank_upper_bound = match self.deflation_mode() {
186            DeflationMode::Regression => {
187                // With PLSRegression n_components is bounded by the rank of (x.T x)
188                // see Wegelin page 25
189                p
190            }
191            DeflationMode::Canonical => {
192                // With CCA and PLSCanonical, n_components is bounded by the rank of
193                // X and the rank of Y: see Wegelin page 12
194                n.min(p.min(q))
195            }
196        };
197
198        if 1 > n_components || n_components > rank_upper_bound {
199            return Err(PlsError::BadComponentNumberError {
200                upperbound: rank_upper_bound,
201                actual: n_components,
202            });
203        }
204        let norm_y_weights = self.deflation_mode() == DeflationMode::Canonical;
205        let (mut xk, mut yk, x_mean, y_mean, x_std, y_std) =
206            utils::center_scale_dataset(dataset, self.scale());
207
208        let mut x_weights = Array2::<F>::zeros((p, n_components)); // U
209        let mut y_weights = Array2::<F>::zeros((q, n_components)); // V
210        let mut x_scores = Array2::<F>::zeros((n, n_components)); // xi
211        let mut y_scores = Array2::<F>::zeros((n, n_components)); // Omega
212        let mut x_loadings = Array2::<F>::zeros((p, n_components)); // Gamma
213        let mut y_loadings = Array2::<F>::zeros((q, n_components)); // Delta
214        let mut n_iters = Array1::zeros(n_components);
215
216        // This whole thing corresponds to the algorithm in section 4.1 of the
217        // review from Wegelin. See above for a notation mapping from code to
218        // paper.
219        let eps = F::epsilon();
220        for k in 0..n_components {
221            // Find first left and right singular vectors of the x.T.dot(Y)
222            // cross-covariance matrix.
223
224            let (mut x_weights_k, mut y_weights_k) = match self.algorithm() {
225                Algorithm::Nipals => {
226                    // Replace columns that are all close to zero with zeros
227                    for mut yj in yk.columns_mut() {
228                        if *(yj.mapv(|y| y.abs()).max()?) < F::cast(10.) * eps {
229                            yj.assign(&Array1::zeros(yj.len()));
230                        }
231                    }
232
233                    let (x_weights_k, y_weights_k, n_iter) =
234                        self.get_first_singular_vectors_power_method(&xk, &yk, norm_y_weights)?;
235                    n_iters[k] = n_iter;
236                    (x_weights_k, y_weights_k)
237                }
238                Algorithm::Svd => self.get_first_singular_vectors_svd(&xk, &yk)?,
239            };
240            utils::svd_flip_1d(&mut x_weights_k, &mut y_weights_k);
241
242            // compute scores, i.e. the projections of x and Y
243            let x_scores_k = xk.dot(&x_weights_k);
244            let y_ss = if norm_y_weights {
245                F::one()
246            } else {
247                y_weights_k.dot(&y_weights_k)
248            };
249            let y_scores_k = yk.dot(&y_weights_k) / y_ss;
250
251            // Deflation: subtract rank-one approx to obtain xk+1 and yk+1
252            let x_loadings_k = x_scores_k.dot(&xk) / x_scores_k.dot(&x_scores_k);
253            xk = xk - utils::outer(&x_scores_k, &x_loadings_k); // outer product
254
255            let y_loadings_k = match self.deflation_mode() {
256                DeflationMode::Canonical => {
257                    // regress yk on y_score
258                    let y_loadings_k = y_scores_k.dot(&yk) / y_scores_k.dot(&y_scores_k);
259                    yk = yk - utils::outer(&y_scores_k, &y_loadings_k); // outer product
260                    y_loadings_k
261                }
262                DeflationMode::Regression => {
263                    // regress yk on x_score
264                    let y_loadings_k = x_scores_k.dot(&yk) / x_scores_k.dot(&x_scores_k);
265                    yk = yk - utils::outer(&x_scores_k, &y_loadings_k); // outer product
266                    y_loadings_k
267                }
268            };
269
270            x_weights.column_mut(k).assign(&x_weights_k);
271            y_weights.column_mut(k).assign(&y_weights_k);
272            x_scores.column_mut(k).assign(&x_scores_k);
273            y_scores.column_mut(k).assign(&y_scores_k);
274            x_loadings.column_mut(k).assign(&x_loadings_k);
275            y_loadings.column_mut(k).assign(&y_loadings_k);
276        }
277        // x was approximated as xi . Gamma.T + x_(R+1) xi . Gamma.T is a sum
278        // of n_components rank-1 matrices. x_(R+1) is whatever is left
279        // to fully reconstruct x, and can be 0 if x is of rank n_components.
280        // Similiarly, Y was approximated as Omega . Delta.T + Y_(R+1)
281
282        // Compute transformation matrices (rotations_). See User Guide.
283        let x_rotations = x_weights.dot(&utils::pinv2(x_loadings.t().dot(&x_weights).view(), None));
284        let y_rotations = y_weights.dot(&utils::pinv2(y_loadings.t().dot(&y_weights).view(), None));
285
286        let mut coefficients = x_rotations.dot(&y_loadings.t());
287        coefficients *= &y_std;
288
289        Ok(Pls {
290            x_mean,
291            x_std,
292            y_mean,
293            y_std,
294            x_weights,
295            y_weights,
296            #[cfg(test)]
297            x_scores,
298            #[cfg(test)]
299            y_scores,
300            x_loadings,
301            y_loadings,
302            x_rotations,
303            y_rotations,
304            coefficients,
305        })
306    }
307}
308
309impl<F: Float> PlsValidParams<F> {
310    /// Return the first left and right singular vectors of x'Y.
311    /// Provides an alternative to the svd(x'Y) and uses the power method instead.
312    fn get_first_singular_vectors_power_method(
313        &self,
314        x: &ArrayBase<impl Data<Elem = F>, Ix2>,
315        y: &ArrayBase<impl Data<Elem = F>, Ix2>,
316        norm_y_weights: bool,
317    ) -> Result<(Array1<F>, Array1<F>, usize)> {
318        let eps = F::epsilon();
319
320        let mut y_score = None;
321        for col in y.t().rows() {
322            if *col.mapv(|v| v.abs()).max().unwrap() > eps {
323                y_score = Some(col.to_owned());
324                break;
325            }
326        }
327        let mut y_score = y_score.ok_or(PlsError::PowerMethodConstantResidualError())?;
328
329        let mut x_pinv = None;
330        let mut y_pinv = None;
331        if self.mode() == Mode::B {
332            x_pinv = Some(utils::pinv2(x.view(), Some(F::cast(10.) * eps)));
333            y_pinv = Some(utils::pinv2(y.view(), Some(F::cast(10.) * eps)));
334        }
335
336        // init to big value for first convergence check
337        let mut x_weights_old = Array1::<F>::from_elem(x.ncols(), F::cast(100.));
338
339        let mut n_iter = 1;
340        let mut x_weights = Array1::<F>::ones(x.ncols());
341        let mut y_weights = Array1::<F>::ones(y.ncols());
342        let mut converged = false;
343        while n_iter < self.max_iter() {
344            x_weights = match self.mode() {
345                Mode::A => x.t().dot(&y_score) / y_score.dot(&y_score),
346                Mode::B => x_pinv.to_owned().unwrap().dot(&y_score),
347            };
348            x_weights /= x_weights.dot(&x_weights).sqrt() + eps;
349            let x_score = x.dot(&x_weights);
350
351            y_weights = match self.mode() {
352                Mode::A => y.t().dot(&x_score) / x_score.dot(&x_score),
353                Mode::B => y_pinv.to_owned().unwrap().dot(&x_score),
354            };
355
356            if norm_y_weights {
357                y_weights /= y_weights.dot(&y_weights).sqrt() + eps
358            }
359
360            let ya = y.dot(&y_weights);
361            let yb = y_weights.dot(&y_weights) + eps;
362            y_score = ya.mapv(|v| v / yb);
363
364            let x_weights_diff = &x_weights - &x_weights_old;
365            if x_weights_diff.dot(&x_weights_diff) < self.tolerance() || y.ncols() == 1 {
366                converged = true;
367                break;
368            } else {
369                x_weights_old = x_weights.to_owned();
370                n_iter += 1;
371            }
372        }
373        if n_iter == self.max_iter() && !converged {
374            Err(PlsError::PowerMethodNotConvergedError(self.max_iter()))
375        } else {
376            Ok((x_weights, y_weights, n_iter))
377        }
378    }
379
380    fn get_first_singular_vectors_svd(
381        &self,
382        x: &ArrayBase<impl Data<Elem = F>, Ix2>,
383        y: &ArrayBase<impl Data<Elem = F>, Ix2>,
384    ) -> Result<(Array1<F>, Array1<F>)> {
385        let c = x.t().dot(y);
386
387        let c = c.with_lapack();
388        let (u, s, vt) = c.svd(true, true)?;
389        // Extract the SVD component corresponding to the largest singular-value
390        // XXX We should compute the partial SVD instead of full SVD
391        let max = s.argmax()?;
392        let u = u.unwrap().column(max).to_owned().without_lapack();
393        let vt = vt.unwrap().row(max).to_owned().without_lapack();
394
395        Ok((u, vt))
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use approx::assert_abs_diff_eq;
403    use linfa::{dataset::Records, traits::Predict, ParamGuard};
404    use linfa_datasets::linnerud;
405    use ndarray::{array, concatenate, Array, Axis};
406    use ndarray_rand::rand::SeedableRng;
407    use ndarray_rand::rand_distr::StandardNormal;
408    use ndarray_rand::RandomExt;
409    use rand_xoshiro::Xoshiro256Plus;
410
411    #[test]
412    fn autotraits() {
413        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
414        has_autotraits::<PlsParams<f64>>();
415        has_autotraits::<PlsValidParams<f64>>();
416        has_autotraits::<Pls<f64>>();
417        has_autotraits::<PlsError>();
418    }
419
420    fn assert_matrix_orthonormal(m: &Array2<f64>) {
421        assert_abs_diff_eq!(&m.t().dot(m), &Array::eye(m.ncols()), epsilon = 1e-7);
422    }
423
424    fn assert_matrix_orthogonal(m: &Array2<f64>) {
425        let k = m.t().dot(m);
426        assert_abs_diff_eq!(&k, &Array::from_diag(&k.diag()), epsilon = 1e-7);
427    }
428
429    #[test]
430    fn test_pls_canonical_basics() -> Result<()> {
431        // Basic checks for PLSCanonical
432        let dataset = linnerud();
433        let records = dataset.records();
434
435        let pls = Pls::canonical(records.ncols()).fit(&dataset)?;
436
437        let (x_weights, y_weights) = pls.weights();
438        assert_matrix_orthonormal(x_weights);
439        assert_matrix_orthonormal(y_weights);
440
441        let (x_scores, y_scores) = pls.scores();
442        assert_matrix_orthogonal(x_scores);
443        assert_matrix_orthogonal(y_scores);
444
445        // Check X = TP' and Y = UQ'
446        let (p, q) = pls.loadings();
447        let t = x_scores;
448        let u = y_scores;
449
450        // Need to scale first
451        let (xc, yc, ..) = utils::center_scale_dataset(&dataset, true);
452        assert_abs_diff_eq!(&xc, &t.dot(&p.t()), epsilon = 1e-7);
453        assert_abs_diff_eq!(&yc, &u.dot(&q.t()), epsilon = 1e-7);
454
455        // Check that rotations on training data lead to scores
456        let ds = pls.transform(dataset);
457        assert_abs_diff_eq!(ds.records(), x_scores, epsilon = 1e-7);
458        assert_abs_diff_eq!(ds.targets(), y_scores, epsilon = 1e-7);
459
460        Ok(())
461    }
462
463    #[test]
464    fn test_sanity_check_pls_regression() {
465        let dataset = linnerud();
466        let pls = Pls::regression(3)
467            .fit(&dataset)
468            .expect("PLS fitting failed");
469
470        // The results were checked against scikit-learn 0.24 PlsRegression
471        let expected_x_weights = array![
472            [0.61330704, -0.00443647, 0.78983213],
473            [0.74697144, -0.32172099, -0.58183269],
474            [0.25668686, 0.94682413, -0.19399983]
475        ];
476
477        let expected_x_loadings = array![
478            [0.61470416, -0.24574278, 0.78983213],
479            [0.65625755, -0.14396183, -0.58183269],
480            [0.51733059, 1.00609417, -0.19399983]
481        ];
482
483        let expected_y_weights = array![
484            [-0.32456184, 0.29892183, 0.20316322],
485            [-0.42439636, 0.61970543, 0.19320542],
486            [0.13143144, -0.26348971, -0.17092916]
487        ];
488
489        let expected_y_loadings = array![
490            [-0.32456184, 0.29892183, 0.20316322],
491            [-0.42439636, 0.61970543, 0.19320542],
492            [0.13143144, -0.26348971, -0.17092916]
493        ];
494        assert_abs_diff_eq!(pls.x_weights, expected_x_weights, epsilon = 1e-6);
495        assert_abs_diff_eq!(pls.x_loadings, expected_x_loadings, epsilon = 1e-6);
496        assert_abs_diff_eq!(pls.y_weights, expected_y_weights, epsilon = 1e-6);
497        assert_abs_diff_eq!(pls.y_loadings, expected_y_loadings, epsilon = 1e-6);
498    }
499
500    #[test]
501    fn test_sanity_check_pls_regression_constant_column_y() {
502        let mut dataset = linnerud();
503        let nrows = dataset.targets.nrows();
504        dataset.targets.column_mut(0).assign(&Array1::ones(nrows));
505        let pls = Pls::regression(3)
506            .fit(&dataset)
507            .expect("PLS fitting failed");
508
509        // The results were checked against scikit-learn 0.24 PlsRegression
510        let expected_x_weights = array![
511            [0.6273573, 0.007081799, 0.7786994],
512            [0.7493417, -0.277612681, -0.6011807],
513            [0.2119194, 0.960666981, -0.1794690]
514        ];
515
516        let expected_x_loadings = array![
517            [0.6273512, -0.22464538, 0.7786994],
518            [0.6643156, -0.09871193, -0.6011807],
519            [0.5125877, 1.01407380, -0.1794690]
520        ];
521
522        let expected_y_loadings = array![
523            [0.0000000, 0.0000000, 0.0000000],
524            [-0.4357300, 0.5828479, 0.2174802],
525            [0.1353739, -0.2486423, -0.1810386]
526        ];
527        assert_abs_diff_eq!(pls.x_weights, expected_x_weights, epsilon = 1e-6);
528        assert_abs_diff_eq!(pls.x_loadings, expected_x_loadings, epsilon = 1e-6);
529        // For the PLSRegression with default parameters, y_loadings == y_weights
530        assert_abs_diff_eq!(pls.y_loadings, expected_y_loadings, epsilon = 1e-6);
531        assert_abs_diff_eq!(pls.y_weights, expected_y_loadings, epsilon = 1e-6);
532    }
533
534    #[test]
535    fn test_sanity_check_pls_canonical() -> Result<()> {
536        // Sanity check for PLSCanonical
537        // The results were checked against the R-package plspm
538        let dataset = linnerud();
539        let pls = Pls::canonical(dataset.records().ncols()).fit(&dataset)?;
540
541        let expected_x_weights = array![
542            [-0.61330704, 0.25616119, -0.74715187],
543            [-0.74697144, 0.11930791, 0.65406368],
544            [-0.25668686, -0.95924297, -0.11817271]
545        ];
546
547        let expected_x_rotations = array![
548            [-0.61330704, 0.41591889, -0.62297525],
549            [-0.74697144, 0.31388326, 0.77368233],
550            [-0.25668686, -0.89237972, -0.24121788]
551        ];
552
553        let expected_y_weights = array![
554            [0.58989127, 0.7890047, 0.1717553],
555            [0.77134053, -0.61351791, 0.16920272],
556            [-0.23887670, -0.03267062, 0.97050016]
557        ];
558
559        let expected_y_rotations = array![
560            [0.58989127, 0.7168115, 0.30665872],
561            [0.77134053, -0.70791757, 0.19786539],
562            [-0.23887670, -0.00343595, 0.94162826]
563        ];
564
565        let (x_weights, y_weights) = pls.weights();
566        let (x_rotations, y_rotations) = pls.rotations();
567        assert_abs_diff_eq!(
568            expected_x_rotations.mapv(|v: f64| v.abs()),
569            x_rotations.mapv(|v| v.abs()),
570            epsilon = 1e-7
571        );
572        assert_abs_diff_eq!(
573            expected_x_weights.mapv(|v: f64| v.abs()),
574            x_weights.mapv(|v| v.abs()),
575            epsilon = 1e-7
576        );
577        assert_abs_diff_eq!(
578            expected_y_rotations.mapv(|v: f64| v.abs()),
579            y_rotations.mapv(|v| v.abs()),
580            epsilon = 1e-7
581        );
582        assert_abs_diff_eq!(
583            expected_y_weights.mapv(|v: f64| v.abs()),
584            y_weights.mapv(|v| v.abs()),
585            epsilon = 1e-7
586        );
587
588        let x_rotations_sign_flip = (x_rotations / &expected_x_rotations).mapv(|v| v.signum());
589        let x_weights_sign_flip = (x_weights / &expected_x_weights).mapv(|v| v.signum());
590        let y_rotations_sign_flip = (y_rotations / &expected_y_rotations).mapv(|v| v.signum());
591        let y_weights_sign_flip = (y_weights / &expected_y_weights).mapv(|v| v.signum());
592        assert_abs_diff_eq!(x_rotations_sign_flip, x_weights_sign_flip);
593        assert_abs_diff_eq!(y_rotations_sign_flip, y_weights_sign_flip);
594
595        assert_matrix_orthonormal(x_weights);
596        assert_matrix_orthonormal(y_weights);
597
598        let (x_scores, y_scores) = pls.scores();
599        assert_matrix_orthogonal(x_scores);
600        assert_matrix_orthogonal(y_scores);
601        Ok(())
602    }
603
604    #[test]
605    fn test_sanity_check_pls_canonical_random() {
606        // Sanity check for PLSCanonical on random data
607        // The results were checked against the R-package plspm
608        let n = 500;
609        let p_noise = 10;
610        let q_noise = 5;
611
612        // 2 latents vars:
613        let mut rng = Xoshiro256Plus::seed_from_u64(100);
614        let l1: Array1<f64> = Array1::random_using(n, StandardNormal, &mut rng);
615        let l2: Array1<f64> = Array1::random_using(n, StandardNormal, &mut rng);
616        let mut latents = Array::zeros((4, n));
617        latents.row_mut(0).assign(&l1);
618        latents.row_mut(0).assign(&l1);
619        latents.row_mut(0).assign(&l2);
620        latents.row_mut(0).assign(&l2);
621        latents = latents.reversed_axes();
622
623        let mut x = &latents + &Array2::<f64>::random_using((n, 4), StandardNormal, &mut rng);
624        let mut y = latents + &Array2::<f64>::random_using((n, 4), StandardNormal, &mut rng);
625
626        x = concatenate(
627            Axis(1),
628            &[
629                x.view(),
630                Array2::random_using((n, p_noise), StandardNormal, &mut rng).view(),
631            ],
632        )
633        .unwrap();
634        y = concatenate(
635            Axis(1),
636            &[
637                y.view(),
638                Array2::random_using((n, q_noise), StandardNormal, &mut rng).view(),
639            ],
640        )
641        .unwrap();
642
643        let ds = Dataset::new(x, y);
644        let pls = Pls::canonical(3)
645            .fit(&ds)
646            .expect("PLS canonical fitting failed");
647
648        let (x_weights, y_weights) = pls.weights();
649        assert_matrix_orthonormal(x_weights);
650        assert_matrix_orthonormal(y_weights);
651
652        let (x_scores, y_scores) = pls.scores();
653        assert_matrix_orthogonal(x_scores);
654        assert_matrix_orthogonal(y_scores);
655    }
656
657    #[test]
658    fn test_scale_and_stability() -> Result<()> {
659        // scale=True is equivalent to scale=False on centered/scaled data
660        // This allows to check numerical stability over platforms as well
661
662        let ds = linnerud();
663        let (x_s, y_s, ..) = utils::center_scale_dataset(&ds, true);
664        let ds_s = Dataset::new(x_s, y_s);
665
666        let ds_score = Pls::regression(2)
667            .scale(true)
668            .tolerance(1e-3)
669            .fit(&ds)?
670            .transform(ds.to_owned());
671        let ds_s_score = Pls::regression(2)
672            .scale(false)
673            .tolerance(1e-3)
674            .fit(&ds_s)?
675            .transform(ds_s.to_owned());
676
677        assert_abs_diff_eq!(ds_s_score.records(), ds_score.records(), epsilon = 1e-4);
678        assert_abs_diff_eq!(ds_s_score.targets(), ds_score.targets(), epsilon = 1e-4);
679        Ok(())
680    }
681
682    #[test]
683    fn test_one_component_equivalence() -> Result<()> {
684        // PlsRegression, PlsSvd and PLSCanonical should all be equivalent when n_components is 1
685        let ds = linnerud();
686        let ds2 = linnerud();
687        let regression = Pls::regression(1).fit(&ds)?.transform(ds);
688        let canonical = Pls::canonical(1).fit(&ds2)?.transform(ds2);
689
690        assert_abs_diff_eq!(regression.records(), canonical.records(), epsilon = 1e-7);
691        Ok(())
692    }
693
694    #[test]
695    fn test_convergence_fail() {
696        let ds = linnerud();
697        assert!(
698            Pls::canonical(ds.records().nfeatures())
699                .max_iterations(2)
700                .fit(&ds)
701                .is_err(),
702            "PLS power method should not converge, hence raise an error"
703        );
704    }
705
706    #[test]
707    fn test_bad_component_number() {
708        let ds = linnerud();
709        assert!(
710            Pls::cca(ds.records().nfeatures() + 1).fit(&ds).is_err(),
711            "n_components too large should raise an error"
712        );
713        assert!(
714            Pls::canonical(0).fit(&ds).is_err(),
715            "n_components=0 should raise an error"
716        );
717    }
718
719    #[test]
720    fn test_singular_value_helpers() -> Result<()> {
721        // Make sure SVD and power method give approximately the same results
722        let ds = linnerud();
723
724        let (mut u1, mut v1, _) = PlsParams::new(2)
725            .check()?
726            .get_first_singular_vectors_power_method(ds.records(), ds.targets(), true)?;
727        let (mut u2, mut v2) = PlsParams::new(2)
728            .check()?
729            .get_first_singular_vectors_svd(ds.records(), ds.targets())?;
730
731        utils::svd_flip_1d(&mut u1, &mut v1);
732        utils::svd_flip_1d(&mut u2, &mut v2);
733
734        let rtol = 1e-1;
735        assert_abs_diff_eq!(u1, u2, epsilon = rtol);
736        assert_abs_diff_eq!(v1, v2, epsilon = rtol);
737        Ok(())
738    }
739
740    macro_rules! test_pls_algo_nipals_svd {
741        ($($name:ident, )*) => {
742            paste::item! {
743                $(
744                    #[test]
745                    fn [<test_pls_$name>]() -> Result<()> {
746                        let ds = linnerud();
747                        let pls = Pls::[<$name>](3).fit(&ds)?;
748                        let ds1 = pls.transform(ds.to_owned());
749                        let ds2 = Pls::[<$name>](3).algorithm(Algorithm::Svd).fit(&ds)?.transform(ds);
750                        assert_abs_diff_eq!(ds1.records(), ds2.records(), epsilon=1e-2);
751                        let exercices = array![[14., 146., 61.], [6., 80., 60.]];
752                        let physios = pls.predict(exercices);
753                        println!("Physiologicals = {:?}", physios.targets());
754                        Ok(())
755                    }
756                )*
757            }
758        };
759    }
760
761    test_pls_algo_nipals_svd! {
762        canonical, regression,
763    }
764
765    #[test]
766    fn test_cca() {
767        // values checked against scikit-learn 0.24.1 CCA
768        let ds = linnerud();
769        let cca = Pls::cca(3).fit(&ds).unwrap();
770        let ds = cca.transform(ds);
771        let expected_x = array![
772            [0.09597886, 0.13862931, -1.0311966],
773            [-0.7170194, 0.25195026, -0.83049671],
774            [-0.76492193, 0.37601463, 1.20714686],
775            [-0.03734329, -0.9746487, 0.79363542],
776            [0.42809962, -0.50053551, 0.40089685],
777            [-0.54141144, -0.29403268, -0.47221389],
778            [-0.29901672, -0.67023009, 0.17945745],
779            [-0.11425233, -0.43360723, -0.47235823],
780            [1.29212153, -0.9373391, 0.02572464],
781            [-0.17770025, 3.4785377, 0.8486413],
782            [0.39344638, -1.28718499, 1.43816035],
783            [0.52667844, 0.82080301, -0.02624471],
784            [0.74616393, 0.54578854, 0.01825073],
785            [-1.42623443, -0.00884605, -0.24019883],
786            [-0.72026991, -0.73588273, 0.2241694],
787            [0.4237932, 0.99977428, -0.1667137],
788            [-0.88437821, -0.73784626, -0.01073894],
789            [1.05159992, 0.26381077, -0.83138216],
790            [1.26196754, -0.18618728, -0.12863494],
791            [-0.53730151, -0.10896789, -0.92590428]
792        ];
793        assert_abs_diff_eq!(expected_x, ds.records(), epsilon = 1e-2);
794    }
795
796    #[test]
797    fn test_transform_and_inverse() -> Result<()> {
798        let ds = linnerud();
799        let pls = Pls::canonical(3).fit(&ds)?;
800
801        let ds_proj = pls.transform(ds);
802        let ds_orig = pls.inverse_transform(ds_proj);
803
804        let ds = linnerud();
805        assert_abs_diff_eq!(ds.records(), ds_orig.records(), epsilon = 1e-6);
806        assert_abs_diff_eq!(ds.targets(), ds_orig.targets(), epsilon = 1e-6);
807        Ok(())
808    }
809
810    #[test]
811    fn test_pls_constant_y() {
812        // Checks constant residual error when y is constant.
813        let n = 100;
814        let mut rng = Xoshiro256Plus::seed_from_u64(42);
815        let x = Array2::<f64>::random_using((n, 3), StandardNormal, &mut rng);
816        let y = Array2::zeros((n, 1));
817        let ds = Dataset::new(x, y);
818        assert!(matches!(
819            Pls::regression(2).fit(&ds).unwrap_err(),
820            PlsError::PowerMethodConstantResidualError()
821        ));
822    }
823}