linfa_pls/
lib.rs

1//! # Partial Least Squares
2//!
3//! `linfa-pls` provides an implementation of methods in the PLS (Partial Least Squares) family.
4//! The PLS method is a statistical method that finds a linear relationship between
5//! input variables and output variables by projecting them onto a new subspace formed
6//! by newly chosen variables (aka latent variables), which are linear
7//! combinations of the input variables. The subspace is choosen to maximize the
8//! covariance between responses and independant variables.
9//!
10//! This approach is particularly useful when the original data are characterized by
11//! a large number of highly collinear variables measured on a small number of samples.
12//!
13//! The implementation is a port of the scikit-learn 0.24 cross-decomposition module.
14//!
15//! ## References
16//!
17//! * [A survey of Partial Least Squares (PLS) methods, with emphasis on the two-block case JA Wegelin](https://stat.uw.edu/sites/default/files/files/reports/2000/tr371.pdf)
18//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/cross_decomposition.html#cross-decomposition)
19//!
20//! ## Example
21//!
22//! ```rust, ignore
23//! use linfa::prelude::*;
24//! use linfa_pls::{errors::Result, PlsRegression};
25//! use ndarray::array;
26//!
27//! // Load linnerud datase 20 samples, 3 input features, 3 output features
28//! let ds = linnerud();
29//!
30//! // Fit PLS2 method using 2 principal components (latent variables)
31//! let pls = PlsRegression::params(2).fit(&ds)?;
32//!
33//! // We can either apply the dimension reduction to a dataset
34//! let reduced_ds = pls.transform(ds);
35//!
36//! // ... or predict outputs given a new input sample.
37//! let exercices = array![[14., 146., 61.], [6., 80., 60.]];
38//! let physio_measures = pls.predict(exercices);
39//! ```
40mod errors;
41mod hyperparams;
42mod pls_generic;
43pub use pls_generic::*;
44mod pls_svd;
45mod utils;
46
47use linfa::{traits::Fit, traits::PredictInplace, traits::Transformer, DatasetBase, Float};
48use ndarray::{Array2, ArrayBase, Data, Ix2};
49
50pub use errors::*;
51pub use hyperparams::*;
52pub use pls_svd::*;
53
54#[cfg(feature = "serde")]
55use serde_crate::{Deserialize, Serialize};
56
57macro_rules! pls_algo { ($name:ident) => {
58    paste::item! {
59        #[cfg_attr(
60            feature = "serde",
61            derive(Serialize, Deserialize),
62            serde(crate = "serde_crate")
63        )]
64        #[derive(Debug, Clone, PartialEq)]
65        pub struct [<Pls $name>]<F: Float>(Pls<F>);
66        impl<F: Float> [<Pls $name>]<F> {
67
68            pub fn params(n_components: usize) -> [<Pls $name Params>]<F> {
69                [<Pls $name Params>]([<Pls $name ValidParams>](Pls::[<$name:lower>](n_components).0))
70            }
71
72            /// Singular vectors of the cross-covariance matrices
73            pub fn weights(&self) -> (&Array2<F>, &Array2<F>) {
74                self.0.weights()
75            }
76
77            /// Loadings of records and targets
78            pub fn loadings(&self) -> (&Array2<F>, &Array2<F>) {
79                self.0.loadings()
80            }
81
82            /// Projection matrices used to transform records and targets
83            pub fn rotations(&self) -> (&Array2<F>, &Array2<F>) {
84                self.0.rotations()
85            }
86
87            /// The coefficients of the linear model such that Y is approximated as Y = X.coefficients
88            pub fn coefficients(&self) -> &Array2<F> {
89                self.0.coefficients()
90            }
91
92            /// Transform the given dataset in the projected space back to the original space.
93            pub fn inverse_transform(
94                &self,
95                dataset: DatasetBase<
96                    ArrayBase<impl Data<Elem = F>, Ix2>,
97                    ArrayBase<impl Data<Elem = F>, Ix2>,
98                >,
99            ) -> DatasetBase<Array2<F>, Array2<F>> {
100                self.0.inverse_transform(dataset)
101            }
102        }
103
104        impl<F: Float, D: Data<Elem = F>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, PlsError>
105            for [<Pls $name ValidParams>]<F>
106        {
107            type Object = [<Pls $name>]<F>;
108            fn fit(
109                &self,
110                dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
111            ) -> Result<Self::Object> {
112                let pls = self.0.fit(dataset)?;
113                Ok([<Pls $name>](pls))
114            }
115        }
116
117        impl<F: Float, D: Data<Elem = F>> Transformer<
118            DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
119            DatasetBase<Array2<F>, Array2<F>>,
120        > for [<Pls $name>]<F>
121        {
122            /// Apply dimension reduction to the given dataset
123            fn transform(
124                &self,
125                dataset: DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
126            ) -> DatasetBase<Array2<F>, Array2<F>> {
127                self.0.transform(dataset)
128            }
129        }
130
131        impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<F>> for [<Pls $name>]<F> {
132            /// Given an input matrix `X`, with shape `(n_samples, n_features)`,
133            /// `predict` returns the target variable according to [<Pls $name>] method
134            /// learned from the training data distribution.
135            fn predict_inplace(&'_ self, x: &ArrayBase<D, Ix2>, y: &mut Array2<F>) {
136                self.0.predict_inplace(x, y);
137            }
138
139            fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
140                self.0.default_target(x)
141            }
142        }
143    }
144}}
145
146pls_algo!(Regression);
147pls_algo!(Canonical);
148pls_algo!(Cca);
149
150#[cfg(test)]
151mod test {
152    use super::*;
153    use approx::assert_abs_diff_eq;
154    use linfa::{traits::Fit, traits::Predict, traits::Transformer};
155    use linfa_datasets::linnerud;
156    use ndarray::array;
157
158    macro_rules! test_pls_algo {
159        (Svd) => {
160            paste::item! {
161                #[test]
162                fn [<test_pls_svd>]() -> Result<()> {
163                    let ds = linnerud();
164                    let pls = PlsSvd::<f64>::params(3).fit(&ds)?;
165                    let _ds1 = pls.transform(ds);
166                    Ok(())
167                }
168            }
169        };
170
171        ($name:ident, $expected:expr) => {
172            paste::item! {
173                #[test]
174                fn [<test_pls_$name:lower>]() -> Result<()> {
175                    let ds = linnerud();
176                    let pls = [<Pls $name>]::<f64>::params(2).fit(&ds)?;
177                    let _ds1 = pls.transform(ds);
178                    let exercices = array![[14., 146., 61.], [6., 80., 60.]];
179                    let physios = pls.predict(exercices);
180                    assert_abs_diff_eq!($expected, physios.targets(), epsilon=1e-2);
181                    Ok(())
182                }
183            }
184        };
185    }
186
187    // Prediction values were checked against scikit-learn 0.24.1
188    test_pls_algo!(
189        Canonical,
190        array![
191            [180.56979423, 33.29543984, 56.90850758],
192            [190.854022, 38.91963398, 53.26914489]
193        ]
194    );
195    test_pls_algo!(
196        Regression,
197        array![
198            [172.39580643, 34.11919145, 57.15430526],
199            [192.11167813, 38.05058858, 53.99844922]
200        ]
201    );
202    test_pls_algo!(
203        Cca,
204        array![
205            [181.56238421, 34.42502589, 57.31447865],
206            [205.11767414, 40.23445194, 52.26494323]
207        ]
208    );
209    test_pls_algo!(Svd);
210
211    #[test]
212    fn test_one_component_equivalence() -> Result<()> {
213        // PlsRegression, PlsSvd and PLSCanonical should all be equivalent when n_components is 1
214        let ds = linnerud();
215        let regression = PlsRegression::params(1).fit(&ds)?.transform(linnerud());
216        let canonical = PlsCanonical::params(1).fit(&ds)?.transform(linnerud());
217        let svd = PlsSvd::<f64>::params(1).fit(&ds)?.transform(linnerud());
218
219        assert_abs_diff_eq!(regression.records(), canonical.records(), epsilon = 1e-5);
220        assert_abs_diff_eq!(svd.records(), canonical.records(), epsilon = 1e-5);
221        Ok(())
222    }
223}