linfa_elasticnet/
hyperparams.rs

1#[cfg(feature = "serde")]
2use serde_crate::{Deserialize, Serialize};
3
4use linfa::{Float, ParamGuard};
5
6use crate::error::ElasticNetError;
7
8use super::Result;
9
10#[cfg_attr(
11    feature = "serde",
12    derive(Serialize, Deserialize),
13    serde(crate = "serde_crate")
14)]
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct ElasticNetValidParamsBase<F, const MULTI_TASK: bool> {
17    penalty: F,
18    l1_ratio: F,
19    with_intercept: bool,
20    max_iterations: u32,
21    tolerance: F,
22}
23
24/// A verified hyper-parameter set ready for the estimation of a ElasticNet regression model
25///
26/// See [`ElasticNetParams`](crate::ElasticNetParams) for more information.
27pub type ElasticNetValidParams<F> = ElasticNetValidParamsBase<F, false>;
28
29/// A verified hyper-parameter set ready for the estimation of a multi-task ElasticNet regression model
30///
31/// See [`MultiTaskElasticNetParams`](crate::MultiTaskElasticNetParams) for more information.
32pub type MultiTaskElasticNetValidParams<F> = ElasticNetValidParamsBase<F, true>;
33
34impl<F: Float, const MULTI_TASK: bool> ElasticNetValidParamsBase<F, MULTI_TASK> {
35    pub fn penalty(&self) -> F {
36        self.penalty
37    }
38
39    pub fn l1_ratio(&self) -> F {
40        self.l1_ratio
41    }
42
43    pub fn with_intercept(&self) -> bool {
44        self.with_intercept
45    }
46
47    pub fn max_iterations(&self) -> u32 {
48        self.max_iterations
49    }
50
51    pub fn tolerance(&self) -> F {
52        self.tolerance
53    }
54}
55
56#[derive(Clone, Debug, PartialEq, Eq)]
57pub struct ElasticNetParamsBase<F, const MULTI_TASK: bool>(
58    ElasticNetValidParamsBase<F, MULTI_TASK>,
59);
60
61/// A hyper-parameter set for Elastic-Net
62///
63/// Configures and minimizes the following objective function:
64/// ```ignore
65/// 1 / (2 * n_samples) * ||y - Xw||^2_2
66///     + penalty * l1_ratio * ||w||_1
67///     + 0.5 * penalty * (1 - l1_ratio) * ||w||^2_2
68/// ```
69///
70/// The parameter set can be verified into a
71/// [`ElasticNetValidParams`](crate::hyperparams::ElasticNetValidParams) by calling
72/// [ParamGuard::check](Self::check()). It is also possible to directly fit a model with
73/// [Fit::fit](linfa::traits::Fit::fit) which implicitely verifies the parameter set prior to the
74/// model estimation and forwards any error.
75///
76/// # Parameters
77/// | Name | Default | Purpose | Range |
78/// | :--- | :--- | :---| :--- |
79/// | [penalty](Self::penalty()) | `1.0` | Overall parameter penalty | `[0, inf)` |
80/// | [l1_ratio](Self::l1_ratio()) | `0.5` | Distribution of penalty to L1 and L2 regularizations | `[0.0, 1.0]` |
81/// | [with_intercept](Self::with_intercept()) | `true` | Enable intercept | `false`, `true` |
82/// | [tolerance](Self::tolerance()) | `1e-4` | Absolute change of any of the parameters | `(0, inf)` |
83/// | [max_iterations](Self::max_iterations()) | `1000` | Maximum number of iterations | `[1, inf)` |
84///
85/// # Errors
86///
87/// The following errors can come from invalid hyper-parameters:
88///
89/// Returns [`InvalidPenalty`](ElasticNetError::InvalidPenalty) if the penalty is negative.
90///
91/// Returns [`InvalidL1Ratio`](ElasticNetError::InvalidL1Ratio) if the L1 ratio is not in unit.
92/// range
93///
94/// Returns [`InvalidTolerance`](ElasticNetError::InvalidTolerance) if the tolerance is negative.
95///
96/// # Example
97///
98/// ```rust
99/// use linfa_elasticnet::{ElasticNetParams, ElasticNetError};
100/// use linfa::prelude::*;
101/// use ndarray::array;
102///
103/// let ds = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![3.0, 2.0]);
104///
105/// // create a new parameter set with penalty equals `1e-5`
106/// let unchecked_params = ElasticNetParams::new()
107///     .penalty(1e-5);
108///
109/// // fit model with unchecked parameter set
110/// let model = unchecked_params.fit(&ds)?;
111///
112/// // transform into a verified parameter set
113/// let checked_params = unchecked_params.check()?;
114///
115/// // Regenerate model with the verified parameters, this only returns
116/// // errors originating from the fitting process
117/// let model = checked_params.fit(&ds)?;
118/// # Ok::<(), ElasticNetError>(())
119/// ```
120pub type ElasticNetParams<F> = ElasticNetParamsBase<F, false>;
121
122/// A hyper-parameter set for multi-task Elastic-Net
123///
124/// The multi-task version (Y becomes a measurement matrix) is also supported and
125/// solves the following objective function:
126/// ```ignore
127/// 1 / (2 * n_samples) * || Y - XW ||^2_F
128///     + penalty * l1_ratio * ||W||_2,1
129///     + 0.5 * penalty * (1 - l1_ratio) * ||W||^2_F
130/// ```
131///
132/// See [`ElasticNetParams`](crate::ElasticNetParams) for information on parameters and return
133/// values.
134///
135/// # Example
136///
137/// ```rust
138/// use linfa_elasticnet::{MultiTaskElasticNetParams, ElasticNetError};
139/// use linfa::prelude::*;
140/// use ndarray::array;
141///
142/// let ds = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![[3.0, 1.1], [2.0, 2.2]]);
143///
144/// // create a new parameter set with penalty equals `1e-5`
145/// let unchecked_params = MultiTaskElasticNetParams::new()
146///     .penalty(1e-5);
147///
148/// // fit model with unchecked parameter set
149/// let model = unchecked_params.fit(&ds)?;
150///
151/// // transform into a verified parameter set
152/// let checked_params = unchecked_params.check()?;
153///
154/// // Regenerate model with the verified parameters, this only returns
155/// // errors originating from the fitting process
156/// let model = checked_params.fit(&ds)?;
157/// # Ok::<(), ElasticNetError>(())
158/// ```
159pub type MultiTaskElasticNetParams<F> = ElasticNetParamsBase<F, true>;
160
161impl<F: Float, const MULTI_TASK: bool> Default for ElasticNetParamsBase<F, MULTI_TASK> {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167/// Configure and fit a Elastic Net model
168impl<F: Float, const MULTI_TASK: bool> ElasticNetParamsBase<F, MULTI_TASK> {
169    /// Create default elastic net hyper parameters
170    ///
171    /// By default, an intercept will be fitted. To disable fitting an
172    /// intercept, call `.with_intercept(false)` before calling `.fit()`.
173    ///
174    /// To additionally normalize the feature matrix before fitting, call
175    /// `fit_intercept_and_normalize()` before calling `fit()`. The feature
176    /// matrix will not be normalized by default.
177    pub fn new() -> ElasticNetParamsBase<F, MULTI_TASK> {
178        Self(ElasticNetValidParamsBase {
179            penalty: F::one(),
180            l1_ratio: F::cast(0.5),
181            with_intercept: true,
182            max_iterations: 1000,
183            tolerance: F::cast(1e-4),
184        })
185    }
186
187    /// Set the overall parameter penalty parameter of the elastic net, otherwise known as `alpha`.
188    /// Use `l1_ratio` to configure how the penalty distributed to L1 and L2
189    /// regularization.
190    pub fn penalty(mut self, penalty: F) -> Self {
191        self.0.penalty = penalty;
192        self
193    }
194
195    /// Set l1_ratio parameter of the elastic net. Controls how the parameter
196    /// penalty is distributed to L1 and L2 regularization.
197    /// Setting `l1_ratio` to 1.0 is equivalent to a "Lasso" penalization,
198    /// setting it to 0.0 is equivalent to "Ridge" penalization.
199    ///
200    /// Defaults to `0.5` if not set
201    ///
202    /// `l1_ratio` must be between `0.0` and `1.0`.
203    pub fn l1_ratio(mut self, l1_ratio: F) -> Self {
204        self.0.l1_ratio = l1_ratio;
205        self
206    }
207
208    /// Configure the elastic net model to fit an intercept.
209    /// Defaults to `true` if not set.
210    pub fn with_intercept(mut self, with_intercept: bool) -> Self {
211        self.0.with_intercept = with_intercept;
212        self
213    }
214
215    /// Set the tolerance which is the minimum absolute change in any of the
216    /// model parameters needed for the parameter optimization to continue.
217    ///
218    /// Defaults to `1e-4` if not set
219    pub fn tolerance(mut self, tolerance: F) -> Self {
220        self.0.tolerance = tolerance;
221        self
222    }
223
224    /// Set the maximum number of iterations for the optimization routine.
225    ///
226    /// Defaults to `1000` if not set
227    pub fn max_iterations(mut self, max_iterations: u32) -> Self {
228        self.0.max_iterations = max_iterations;
229        self
230    }
231}
232
233impl<F: Float, const MULTI_TASK: bool> ParamGuard for ElasticNetParamsBase<F, MULTI_TASK> {
234    type Checked = ElasticNetValidParamsBase<F, MULTI_TASK>;
235    type Error = ElasticNetError;
236
237    /// Validate the hyper parameters
238    fn check_ref(&self) -> Result<&Self::Checked> {
239        if self.0.penalty.is_negative() {
240            Err(ElasticNetError::InvalidPenalty(
241                self.0.penalty.to_f32().unwrap(),
242            ))
243        } else if !(F::zero()..=F::one()).contains(&self.0.l1_ratio) {
244            Err(ElasticNetError::InvalidL1Ratio(
245                self.0.l1_ratio.to_f32().unwrap(),
246            ))
247        } else if self.0.tolerance.is_negative() {
248            Err(ElasticNetError::InvalidTolerance(
249                self.0.tolerance.to_f32().unwrap(),
250            ))
251        } else {
252            Ok(&self.0)
253        }
254    }
255
256    fn check(self) -> Result<Self::Checked> {
257        self.check_ref()?;
258        Ok(self.0)
259    }
260}