linfa_logistic/
hyperparams.rs1use linfa::ParamGuard;
2use ndarray::{Array, Dimension};
3
4use crate::error::Error;
5use crate::float::Float;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, PartialEq)]
13#[cfg_attr(
14 feature = "serde",
15 derive(Serialize, Deserialize),
16 serde(crate = "serde_crate")
17)]
18pub struct LogisticRegressionParams<F: Float, D: Dimension>(LogisticRegressionValidParams<F, D>);
19
20#[derive(Debug, Clone, PartialEq)]
21#[cfg_attr(
22 feature = "serde",
23 derive(Serialize, Deserialize),
24 serde(crate = "serde_crate")
25)]
26pub struct LogisticRegressionValidParams<F: Float, D: Dimension> {
27 pub(crate) alpha: F,
28 pub(crate) fit_intercept: bool,
29 pub(crate) max_iterations: u64,
30 pub(crate) gradient_tolerance: F,
31 pub(crate) initial_params: Option<Array<F, D>>,
32}
33
34impl<F: Float, D: Dimension> ParamGuard for LogisticRegressionParams<F, D> {
35 type Checked = LogisticRegressionValidParams<F, D>;
36 type Error = Error;
37
38 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
39 if !self.0.alpha.is_finite() || self.0.alpha < F::zero() {
40 return Err(Error::InvalidAlpha);
41 }
42 if !self.0.gradient_tolerance.is_finite() || self.0.gradient_tolerance <= F::zero() {
43 return Err(Error::InvalidGradientTolerance);
44 }
45 if let Some(params) = self.0.initial_params.as_ref() {
46 if params.iter().any(|p| !p.is_finite()) {
47 return Err(Error::InvalidInitialParameters);
48 }
49 }
50 Ok(&self.0)
51 }
52
53 fn check(self) -> Result<Self::Checked, Self::Error> {
54 self.check_ref()?;
55 Ok(self.0)
56 }
57}
58
59impl<F: Float, D: Dimension> LogisticRegressionParams<F, D> {
60 pub fn new() -> Self {
62 Self(LogisticRegressionValidParams {
63 alpha: F::cast(1.0),
64 fit_intercept: true,
65 max_iterations: 100,
66 gradient_tolerance: F::cast(1e-4),
67 initial_params: None,
68 })
69 }
70
71 pub fn alpha(mut self, alpha: F) -> Self {
74 self.0.alpha = alpha;
75 self
76 }
77
78 pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
80 self.0.fit_intercept = fit_intercept;
81 self
82 }
83
84 pub fn max_iterations(mut self, max_iterations: u64) -> Self {
87 self.0.max_iterations = max_iterations;
88 self
89 }
90
91 pub fn gradient_tolerance(mut self, gradient_tolerance: F) -> Self {
94 self.0.gradient_tolerance = gradient_tolerance;
95 self
96 }
97
98 pub fn initial_params(mut self, params: Array<F, D>) -> Self {
104 self.0.initial_params = Some(params);
105 self
106 }
107}