linfa_ftrl/
hyperparams.rs1use crate::error::FtrlError;
2use linfa::{Float, ParamGuard};
3use rand::Rng;
4#[cfg(feature = "serde")]
5use serde_crate::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq)]
8#[cfg_attr(
9 feature = "serde",
10 derive(Serialize, Deserialize),
11 serde(crate = "serde_crate")
12)]
13pub struct FtrlParams<F: Float, R: Rng>(pub(crate) FtrlValidParams<F, R>);
14
15#[derive(Debug, Clone, PartialEq)]
19#[cfg_attr(
20 feature = "serde",
21 derive(Serialize, Deserialize),
22 serde(crate = "serde_crate")
23)]
24pub struct FtrlValidParams<F: Float, R: Rng> {
25 pub(crate) alpha: F,
26 pub(crate) beta: F,
27 pub(crate) l1_ratio: F,
28 pub(crate) l2_ratio: F,
29 pub(crate) rng: R,
30}
31
32impl<F: Float, R: Rng> ParamGuard for FtrlParams<F, R> {
33 type Checked = FtrlValidParams<F, R>;
34 type Error = FtrlError;
35
36 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
38 if !(F::zero()..=F::one()).contains(&self.0.l1_ratio) {
39 Err(FtrlError::InvalidL1Ratio(self.0.l1_ratio.to_f32().unwrap()))
40 } else if !(F::zero()..=F::one()).contains(&self.0.l2_ratio) {
41 Err(FtrlError::InvalidL2Ratio(self.0.l2_ratio.to_f32().unwrap()))
42 } else if !&self.0.alpha.is_finite() || self.0.alpha.is_negative() {
43 Err(FtrlError::InvalidAlpha(self.0.alpha.to_f32().unwrap()))
44 } else if !&self.0.beta.is_finite() || self.0.beta.is_negative() {
45 Err(FtrlError::InvalidBeta(self.0.beta.to_f32().unwrap()))
46 } else {
47 Ok(&self.0)
48 }
49 }
50
51 fn check(self) -> Result<Self::Checked, Self::Error> {
52 self.check_ref()?;
53 Ok(self.0)
54 }
55}
56
57impl<F: Float, R: Rng> FtrlValidParams<F, R> {
58 pub fn alpha(&self) -> F {
59 self.alpha
60 }
61
62 pub fn beta(&self) -> F {
63 self.beta
64 }
65
66 pub fn l1_ratio(&self) -> F {
67 self.l1_ratio
68 }
69
70 pub fn l2_ratio(&self) -> F {
71 self.l2_ratio
72 }
73
74 pub fn rng(&self) -> &R {
75 &self.rng
76 }
77}
78
79impl<F: Float, R: Rng> FtrlParams<F, R> {
80 pub fn new(alpha: F, beta: F, l1_ratio: F, l2_ratio: F, rng: R) -> Self {
82 Self(FtrlValidParams {
83 alpha,
84 beta,
85 l1_ratio,
86 l2_ratio,
87 rng,
88 })
89 }
90
91 pub fn default_with_rng(rng: R) -> Self {
93 Self(FtrlValidParams {
94 alpha: F::cast(0.005),
95 beta: F::cast(0.0),
96 l1_ratio: F::cast(0.5),
97 l2_ratio: F::cast(0.5),
98 rng,
99 })
100 }
101
102 pub fn alpha(mut self, alpha: F) -> Self {
108 self.0.alpha = alpha;
109 self
110 }
111
112 pub fn beta(mut self, beta: F) -> Self {
118 self.0.beta = beta;
119 self
120 }
121
122 pub fn l1_ratio(mut self, l1_ratio: F) -> Self {
128 self.0.l1_ratio = l1_ratio;
129 self
130 }
131
132 pub fn l2_ratio(mut self, l2_ratio: F) -> Self {
139 self.0.l2_ratio = l2_ratio;
140 self
141 }
142
143 pub fn rng(mut self, rng: R) -> Self {
149 self.0.rng = rng;
150 self
151 }
152}