linfa_linear/glm/
hyperparams.rs

1use crate::{glm::link::Link, LinearError, TweedieRegressor};
2use linfa::{Float, ParamGuard};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6/// The set of hyperparameters that can be specified for the execution of the Tweedie Regressor.
7#[derive(Debug, Clone, PartialEq)]
8#[cfg_attr(
9    feature = "serde",
10    derive(Serialize, Deserialize),
11    serde(crate = "serde_crate")
12)]
13pub struct TweedieRegressorValidParams<F> {
14    alpha: F,
15    fit_intercept: bool,
16    power: F,
17    link: Option<Link>,
18    max_iter: usize,
19    tol: F,
20}
21
22impl<F: Float> TweedieRegressorValidParams<F> {
23    pub fn alpha(&self) -> F {
24        self.alpha
25    }
26
27    pub fn fit_intercept(&self) -> bool {
28        self.fit_intercept
29    }
30
31    pub fn power(&self) -> F {
32        self.power
33    }
34
35    pub fn link(&self) -> Link {
36        match self.link {
37            Some(x) => x,
38            None if self.power <= F::zero() => Link::Identity,
39            None => Link::Log,
40        }
41    }
42
43    pub fn max_iter(&self) -> usize {
44        self.max_iter
45    }
46
47    pub fn tol(&self) -> F {
48        self.tol
49    }
50}
51
52/// The set of hyperparameters that can be specified for the execution of the Tweedie Regressor.
53#[derive(Debug, Clone, PartialEq)]
54pub struct TweedieRegressorParams<F>(TweedieRegressorValidParams<F>);
55
56impl<F: Float> Default for TweedieRegressorParams<F> {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl<F: Float> TweedieRegressor<F> {
63    pub fn params() -> TweedieRegressorParams<F> {
64        TweedieRegressorParams::new()
65    }
66}
67
68impl<F: Float> TweedieRegressorParams<F> {
69    pub fn new() -> Self {
70        Self(TweedieRegressorValidParams {
71            alpha: F::one(),
72            fit_intercept: true,
73            power: F::one(),
74            link: None,
75            max_iter: 100,
76            tol: F::cast(1e-4),
77        })
78    }
79
80    /// Constant that multiplies with the penalty term and thus determines the
81    /// regularization strenght. `alpha` set to 0 is equivalent to unpenalized GLM.
82    pub fn alpha(mut self, alpha: F) -> Self {
83        self.0.alpha = alpha;
84        self
85    }
86
87    /// Specifies whether a bias or intercept should be added to the model
88    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
89        self.0.fit_intercept = fit_intercept;
90        self
91    }
92
93    /// The power determines the underlying target distribution
94    pub fn power(mut self, power: F) -> Self {
95        self.0.power = power;
96        self
97    }
98
99    /// The link function of the GLM, for mapping from linear predictor `x @ coeff + intercept` to
100    /// the prediction. If no value is set, the link will be selected based on the following,
101    /// - [`identity`](Link::Identity) for Normal distribution (`power` = 0)
102    /// - [`log`](Link::Log) for Poisson, Gamma and Inverse Gaussian distributions (`power` >= 1)
103    pub fn link(mut self, link: Link) -> Self {
104        self.0.link = Some(link);
105        self
106    }
107
108    /// Maximum number of iterations for the LBFGS solver
109    pub fn max_iter(mut self, max_iter: usize) -> Self {
110        self.0.max_iter = max_iter;
111        self
112    }
113
114    /// Stopping criterion for the LBFGS solver
115    pub fn tol(mut self, tol: F) -> Self {
116        self.0.tol = tol;
117        self
118    }
119}
120
121impl<F: Float> ParamGuard for TweedieRegressorParams<F> {
122    type Checked = TweedieRegressorValidParams<F>;
123    type Error = LinearError<F>;
124
125    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
126        if self.0.alpha.is_sign_negative() {
127            Err(LinearError::InvalidPenalty(self.0.alpha))
128        } else if self.0.power > F::zero() && self.0.power < F::one() {
129            Err(LinearError::InvalidTweediePower(self.0.power))
130        } else {
131            Ok(&self.0)
132        }
133    }
134
135    fn check(self) -> Result<Self::Checked, Self::Error> {
136        self.check_ref()?;
137        Ok(self.0)
138    }
139}