1use crate::errors::{PlsError, Result};
2use crate::{utils, Float};
3use linfa::dataset::{WithLapack, WithoutLapack};
4use linfa::{dataset::Records, traits::Fit, traits::Transformer, DatasetBase};
5#[cfg(not(feature = "blas"))]
6use linfa_linalg::svd::*;
7use ndarray::{s, Array1, Array2, ArrayBase, Data, Ix2};
8#[cfg(feature = "blas")]
9use ndarray_linalg::svd::*;
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12
13#[cfg_attr(
14 feature = "serde",
15 derive(Serialize, Deserialize),
16 serde(crate = "serde_crate")
17)]
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct PlsSvdParams {
20 n_components: usize,
21 scale: bool,
22}
23
24impl PlsSvdParams {
25 pub fn new(n_components: usize) -> PlsSvdParams {
26 PlsSvdParams {
27 n_components,
28 scale: true,
29 }
30 }
31
32 pub fn scale(mut self, scale: bool) -> Self {
33 self.scale = scale;
34 self
35 }
36}
37
38impl Default for PlsSvdParams {
39 fn default() -> Self {
40 Self::new(2)
41 }
42}
43
44#[allow(clippy::many_single_char_names)]
45impl<F: Float, D: Data<Elem = F>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, PlsError>
46 for PlsSvdParams
47{
48 type Object = PlsSvd<F>;
49
50 fn fit(
51 &self,
52 dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
53 ) -> Result<Self::Object> {
54 if dataset.nsamples() < 2 {
55 return Err(PlsError::NotEnoughSamplesError(
56 dataset.records().nsamples(),
57 ));
58 }
59 let rank_upper_bound = dataset
64 .nsamples()
65 .min(dataset.nfeatures())
66 .min(dataset.targets().ncols());
67 if 1 > self.n_components || self.n_components > rank_upper_bound {
68 return Err(PlsError::BadComponentNumberError {
69 upperbound: rank_upper_bound,
70 actual: self.n_components,
71 });
72 }
73 let (x, y, x_mean, y_mean, x_std, y_std) = utils::center_scale_dataset(dataset, self.scale);
74
75 let c = x.t().dot(&y);
77 let d = c.with_lapack().svd(true, true)?;
78 #[cfg(feature = "blas")]
79 let (u, _, vt) = d;
80 #[cfg(not(feature = "blas"))]
81 let (u, _, vt) = d.sort_svd_desc();
82 let u = u.unwrap().slice_move(s![.., ..self.n_components]);
84 let vt = vt.unwrap().slice_move(s![..self.n_components, ..]);
85 let (u, vt) = utils::svd_flip(u, vt);
86 let v = vt.reversed_axes();
87
88 let x_weights = u.without_lapack();
89 let y_weights = v.without_lapack();
90
91 Ok(PlsSvd {
92 x_mean,
93 x_std,
94 y_mean,
95 y_std,
96 x_weights,
97 y_weights,
98 })
99 }
100}
101#[derive(Debug, Clone, PartialEq)]
102pub struct PlsSvd<F: Float> {
103 x_mean: Array1<F>,
104 x_std: Array1<F>,
105 y_mean: Array1<F>,
106 y_std: Array1<F>,
107 x_weights: Array2<F>,
108 y_weights: Array2<F>,
109}
110
111impl<F: Float> PlsSvd<F> {
112 pub fn params(n_components: usize) -> PlsSvdParams {
113 PlsSvdParams {
114 n_components,
115 scale: true,
116 }
117 }
118
119 pub(crate) fn means(&self) -> (&Array1<F>, &Array1<F>) {
120 (&self.x_mean, &self.y_mean)
121 }
122
123 pub(crate) fn stds(&self) -> (&Array1<F>, &Array1<F>) {
124 (&self.x_std, &self.y_std)
125 }
126
127 pub fn weights(&self) -> (&Array2<F>, &Array2<F>) {
128 (&self.x_weights, &self.y_weights)
129 }
130}
131
132impl<F: Float, D: Data<Elem = F>>
133 Transformer<
134 DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
135 DatasetBase<Array2<F>, Array2<F>>,
136 > for PlsSvd<F>
137{
138 fn transform(
139 &self,
140 dataset: DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
141 ) -> DatasetBase<Array2<F>, Array2<F>> {
142 let (x_mean, y_mean) = &self.means();
143 let (x_std, y_std) = &self.stds();
144 let (x_weights, y_weights) = &self.weights();
145 let xr = (dataset.records() - *x_mean) / *x_std;
146 let x_scores = xr.dot(*x_weights);
147 let yr = (dataset.targets() - *y_mean) / *y_std;
148 let y_scores = yr.dot(*y_weights);
149 DatasetBase::new(x_scores, y_scores)
150 }
151}
152
153#[cfg(test)]
154mod test {
155 use super::*;
156 use approx::assert_abs_diff_eq;
157 use linfa_datasets::linnerud;
158 use ndarray::array;
159
160 #[test]
161 fn autotraits() {
162 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
163 has_autotraits::<PlsSvd<f64>>();
164 has_autotraits::<PlsSvdParams>();
165 }
166
167 #[test]
168 fn test_svd() -> Result<()> {
169 let ds = linnerud();
171 let pls = PlsSvd::<f64>::params(3).fit(&ds)?;
172 let ds = pls.transform(ds);
173 let expected_x = array![
174 [-0.37144954, -0.0544441, -0.82290137],
175 [-1.34032497, 0.19638169, -0.71715313],
176 [-0.08234873, 0.58492788, 0.86557407],
177 [-0.35496515, -0.62863268, 0.74383396],
178 [0.46311708, -0.39856773, 0.39748814],
179 [-1.30584148, -0.20072641, -0.3591439],
180 [-0.86178968, -0.43791399, 0.2111225],
181 [-0.79728366, -0.3790222, -0.32195725],
182 [1.14229739, -0.93000533, 0.19761764],
183 [3.03443501, 2.81149299, 0.22224139],
184 [0.40921689, -0.84959246, 1.30923934],
185 [1.40508381, 0.53658054, -0.09910248],
186 [1.53073864, 0.29558804, -0.01949986],
187 [-2.2227316, 0.19806308, -0.2536748],
188 [-1.49897159, -0.4114628, 0.23494514],
189 [1.3140941, 0.67110308, -0.2366431],
190 [-1.88043225, -0.41844445, 0.04307104],
191 [1.23661961, -0.09041449, -0.63734812],
192 [1.60595982, -0.37158339, -0.01919568],
193 [-1.42542371, -0.12332727, -0.73851355]
194 ];
195 assert_abs_diff_eq!(expected_x, ds.records(), epsilon = 1e-6);
196 Ok(())
197 }
198}