linfa_logistic/
hyperparams.rs1use linfa::ParamGuard;
2use ndarray::{Array, Dimension};
3
4use crate::error::Error;
5use crate::float::Float;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, PartialEq)]
13#[cfg_attr(
14 feature = "serde",
15 derive(Serialize, Deserialize),
16 serde(crate = "serde_crate"),
17 serde(bound(deserialize = "D: Deserialize<'de>"))
18)]
19pub struct LogisticRegressionParams<F: Float, D: Dimension>(LogisticRegressionValidParams<F, D>);
20
21#[derive(Debug, Clone, PartialEq)]
22#[cfg_attr(
23 feature = "serde",
24 derive(Serialize, Deserialize),
25 serde(crate = "serde_crate"),
26 serde(bound(deserialize = "D: Deserialize<'de>"))
27)]
28pub struct LogisticRegressionValidParams<F: Float, D: Dimension> {
29 pub(crate) alpha: F,
30 pub(crate) fit_intercept: bool,
31 pub(crate) max_iterations: u64,
32 pub(crate) gradient_tolerance: F,
33 pub(crate) initial_params: Option<Array<F, D>>,
34}
35
36impl<F: Float, D: Dimension> ParamGuard for LogisticRegressionParams<F, D> {
37 type Checked = LogisticRegressionValidParams<F, D>;
38 type Error = Error;
39
40 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
41 if !self.0.alpha.is_finite() || self.0.alpha < F::zero() {
42 return Err(Error::InvalidAlpha);
43 }
44 if !self.0.gradient_tolerance.is_finite() || self.0.gradient_tolerance <= F::zero() {
45 return Err(Error::InvalidGradientTolerance);
46 }
47 if let Some(params) = self.0.initial_params.as_ref() {
48 if params.iter().any(|p| !p.is_finite()) {
49 return Err(Error::InvalidInitialParameters);
50 }
51 }
52 Ok(&self.0)
53 }
54
55 fn check(self) -> Result<Self::Checked, Self::Error> {
56 self.check_ref()?;
57 Ok(self.0)
58 }
59}
60
61impl<F: Float, D: Dimension> LogisticRegressionParams<F, D> {
62 pub fn new() -> Self {
64 Self(LogisticRegressionValidParams {
65 alpha: F::cast(1.0),
66 fit_intercept: true,
67 max_iterations: 100,
68 gradient_tolerance: F::cast(1e-4),
69 initial_params: None,
70 })
71 }
72
73 pub fn alpha(mut self, alpha: F) -> Self {
76 self.0.alpha = alpha;
77 self
78 }
79
80 pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
82 self.0.fit_intercept = fit_intercept;
83 self
84 }
85
86 pub fn max_iterations(mut self, max_iterations: u64) -> Self {
89 self.0.max_iterations = max_iterations;
90 self
91 }
92
93 pub fn gradient_tolerance(mut self, gradient_tolerance: F) -> Self {
96 self.0.gradient_tolerance = gradient_tolerance;
97 self
98 }
99
100 pub fn initial_params(mut self, params: Array<F, D>) -> Self {
106 self.0.initial_params = Some(params);
107 self
108 }
109}