linfa_logistic/
hyperparams.rs

1use linfa::ParamGuard;
2use ndarray::{Array, Dimension};
3
4use crate::error::Error;
5use crate::float::Float;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10/// A generalized logistic regression type that specializes as either binomial logistic regression
11/// or multinomial logistic regression.
12#[derive(Debug, Clone, PartialEq)]
13#[cfg_attr(
14    feature = "serde",
15    derive(Serialize, Deserialize),
16    serde(crate = "serde_crate")
17)]
18pub struct LogisticRegressionParams<F: Float, D: Dimension>(LogisticRegressionValidParams<F, D>);
19
20#[derive(Debug, Clone, PartialEq)]
21#[cfg_attr(
22    feature = "serde",
23    derive(Serialize, Deserialize),
24    serde(crate = "serde_crate")
25)]
26pub struct LogisticRegressionValidParams<F: Float, D: Dimension> {
27    pub(crate) alpha: F,
28    pub(crate) fit_intercept: bool,
29    pub(crate) max_iterations: u64,
30    pub(crate) gradient_tolerance: F,
31    pub(crate) initial_params: Option<Array<F, D>>,
32}
33
34impl<F: Float, D: Dimension> ParamGuard for LogisticRegressionParams<F, D> {
35    type Checked = LogisticRegressionValidParams<F, D>;
36    type Error = Error;
37
38    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
39        if !self.0.alpha.is_finite() || self.0.alpha < F::zero() {
40            return Err(Error::InvalidAlpha);
41        }
42        if !self.0.gradient_tolerance.is_finite() || self.0.gradient_tolerance <= F::zero() {
43            return Err(Error::InvalidGradientTolerance);
44        }
45        if let Some(params) = self.0.initial_params.as_ref() {
46            if params.iter().any(|p| !p.is_finite()) {
47                return Err(Error::InvalidInitialParameters);
48            }
49        }
50        Ok(&self.0)
51    }
52
53    fn check(self) -> Result<Self::Checked, Self::Error> {
54        self.check_ref()?;
55        Ok(self.0)
56    }
57}
58
59impl<F: Float, D: Dimension> LogisticRegressionParams<F, D> {
60    /// Creates a new LogisticRegression with default configuration.
61    pub fn new() -> Self {
62        Self(LogisticRegressionValidParams {
63            alpha: F::cast(1.0),
64            fit_intercept: true,
65            max_iterations: 100,
66            gradient_tolerance: F::cast(1e-4),
67            initial_params: None,
68        })
69    }
70
71    /// Set the regularization parameter `alpha` used for L2 regularization,
72    /// defaults to `1.0`.
73    pub fn alpha(mut self, alpha: F) -> Self {
74        self.0.alpha = alpha;
75        self
76    }
77
78    /// Configure if an intercept should be fitted, defaults to `true`.
79    pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
80        self.0.fit_intercept = fit_intercept;
81        self
82    }
83
84    /// Configure the maximum number of iterations that the solver should perform,
85    /// defaults to `100`.
86    pub fn max_iterations(mut self, max_iterations: u64) -> Self {
87        self.0.max_iterations = max_iterations;
88        self
89    }
90
91    /// Configure the minimum change to the gradient to continue the solver,
92    /// defaults to `1e-4`.
93    pub fn gradient_tolerance(mut self, gradient_tolerance: F) -> Self {
94        self.0.gradient_tolerance = gradient_tolerance;
95        self
96    }
97
98    /// Configure the initial parameters from where the optimization starts.  The `params` array
99    /// must have the same number of rows as there are columns on the feature matrix `x` passed to
100    /// the `fit` method. If `with_intercept` is set, then it needs to have one more row. For
101    /// multinomial regression, `params` also must have the same number of columns as the number of
102    /// distinct classes in `y`.
103    pub fn initial_params(mut self, params: Array<F, D>) -> Self {
104        self.0.initial_params = Some(params);
105        self
106    }
107}