1mod errors;
41mod hyperparams;
42mod pls_generic;
43pub use pls_generic::*;
44mod pls_svd;
45mod utils;
46
47use linfa::{traits::Fit, traits::PredictInplace, traits::Transformer, DatasetBase, Float};
48use ndarray::{Array2, ArrayBase, Data, Ix2};
49
50pub use errors::*;
51pub use hyperparams::*;
52pub use pls_svd::*;
53
54#[cfg(feature = "serde")]
55use serde_crate::{Deserialize, Serialize};
56
57macro_rules! pls_algo { ($name:ident) => {
58 paste::item! {
59 #[cfg_attr(
60 feature = "serde",
61 derive(Serialize, Deserialize),
62 serde(crate = "serde_crate")
63 )]
64 #[derive(Debug, Clone, PartialEq)]
65 pub struct [<Pls $name>]<F: Float>(Pls<F>);
66 impl<F: Float> [<Pls $name>]<F> {
67
68 pub fn params(n_components: usize) -> [<Pls $name Params>]<F> {
69 [<Pls $name Params>]([<Pls $name ValidParams>](Pls::[<$name:lower>](n_components).0))
70 }
71
72 pub fn weights(&self) -> (&Array2<F>, &Array2<F>) {
74 self.0.weights()
75 }
76
77 pub fn loadings(&self) -> (&Array2<F>, &Array2<F>) {
79 self.0.loadings()
80 }
81
82 pub fn rotations(&self) -> (&Array2<F>, &Array2<F>) {
84 self.0.rotations()
85 }
86
87 pub fn coefficients(&self) -> &Array2<F> {
89 self.0.coefficients()
90 }
91
92 pub fn inverse_transform(
94 &self,
95 dataset: DatasetBase<
96 ArrayBase<impl Data<Elem = F>, Ix2>,
97 ArrayBase<impl Data<Elem = F>, Ix2>,
98 >,
99 ) -> DatasetBase<Array2<F>, Array2<F>> {
100 self.0.inverse_transform(dataset)
101 }
102 }
103
104 impl<F: Float, D: Data<Elem = F>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, PlsError>
105 for [<Pls $name ValidParams>]<F>
106 {
107 type Object = [<Pls $name>]<F>;
108 fn fit(
109 &self,
110 dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
111 ) -> Result<Self::Object> {
112 let pls = self.0.fit(dataset)?;
113 Ok([<Pls $name>](pls))
114 }
115 }
116
117 impl<F: Float, D: Data<Elem = F>> Transformer<
118 DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
119 DatasetBase<Array2<F>, Array2<F>>,
120 > for [<Pls $name>]<F>
121 {
122 fn transform(
124 &self,
125 dataset: DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
126 ) -> DatasetBase<Array2<F>, Array2<F>> {
127 self.0.transform(dataset)
128 }
129 }
130
131 impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<F>> for [<Pls $name>]<F> {
132 fn predict_inplace(&'_ self, x: &ArrayBase<D, Ix2>, y: &mut Array2<F>) {
136 self.0.predict_inplace(x, y);
137 }
138
139 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
140 self.0.default_target(x)
141 }
142 }
143 }
144}}
145
146pls_algo!(Regression);
147pls_algo!(Canonical);
148pls_algo!(Cca);
149
150#[cfg(test)]
151mod test {
152 use super::*;
153 use approx::assert_abs_diff_eq;
154 use linfa::{traits::Fit, traits::Predict, traits::Transformer};
155 use linfa_datasets::linnerud;
156 use ndarray::array;
157
158 macro_rules! test_pls_algo {
159 (Svd) => {
160 paste::item! {
161 #[test]
162 fn [<test_pls_svd>]() -> Result<()> {
163 let ds = linnerud();
164 let pls = PlsSvd::<f64>::params(3).fit(&ds)?;
165 let _ds1 = pls.transform(ds);
166 Ok(())
167 }
168 }
169 };
170
171 ($name:ident, $expected:expr) => {
172 paste::item! {
173 #[test]
174 fn [<test_pls_$name:lower>]() -> Result<()> {
175 let ds = linnerud();
176 let pls = [<Pls $name>]::<f64>::params(2).fit(&ds)?;
177 let _ds1 = pls.transform(ds);
178 let exercices = array![[14., 146., 61.], [6., 80., 60.]];
179 let physios = pls.predict(exercices);
180 assert_abs_diff_eq!($expected, physios.targets(), epsilon=1e-2);
181 Ok(())
182 }
183 }
184 };
185 }
186
187 test_pls_algo!(
189 Canonical,
190 array![
191 [180.56979423, 33.29543984, 56.90850758],
192 [190.854022, 38.91963398, 53.26914489]
193 ]
194 );
195 test_pls_algo!(
196 Regression,
197 array![
198 [172.39580643, 34.11919145, 57.15430526],
199 [192.11167813, 38.05058858, 53.99844922]
200 ]
201 );
202 test_pls_algo!(
203 Cca,
204 array![
205 [181.56238421, 34.42502589, 57.31447865],
206 [205.11767414, 40.23445194, 52.26494323]
207 ]
208 );
209 test_pls_algo!(Svd);
210
211 #[test]
212 fn test_one_component_equivalence() -> Result<()> {
213 let ds = linnerud();
215 let regression = PlsRegression::params(1).fit(&ds)?.transform(linnerud());
216 let canonical = PlsCanonical::params(1).fit(&ds)?.transform(linnerud());
217 let svd = PlsSvd::<f64>::params(1).fit(&ds)?.transform(linnerud());
218
219 assert_abs_diff_eq!(regression.records(), canonical.records(), epsilon = 1e-5);
220 assert_abs_diff_eq!(svd.records(), canonical.records(), epsilon = 1e-5);
221 Ok(())
222 }
223}