linfa_pls/
hyperparams.rs

1use crate::{Algorithm, DeflationMode, Mode, PlsError};
2use linfa::{Float, ParamGuard};
3
4#[derive(Debug, Clone, PartialEq)]
5pub(crate) struct PlsValidParams<F: Float> {
6    n_components: usize,
7    max_iter: usize,
8    tolerance: F,
9    scale: bool,
10    algorithm: Algorithm,
11    deflation_mode: DeflationMode,
12    mode: Mode,
13}
14
15impl<F: Float> PlsValidParams<F> {
16    pub fn n_components(&self) -> usize {
17        self.n_components
18    }
19
20    pub fn max_iter(&self) -> usize {
21        self.max_iter
22    }
23
24    pub fn tolerance(&self) -> F {
25        self.tolerance
26    }
27
28    pub fn scale(&self) -> bool {
29        self.scale
30    }
31
32    pub fn algorithm(&self) -> Algorithm {
33        self.algorithm
34    }
35
36    pub fn deflation_mode(&self) -> DeflationMode {
37        self.deflation_mode
38    }
39
40    pub fn mode(&self) -> Mode {
41        self.mode
42    }
43}
44
45#[derive(Debug, Clone, PartialEq)]
46pub(crate) struct PlsParams<F: Float>(pub(crate) PlsValidParams<F>);
47
48impl<F: Float> PlsParams<F> {
49    pub fn new(n_components: usize) -> PlsParams<F> {
50        Self(PlsValidParams {
51            n_components,
52            max_iter: 500,
53            tolerance: F::cast(1e-6),
54            scale: true,
55            algorithm: Algorithm::Nipals,
56            deflation_mode: DeflationMode::Regression,
57            mode: Mode::A,
58        })
59    }
60
61    #[cfg(test)]
62    pub fn max_iterations(mut self, max_iter: usize) -> Self {
63        self.0.max_iter = max_iter;
64        self
65    }
66
67    #[cfg(test)]
68    pub fn tolerance(mut self, tolerance: F) -> Self {
69        self.0.tolerance = tolerance;
70        self
71    }
72
73    #[cfg(test)]
74    pub fn scale(mut self, scale: bool) -> Self {
75        self.0.scale = scale;
76        self
77    }
78
79    #[cfg(test)]
80    pub fn algorithm(mut self, algorithm: Algorithm) -> Self {
81        self.0.algorithm = algorithm;
82        self
83    }
84
85    pub fn deflation_mode(mut self, deflation_mode: DeflationMode) -> Self {
86        self.0.deflation_mode = deflation_mode;
87        self
88    }
89
90    pub fn mode(mut self, mode: Mode) -> Self {
91        self.0.mode = mode;
92        self
93    }
94}
95
96impl<F: Float> ParamGuard for PlsParams<F> {
97    type Checked = PlsValidParams<F>;
98    type Error = PlsError;
99
100    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
101        if self.0.tolerance.is_negative()
102            || self.0.tolerance.is_nan()
103            || self.0.tolerance.is_infinite()
104        {
105            Err(PlsError::InvalidTolerance(
106                self.0.tolerance().to_f32().unwrap(),
107            ))
108        } else if self.0.max_iter == 0 {
109            Err(PlsError::ZeroMaxIter)
110        } else {
111            Ok(&self.0)
112        }
113    }
114
115    fn check(self) -> Result<Self::Checked, Self::Error> {
116        self.check_ref()?;
117        Ok(self.0)
118    }
119}
120
121macro_rules! pls_algo { ($name:ident) => {
122    paste::item! {
123        pub struct [<Pls $name Params>]<F: Float>(pub(crate) [<Pls $name ValidParams>]<F>);
124        pub struct [<Pls $name ValidParams>]<F: Float>(pub(crate) PlsValidParams<F>);
125
126        impl<F: Float> [<Pls $name Params>]<F> {
127            /// Set the maximum number of iterations of the power method when algorithm='Nipals'. Ignored otherwise.
128            pub fn max_iterations(mut self, max_iter: usize) -> Self {
129                self.0.0.max_iter = max_iter;
130                self
131            }
132
133            /// Set the tolerance used as convergence criteria in the power method: the algorithm
134            /// stops whenever the squared norm of u_i - u_{i-1} is less than tol, where u corresponds
135            /// to the left singular vector.
136            pub fn tolerance(mut self, tolerance: F) -> Self {
137                self.0.0.tolerance = tolerance;
138                self
139            }
140
141            /// Set whether to scale the dataset
142            pub fn scale(mut self, scale: bool) -> Self {
143                self.0.0.scale = scale;
144                self
145            }
146
147            /// Set the algorithm used to estimate the first singular vectors of the cross-covariance matrix.
148            /// `Nipals` uses the power method while `Svd` will compute the whole SVD.
149            pub fn algorithm(mut self, algorithm: Algorithm) -> Self {
150                self.0.0.algorithm = algorithm;
151                self
152            }
153        }
154
155        impl<F: Float> ParamGuard for [<Pls $name Params>]<F> {
156            type Checked = [<Pls $name ValidParams>]<F>;
157            type Error = PlsError;
158
159            fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
160                if self.0.0.tolerance.is_negative() || self.0.0.tolerance.is_nan() || self.0.0.tolerance.is_infinite() {
161                    Err(PlsError::InvalidTolerance(self.0.0.tolerance.to_f32().unwrap()))
162                } else if self.0.0.max_iter == 0 {
163                    Err(PlsError::ZeroMaxIter)
164                } else {
165                    Ok(&self.0)
166                }
167            }
168
169            fn check(self) -> Result<Self::Checked, Self::Error> {
170                self.check_ref()?;
171                Ok(self.0)
172            }
173        }
174    }
175}}
176
177pls_algo!(Regression);
178pls_algo!(Canonical);
179pls_algo!(Cca);