linfa_pls/
pls_svd.rs

1use crate::errors::{PlsError, Result};
2use crate::{utils, Float};
3use linfa::dataset::{WithLapack, WithoutLapack};
4use linfa::{dataset::Records, traits::Fit, traits::Transformer, DatasetBase};
5#[cfg(not(feature = "blas"))]
6use linfa_linalg::svd::*;
7use ndarray::{s, Array1, Array2, ArrayBase, Data, Ix2};
8#[cfg(feature = "blas")]
9use ndarray_linalg::svd::*;
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12
13#[cfg_attr(
14    feature = "serde",
15    derive(Serialize, Deserialize),
16    serde(crate = "serde_crate")
17)]
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct PlsSvdParams {
20    n_components: usize,
21    scale: bool,
22}
23
24impl PlsSvdParams {
25    pub fn new(n_components: usize) -> PlsSvdParams {
26        PlsSvdParams {
27            n_components,
28            scale: true,
29        }
30    }
31
32    pub fn scale(mut self, scale: bool) -> Self {
33        self.scale = scale;
34        self
35    }
36}
37
38impl Default for PlsSvdParams {
39    fn default() -> Self {
40        Self::new(2)
41    }
42}
43
44#[allow(clippy::many_single_char_names)]
45impl<F: Float, D: Data<Elem = F>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, PlsError>
46    for PlsSvdParams
47{
48    type Object = PlsSvd<F>;
49
50    fn fit(
51        &self,
52        dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
53    ) -> Result<Self::Object> {
54        if dataset.nsamples() < 2 {
55            return Err(PlsError::NotEnoughSamplesError(
56                dataset.records().nsamples(),
57            ));
58        }
59        // we'll compute the SVD of the cross-covariance matrix = X.T.dot(Y)
60        // This matrix rank is at most min(n_samples, n_features, n_targets) so
61        // n_components cannot be bigger than that.
62
63        let rank_upper_bound = dataset
64            .nsamples()
65            .min(dataset.nfeatures())
66            .min(dataset.targets().ncols());
67        if 1 > self.n_components || self.n_components > rank_upper_bound {
68            return Err(PlsError::BadComponentNumberError {
69                upperbound: rank_upper_bound,
70                actual: self.n_components,
71            });
72        }
73        let (x, y, x_mean, y_mean, x_std, y_std) = utils::center_scale_dataset(dataset, self.scale);
74
75        // Compute SVD of cross-covariance matrix
76        let c = x.t().dot(&y);
77        let d = c.with_lapack().svd(true, true)?;
78        #[cfg(feature = "blas")]
79        let (u, _, vt) = d;
80        #[cfg(not(feature = "blas"))]
81        let (u, _, vt) = d.sort_svd_desc();
82        // safe unwraps because both parameters are set to true in above call
83        let u = u.unwrap().slice_move(s![.., ..self.n_components]);
84        let vt = vt.unwrap().slice_move(s![..self.n_components, ..]);
85        let (u, vt) = utils::svd_flip(u, vt);
86        let v = vt.reversed_axes();
87
88        let x_weights = u.without_lapack();
89        let y_weights = v.without_lapack();
90
91        Ok(PlsSvd {
92            x_mean,
93            x_std,
94            y_mean,
95            y_std,
96            x_weights,
97            y_weights,
98        })
99    }
100}
101#[derive(Debug, Clone, PartialEq)]
102pub struct PlsSvd<F: Float> {
103    x_mean: Array1<F>,
104    x_std: Array1<F>,
105    y_mean: Array1<F>,
106    y_std: Array1<F>,
107    x_weights: Array2<F>,
108    y_weights: Array2<F>,
109}
110
111impl<F: Float> PlsSvd<F> {
112    pub fn params(n_components: usize) -> PlsSvdParams {
113        PlsSvdParams {
114            n_components,
115            scale: true,
116        }
117    }
118
119    pub(crate) fn means(&self) -> (&Array1<F>, &Array1<F>) {
120        (&self.x_mean, &self.y_mean)
121    }
122
123    pub(crate) fn stds(&self) -> (&Array1<F>, &Array1<F>) {
124        (&self.x_std, &self.y_std)
125    }
126
127    pub fn weights(&self) -> (&Array2<F>, &Array2<F>) {
128        (&self.x_weights, &self.y_weights)
129    }
130}
131
132impl<F: Float, D: Data<Elem = F>>
133    Transformer<
134        DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
135        DatasetBase<Array2<F>, Array2<F>>,
136    > for PlsSvd<F>
137{
138    fn transform(
139        &self,
140        dataset: DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
141    ) -> DatasetBase<Array2<F>, Array2<F>> {
142        let (x_mean, y_mean) = &self.means();
143        let (x_std, y_std) = &self.stds();
144        let (x_weights, y_weights) = &self.weights();
145        let xr = (dataset.records() - *x_mean) / *x_std;
146        let x_scores = xr.dot(*x_weights);
147        let yr = (dataset.targets() - *y_mean) / *y_std;
148        let y_scores = yr.dot(*y_weights);
149        DatasetBase::new(x_scores, y_scores)
150    }
151}
152
153#[cfg(test)]
154mod test {
155    use super::*;
156    use approx::assert_abs_diff_eq;
157    use linfa_datasets::linnerud;
158    use ndarray::array;
159
160    #[test]
161    fn autotraits() {
162        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
163        has_autotraits::<PlsSvd<f64>>();
164        has_autotraits::<PlsSvdParams>();
165    }
166
167    #[test]
168    fn test_svd() -> Result<()> {
169        // values checked against scikit-learn 0.24.1 PlsSVD
170        let ds = linnerud();
171        let pls = PlsSvd::<f64>::params(3).fit(&ds)?;
172        let ds = pls.transform(ds);
173        let expected_x = array![
174            [-0.37144954, -0.0544441, -0.82290137],
175            [-1.34032497, 0.19638169, -0.71715313],
176            [-0.08234873, 0.58492788, 0.86557407],
177            [-0.35496515, -0.62863268, 0.74383396],
178            [0.46311708, -0.39856773, 0.39748814],
179            [-1.30584148, -0.20072641, -0.3591439],
180            [-0.86178968, -0.43791399, 0.2111225],
181            [-0.79728366, -0.3790222, -0.32195725],
182            [1.14229739, -0.93000533, 0.19761764],
183            [3.03443501, 2.81149299, 0.22224139],
184            [0.40921689, -0.84959246, 1.30923934],
185            [1.40508381, 0.53658054, -0.09910248],
186            [1.53073864, 0.29558804, -0.01949986],
187            [-2.2227316, 0.19806308, -0.2536748],
188            [-1.49897159, -0.4114628, 0.23494514],
189            [1.3140941, 0.67110308, -0.2366431],
190            [-1.88043225, -0.41844445, 0.04307104],
191            [1.23661961, -0.09041449, -0.63734812],
192            [1.60595982, -0.37158339, -0.01919568],
193            [-1.42542371, -0.12332727, -0.73851355]
194        ];
195        assert_abs_diff_eq!(expected_x, ds.records(), epsilon = 1e-6);
196        Ok(())
197    }
198}