linfa_linear/glm/
hyperparams.rs1use crate::{glm::link::Link, LinearError, TweedieRegressor};
2use linfa::{Float, ParamGuard};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq)]
8#[cfg_attr(
9 feature = "serde",
10 derive(Serialize, Deserialize),
11 serde(crate = "serde_crate")
12)]
13pub struct TweedieRegressorValidParams<F> {
14 alpha: F,
15 fit_intercept: bool,
16 power: F,
17 link: Option<Link>,
18 max_iter: usize,
19 tol: F,
20}
21
22impl<F: Float> TweedieRegressorValidParams<F> {
23 pub fn alpha(&self) -> F {
24 self.alpha
25 }
26
27 pub fn fit_intercept(&self) -> bool {
28 self.fit_intercept
29 }
30
31 pub fn power(&self) -> F {
32 self.power
33 }
34
35 pub fn link(&self) -> Link {
36 match self.link {
37 Some(x) => x,
38 None if self.power <= F::zero() => Link::Identity,
39 None => Link::Log,
40 }
41 }
42
43 pub fn max_iter(&self) -> usize {
44 self.max_iter
45 }
46
47 pub fn tol(&self) -> F {
48 self.tol
49 }
50}
51
52#[derive(Debug, Clone, PartialEq)]
54pub struct TweedieRegressorParams<F>(TweedieRegressorValidParams<F>);
55
56impl<F: Float> Default for TweedieRegressorParams<F> {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl<F: Float> TweedieRegressor<F> {
63 pub fn params() -> TweedieRegressorParams<F> {
64 TweedieRegressorParams::new()
65 }
66}
67
68impl<F: Float> TweedieRegressorParams<F> {
69 pub fn new() -> Self {
70 Self(TweedieRegressorValidParams {
71 alpha: F::one(),
72 fit_intercept: true,
73 power: F::one(),
74 link: None,
75 max_iter: 100,
76 tol: F::cast(1e-4),
77 })
78 }
79
80 pub fn alpha(mut self, alpha: F) -> Self {
83 self.0.alpha = alpha;
84 self
85 }
86
87 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
89 self.0.fit_intercept = fit_intercept;
90 self
91 }
92
93 pub fn power(mut self, power: F) -> Self {
95 self.0.power = power;
96 self
97 }
98
99 pub fn link(mut self, link: Link) -> Self {
104 self.0.link = Some(link);
105 self
106 }
107
108 pub fn max_iter(mut self, max_iter: usize) -> Self {
110 self.0.max_iter = max_iter;
111 self
112 }
113
114 pub fn tol(mut self, tol: F) -> Self {
116 self.0.tol = tol;
117 self
118 }
119}
120
121impl<F: Float> ParamGuard for TweedieRegressorParams<F> {
122 type Checked = TweedieRegressorValidParams<F>;
123 type Error = LinearError<F>;
124
125 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
126 if self.0.alpha.is_sign_negative() {
127 Err(LinearError::InvalidPenalty(self.0.alpha))
128 } else if self.0.power > F::zero() && self.0.power < F::one() {
129 Err(LinearError::InvalidTweediePower(self.0.power))
130 } else {
131 Ok(&self.0)
132 }
133 }
134
135 fn check(self) -> Result<Self::Checked, Self::Error> {
136 self.check_ref()?;
137 Ok(self.0)
138 }
139}