linfa_logistic/
hyperparams.rs

1use linfa::ParamGuard;
2use ndarray::{Array, Dimension};
3
4use crate::error::Error;
5use crate::float::Float;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10/// A generalized logistic regression type that specializes as either binomial logistic regression
11/// or multinomial logistic regression.
12#[derive(Debug, Clone, PartialEq)]
13#[cfg_attr(
14    feature = "serde",
15    derive(Serialize, Deserialize),
16    serde(crate = "serde_crate"),
17    serde(bound(deserialize = "D: Deserialize<'de>"))
18)]
19pub struct LogisticRegressionParams<F: Float, D: Dimension>(LogisticRegressionValidParams<F, D>);
20
21#[derive(Debug, Clone, PartialEq)]
22#[cfg_attr(
23    feature = "serde",
24    derive(Serialize, Deserialize),
25    serde(crate = "serde_crate"),
26    serde(bound(deserialize = "D: Deserialize<'de>"))
27)]
28pub struct LogisticRegressionValidParams<F: Float, D: Dimension> {
29    pub(crate) alpha: F,
30    pub(crate) fit_intercept: bool,
31    pub(crate) max_iterations: u64,
32    pub(crate) gradient_tolerance: F,
33    pub(crate) initial_params: Option<Array<F, D>>,
34}
35
36impl<F: Float, D: Dimension> ParamGuard for LogisticRegressionParams<F, D> {
37    type Checked = LogisticRegressionValidParams<F, D>;
38    type Error = Error;
39
40    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
41        if !self.0.alpha.is_finite() || self.0.alpha < F::zero() {
42            return Err(Error::InvalidAlpha);
43        }
44        if !self.0.gradient_tolerance.is_finite() || self.0.gradient_tolerance <= F::zero() {
45            return Err(Error::InvalidGradientTolerance);
46        }
47        if let Some(params) = self.0.initial_params.as_ref() {
48            if params.iter().any(|p| !p.is_finite()) {
49                return Err(Error::InvalidInitialParameters);
50            }
51        }
52        Ok(&self.0)
53    }
54
55    fn check(self) -> Result<Self::Checked, Self::Error> {
56        self.check_ref()?;
57        Ok(self.0)
58    }
59}
60
61impl<F: Float, D: Dimension> LogisticRegressionParams<F, D> {
62    /// Creates a new LogisticRegression with default configuration.
63    pub fn new() -> Self {
64        Self(LogisticRegressionValidParams {
65            alpha: F::cast(1.0),
66            fit_intercept: true,
67            max_iterations: 100,
68            gradient_tolerance: F::cast(1e-4),
69            initial_params: None,
70        })
71    }
72
73    /// Set the regularization parameter `alpha` used for L2 regularization,
74    /// defaults to `1.0`.
75    pub fn alpha(mut self, alpha: F) -> Self {
76        self.0.alpha = alpha;
77        self
78    }
79
80    /// Configure if an intercept should be fitted, defaults to `true`.
81    pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
82        self.0.fit_intercept = fit_intercept;
83        self
84    }
85
86    /// Configure the maximum number of iterations that the solver should perform,
87    /// defaults to `100`.
88    pub fn max_iterations(mut self, max_iterations: u64) -> Self {
89        self.0.max_iterations = max_iterations;
90        self
91    }
92
93    /// Configure the minimum change to the gradient to continue the solver,
94    /// defaults to `1e-4`.
95    pub fn gradient_tolerance(mut self, gradient_tolerance: F) -> Self {
96        self.0.gradient_tolerance = gradient_tolerance;
97        self
98    }
99
100    /// Configure the initial parameters from where the optimization starts.  The `params` array
101    /// must have the same number of rows as there are columns on the feature matrix `x` passed to
102    /// the `fit` method. If `with_intercept` is set, then it needs to have one more row. For
103    /// multinomial regression, `params` also must have the same number of columns as the number of
104    /// distinct classes in `y`.
105    pub fn initial_params(mut self, params: Array<F, D>) -> Self {
106        self.0.initial_params = Some(params);
107        self
108    }
109}