linfa_ftrl/
hyperparams.rs

1use crate::error::FtrlError;
2use linfa::{Float, ParamGuard};
3use rand::Rng;
4#[cfg(feature = "serde")]
5use serde_crate::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq)]
8#[cfg_attr(
9    feature = "serde",
10    derive(Serialize, Deserialize),
11    serde(crate = "serde_crate")
12)]
13pub struct FtrlParams<F: Float, R: Rng>(pub(crate) FtrlValidParams<F, R>);
14
15/// A verified hyper-parameter set ready for the estimation of a Follow the regularized leader - proximal model
16///
17/// See [`FtrlParams`](crate::FtrlParams) for more information.
18#[derive(Debug, Clone, PartialEq)]
19#[cfg_attr(
20    feature = "serde",
21    derive(Serialize, Deserialize),
22    serde(crate = "serde_crate")
23)]
24pub struct FtrlValidParams<F: Float, R: Rng> {
25    pub(crate) alpha: F,
26    pub(crate) beta: F,
27    pub(crate) l1_ratio: F,
28    pub(crate) l2_ratio: F,
29    pub(crate) rng: R,
30}
31
32impl<F: Float, R: Rng> ParamGuard for FtrlParams<F, R> {
33    type Checked = FtrlValidParams<F, R>;
34    type Error = FtrlError;
35
36    /// Validate the hyper parameters
37    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
38        if !(F::zero()..=F::one()).contains(&self.0.l1_ratio) {
39            Err(FtrlError::InvalidL1Ratio(self.0.l1_ratio.to_f32().unwrap()))
40        } else if !(F::zero()..=F::one()).contains(&self.0.l2_ratio) {
41            Err(FtrlError::InvalidL2Ratio(self.0.l2_ratio.to_f32().unwrap()))
42        } else if !&self.0.alpha.is_finite() || self.0.alpha.is_negative() {
43            Err(FtrlError::InvalidAlpha(self.0.alpha.to_f32().unwrap()))
44        } else if !&self.0.beta.is_finite() || self.0.beta.is_negative() {
45            Err(FtrlError::InvalidBeta(self.0.beta.to_f32().unwrap()))
46        } else {
47            Ok(&self.0)
48        }
49    }
50
51    fn check(self) -> Result<Self::Checked, Self::Error> {
52        self.check_ref()?;
53        Ok(self.0)
54    }
55}
56
57impl<F: Float, R: Rng> FtrlValidParams<F, R> {
58    pub fn alpha(&self) -> F {
59        self.alpha
60    }
61
62    pub fn beta(&self) -> F {
63        self.beta
64    }
65
66    pub fn l1_ratio(&self) -> F {
67        self.l1_ratio
68    }
69
70    pub fn l2_ratio(&self) -> F {
71        self.l2_ratio
72    }
73
74    pub fn rng(&self) -> &R {
75        &self.rng
76    }
77}
78
79impl<F: Float, R: Rng> FtrlParams<F, R> {
80    /// Create new hyperparameters with pre-defined values
81    pub fn new(alpha: F, beta: F, l1_ratio: F, l2_ratio: F, rng: R) -> Self {
82        Self(FtrlValidParams {
83            alpha,
84            beta,
85            l1_ratio,
86            l2_ratio,
87            rng,
88        })
89    }
90
91    /// Create new hyperparameters with pre-defined random number generator
92    pub fn default_with_rng(rng: R) -> Self {
93        Self(FtrlValidParams {
94            alpha: F::cast(0.005),
95            beta: F::cast(0.0),
96            l1_ratio: F::cast(0.5),
97            l2_ratio: F::cast(0.5),
98            rng,
99        })
100    }
101
102    /// Set the learning rate.
103    ///
104    /// Defaults to `0.005` if not set
105    ///
106    /// `alpha` must be positive and finite
107    pub fn alpha(mut self, alpha: F) -> Self {
108        self.0.alpha = alpha;
109        self
110    }
111
112    /// Set the beta parameter.
113    ///
114    /// Defaults to `0.0` if not set
115    ///
116    /// `beta` must be positive and finite
117    pub fn beta(mut self, beta: F) -> Self {
118        self.0.beta = beta;
119        self
120    }
121
122    /// Set l1_ratio parameter. Controls how the parameter
123    ///
124    /// Defaults to `0.5` if not set
125    ///
126    /// `l1_ratio` must be between `0.0` and `1.0`.
127    pub fn l1_ratio(mut self, l1_ratio: F) -> Self {
128        self.0.l1_ratio = l1_ratio;
129        self
130    }
131
132    /// Set l2_ratio parameter. Controls how the parameter
133    /// penalty is distributed to L2 regularization.
134    ///
135    /// Defaults to `0.5` if not set
136    ///
137    /// `l2_ratio` must be between `0.0` and `1.0`.
138    pub fn l2_ratio(mut self, l2_ratio: F) -> Self {
139        self.0.l2_ratio = l2_ratio;
140        self
141    }
142
143    /// Set random number generator. Used to initialize Z values
144    ///
145    /// Defaults to Xoshiro256Plus
146    ///
147    /// `rng` must have Clone trait implemented.
148    pub fn rng(mut self, rng: R) -> Self {
149        self.0.rng = rng;
150        self
151    }
152}