linfa_elasticnet/hyperparams.rs
1#[cfg(feature = "serde")]
2use serde_crate::{Deserialize, Serialize};
3
4use linfa::{Float, ParamGuard};
5
6use crate::error::ElasticNetError;
7
8use super::Result;
9
10#[cfg_attr(
11 feature = "serde",
12 derive(Serialize, Deserialize),
13 serde(crate = "serde_crate")
14)]
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct ElasticNetValidParamsBase<F, const MULTI_TASK: bool> {
17 penalty: F,
18 l1_ratio: F,
19 with_intercept: bool,
20 max_iterations: u32,
21 tolerance: F,
22}
23
24/// A verified hyper-parameter set ready for the estimation of a ElasticNet regression model
25///
26/// See [`ElasticNetParams`](crate::ElasticNetParams) for more information.
27pub type ElasticNetValidParams<F> = ElasticNetValidParamsBase<F, false>;
28
29/// A verified hyper-parameter set ready for the estimation of a multi-task ElasticNet regression model
30///
31/// See [`MultiTaskElasticNetParams`](crate::MultiTaskElasticNetParams) for more information.
32pub type MultiTaskElasticNetValidParams<F> = ElasticNetValidParamsBase<F, true>;
33
34impl<F: Float, const MULTI_TASK: bool> ElasticNetValidParamsBase<F, MULTI_TASK> {
35 pub fn penalty(&self) -> F {
36 self.penalty
37 }
38
39 pub fn l1_ratio(&self) -> F {
40 self.l1_ratio
41 }
42
43 pub fn with_intercept(&self) -> bool {
44 self.with_intercept
45 }
46
47 pub fn max_iterations(&self) -> u32 {
48 self.max_iterations
49 }
50
51 pub fn tolerance(&self) -> F {
52 self.tolerance
53 }
54}
55
56#[derive(Clone, Debug, PartialEq, Eq)]
57pub struct ElasticNetParamsBase<F, const MULTI_TASK: bool>(
58 ElasticNetValidParamsBase<F, MULTI_TASK>,
59);
60
61/// A hyper-parameter set for Elastic-Net
62///
63/// Configures and minimizes the following objective function:
64/// ```ignore
65/// 1 / (2 * n_samples) * ||y - Xw||^2_2
66/// + penalty * l1_ratio * ||w||_1
67/// + 0.5 * penalty * (1 - l1_ratio) * ||w||^2_2
68/// ```
69///
70/// The parameter set can be verified into a
71/// [`ElasticNetValidParams`](crate::hyperparams::ElasticNetValidParams) by calling
72/// [ParamGuard::check](Self::check()). It is also possible to directly fit a model with
73/// [Fit::fit](linfa::traits::Fit::fit) which implicitely verifies the parameter set prior to the
74/// model estimation and forwards any error.
75///
76/// # Parameters
77/// | Name | Default | Purpose | Range |
78/// | :--- | :--- | :---| :--- |
79/// | [penalty](Self::penalty()) | `1.0` | Overall parameter penalty | `[0, inf)` |
80/// | [l1_ratio](Self::l1_ratio()) | `0.5` | Distribution of penalty to L1 and L2 regularizations | `[0.0, 1.0]` |
81/// | [with_intercept](Self::with_intercept()) | `true` | Enable intercept | `false`, `true` |
82/// | [tolerance](Self::tolerance()) | `1e-4` | Absolute change of any of the parameters | `(0, inf)` |
83/// | [max_iterations](Self::max_iterations()) | `1000` | Maximum number of iterations | `[1, inf)` |
84///
85/// # Errors
86///
87/// The following errors can come from invalid hyper-parameters:
88///
89/// Returns [`InvalidPenalty`](ElasticNetError::InvalidPenalty) if the penalty is negative.
90///
91/// Returns [`InvalidL1Ratio`](ElasticNetError::InvalidL1Ratio) if the L1 ratio is not in unit.
92/// range
93///
94/// Returns [`InvalidTolerance`](ElasticNetError::InvalidTolerance) if the tolerance is negative.
95///
96/// # Example
97///
98/// ```rust
99/// use linfa_elasticnet::{ElasticNetParams, ElasticNetError};
100/// use linfa::prelude::*;
101/// use ndarray::array;
102///
103/// let ds = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![3.0, 2.0]);
104///
105/// // create a new parameter set with penalty equals `1e-5`
106/// let unchecked_params = ElasticNetParams::new()
107/// .penalty(1e-5);
108///
109/// // fit model with unchecked parameter set
110/// let model = unchecked_params.fit(&ds)?;
111///
112/// // transform into a verified parameter set
113/// let checked_params = unchecked_params.check()?;
114///
115/// // Regenerate model with the verified parameters, this only returns
116/// // errors originating from the fitting process
117/// let model = checked_params.fit(&ds)?;
118/// # Ok::<(), ElasticNetError>(())
119/// ```
120pub type ElasticNetParams<F> = ElasticNetParamsBase<F, false>;
121
122/// A hyper-parameter set for multi-task Elastic-Net
123///
124/// The multi-task version (Y becomes a measurement matrix) is also supported and
125/// solves the following objective function:
126/// ```ignore
127/// 1 / (2 * n_samples) * || Y - XW ||^2_F
128/// + penalty * l1_ratio * ||W||_2,1
129/// + 0.5 * penalty * (1 - l1_ratio) * ||W||^2_F
130/// ```
131///
132/// See [`ElasticNetParams`](crate::ElasticNetParams) for information on parameters and return
133/// values.
134///
135/// # Example
136///
137/// ```rust
138/// use linfa_elasticnet::{MultiTaskElasticNetParams, ElasticNetError};
139/// use linfa::prelude::*;
140/// use ndarray::array;
141///
142/// let ds = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![[3.0, 1.1], [2.0, 2.2]]);
143///
144/// // create a new parameter set with penalty equals `1e-5`
145/// let unchecked_params = MultiTaskElasticNetParams::new()
146/// .penalty(1e-5);
147///
148/// // fit model with unchecked parameter set
149/// let model = unchecked_params.fit(&ds)?;
150///
151/// // transform into a verified parameter set
152/// let checked_params = unchecked_params.check()?;
153///
154/// // Regenerate model with the verified parameters, this only returns
155/// // errors originating from the fitting process
156/// let model = checked_params.fit(&ds)?;
157/// # Ok::<(), ElasticNetError>(())
158/// ```
159pub type MultiTaskElasticNetParams<F> = ElasticNetParamsBase<F, true>;
160
161impl<F: Float, const MULTI_TASK: bool> Default for ElasticNetParamsBase<F, MULTI_TASK> {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167/// Configure and fit a Elastic Net model
168impl<F: Float, const MULTI_TASK: bool> ElasticNetParamsBase<F, MULTI_TASK> {
169 /// Create default elastic net hyper parameters
170 ///
171 /// By default, an intercept will be fitted. To disable fitting an
172 /// intercept, call `.with_intercept(false)` before calling `.fit()`.
173 ///
174 /// To additionally normalize the feature matrix before fitting, call
175 /// `fit_intercept_and_normalize()` before calling `fit()`. The feature
176 /// matrix will not be normalized by default.
177 pub fn new() -> ElasticNetParamsBase<F, MULTI_TASK> {
178 Self(ElasticNetValidParamsBase {
179 penalty: F::one(),
180 l1_ratio: F::cast(0.5),
181 with_intercept: true,
182 max_iterations: 1000,
183 tolerance: F::cast(1e-4),
184 })
185 }
186
187 /// Set the overall parameter penalty parameter of the elastic net, otherwise known as `alpha`.
188 /// Use `l1_ratio` to configure how the penalty distributed to L1 and L2
189 /// regularization.
190 pub fn penalty(mut self, penalty: F) -> Self {
191 self.0.penalty = penalty;
192 self
193 }
194
195 /// Set l1_ratio parameter of the elastic net. Controls how the parameter
196 /// penalty is distributed to L1 and L2 regularization.
197 /// Setting `l1_ratio` to 1.0 is equivalent to a "Lasso" penalization,
198 /// setting it to 0.0 is equivalent to "Ridge" penalization.
199 ///
200 /// Defaults to `0.5` if not set
201 ///
202 /// `l1_ratio` must be between `0.0` and `1.0`.
203 pub fn l1_ratio(mut self, l1_ratio: F) -> Self {
204 self.0.l1_ratio = l1_ratio;
205 self
206 }
207
208 /// Configure the elastic net model to fit an intercept.
209 /// Defaults to `true` if not set.
210 pub fn with_intercept(mut self, with_intercept: bool) -> Self {
211 self.0.with_intercept = with_intercept;
212 self
213 }
214
215 /// Set the tolerance which is the minimum absolute change in any of the
216 /// model parameters needed for the parameter optimization to continue.
217 ///
218 /// Defaults to `1e-4` if not set
219 pub fn tolerance(mut self, tolerance: F) -> Self {
220 self.0.tolerance = tolerance;
221 self
222 }
223
224 /// Set the maximum number of iterations for the optimization routine.
225 ///
226 /// Defaults to `1000` if not set
227 pub fn max_iterations(mut self, max_iterations: u32) -> Self {
228 self.0.max_iterations = max_iterations;
229 self
230 }
231}
232
233impl<F: Float, const MULTI_TASK: bool> ParamGuard for ElasticNetParamsBase<F, MULTI_TASK> {
234 type Checked = ElasticNetValidParamsBase<F, MULTI_TASK>;
235 type Error = ElasticNetError;
236
237 /// Validate the hyper parameters
238 fn check_ref(&self) -> Result<&Self::Checked> {
239 if self.0.penalty.is_negative() {
240 Err(ElasticNetError::InvalidPenalty(
241 self.0.penalty.to_f32().unwrap(),
242 ))
243 } else if !(F::zero()..=F::one()).contains(&self.0.l1_ratio) {
244 Err(ElasticNetError::InvalidL1Ratio(
245 self.0.l1_ratio.to_f32().unwrap(),
246 ))
247 } else if self.0.tolerance.is_negative() {
248 Err(ElasticNetError::InvalidTolerance(
249 self.0.tolerance.to_f32().unwrap(),
250 ))
251 } else {
252 Ok(&self.0)
253 }
254 }
255
256 fn check(self) -> Result<Self::Checked> {
257 self.check_ref()?;
258 Ok(self.0)
259 }
260}