1use crate::{Algorithm, DeflationMode, Mode, PlsError};
2use linfa::{Float, ParamGuard};
3
4#[derive(Debug, Clone, PartialEq)]
5pub(crate) struct PlsValidParams<F: Float> {
6 n_components: usize,
7 max_iter: usize,
8 tolerance: F,
9 scale: bool,
10 algorithm: Algorithm,
11 deflation_mode: DeflationMode,
12 mode: Mode,
13}
14
15impl<F: Float> PlsValidParams<F> {
16 pub fn n_components(&self) -> usize {
17 self.n_components
18 }
19
20 pub fn max_iter(&self) -> usize {
21 self.max_iter
22 }
23
24 pub fn tolerance(&self) -> F {
25 self.tolerance
26 }
27
28 pub fn scale(&self) -> bool {
29 self.scale
30 }
31
32 pub fn algorithm(&self) -> Algorithm {
33 self.algorithm
34 }
35
36 pub fn deflation_mode(&self) -> DeflationMode {
37 self.deflation_mode
38 }
39
40 pub fn mode(&self) -> Mode {
41 self.mode
42 }
43}
44
45#[derive(Debug, Clone, PartialEq)]
46pub(crate) struct PlsParams<F: Float>(pub(crate) PlsValidParams<F>);
47
48impl<F: Float> PlsParams<F> {
49 pub fn new(n_components: usize) -> PlsParams<F> {
50 Self(PlsValidParams {
51 n_components,
52 max_iter: 500,
53 tolerance: F::cast(1e-6),
54 scale: true,
55 algorithm: Algorithm::Nipals,
56 deflation_mode: DeflationMode::Regression,
57 mode: Mode::A,
58 })
59 }
60
61 #[cfg(test)]
62 pub fn max_iterations(mut self, max_iter: usize) -> Self {
63 self.0.max_iter = max_iter;
64 self
65 }
66
67 #[cfg(test)]
68 pub fn tolerance(mut self, tolerance: F) -> Self {
69 self.0.tolerance = tolerance;
70 self
71 }
72
73 #[cfg(test)]
74 pub fn scale(mut self, scale: bool) -> Self {
75 self.0.scale = scale;
76 self
77 }
78
79 #[cfg(test)]
80 pub fn algorithm(mut self, algorithm: Algorithm) -> Self {
81 self.0.algorithm = algorithm;
82 self
83 }
84
85 pub fn deflation_mode(mut self, deflation_mode: DeflationMode) -> Self {
86 self.0.deflation_mode = deflation_mode;
87 self
88 }
89
90 pub fn mode(mut self, mode: Mode) -> Self {
91 self.0.mode = mode;
92 self
93 }
94}
95
96impl<F: Float> ParamGuard for PlsParams<F> {
97 type Checked = PlsValidParams<F>;
98 type Error = PlsError;
99
100 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
101 if self.0.tolerance.is_negative()
102 || self.0.tolerance.is_nan()
103 || self.0.tolerance.is_infinite()
104 {
105 Err(PlsError::InvalidTolerance(
106 self.0.tolerance().to_f32().unwrap(),
107 ))
108 } else if self.0.max_iter == 0 {
109 Err(PlsError::ZeroMaxIter)
110 } else {
111 Ok(&self.0)
112 }
113 }
114
115 fn check(self) -> Result<Self::Checked, Self::Error> {
116 self.check_ref()?;
117 Ok(self.0)
118 }
119}
120
121macro_rules! pls_algo { ($name:ident) => {
122 paste::item! {
123 pub struct [<Pls $name Params>]<F: Float>(pub(crate) [<Pls $name ValidParams>]<F>);
124 pub struct [<Pls $name ValidParams>]<F: Float>(pub(crate) PlsValidParams<F>);
125
126 impl<F: Float> [<Pls $name Params>]<F> {
127 pub fn max_iterations(mut self, max_iter: usize) -> Self {
129 self.0.0.max_iter = max_iter;
130 self
131 }
132
133 pub fn tolerance(mut self, tolerance: F) -> Self {
137 self.0.0.tolerance = tolerance;
138 self
139 }
140
141 pub fn scale(mut self, scale: bool) -> Self {
143 self.0.0.scale = scale;
144 self
145 }
146
147 pub fn algorithm(mut self, algorithm: Algorithm) -> Self {
150 self.0.0.algorithm = algorithm;
151 self
152 }
153 }
154
155 impl<F: Float> ParamGuard for [<Pls $name Params>]<F> {
156 type Checked = [<Pls $name ValidParams>]<F>;
157 type Error = PlsError;
158
159 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
160 if self.0.0.tolerance.is_negative() || self.0.0.tolerance.is_nan() || self.0.0.tolerance.is_infinite() {
161 Err(PlsError::InvalidTolerance(self.0.0.tolerance.to_f32().unwrap()))
162 } else if self.0.0.max_iter == 0 {
163 Err(PlsError::ZeroMaxIter)
164 } else {
165 Ok(&self.0)
166 }
167 }
168
169 fn check(self) -> Result<Self::Checked, Self::Error> {
170 self.check_ref()?;
171 Ok(self.0)
172 }
173 }
174 }
175}}
176
177pls_algo!(Regression);
178pls_algo!(Canonical);
179pls_algo!(Cca);