linfa_ftrl/lib.rs
1#![doc = include_str!("../README.md")]
2
3mod algorithm;
4mod error;
5mod hyperparams;
6
7use crate::hyperparams::FtrlValidParams;
8pub use algorithm::Result;
9pub use error::FtrlError;
10pub use hyperparams::FtrlParams;
11use linfa::Float;
12use ndarray::Array1;
13use ndarray_rand::RandomExt;
14use rand::{distributions::Uniform, Rng};
15use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256Plus};
16#[cfg(feature = "serde")]
17use serde_crate::{Deserialize, Serialize};
18
19#[derive(Debug, Clone)]
20#[cfg_attr(
21 feature = "serde",
22 derive(Serialize, Deserialize),
23 serde(crate = "serde_crate")
24)]
25pub struct Ftrl<F: Float> {
26 /// FTRL (Follow The Regularized Leader - proximal) is a linear model for CTR prediction in online learning settings.
27 /// It stores z and n values, which are later used to calculate weights at incremental model fit and during prediction.
28 /// It is a special type of linear model with sigmoid function which uses L1 and L2 regularization.
29 /// ```rust
30 /// use linfa::Dataset;
31 /// use ndarray::array;
32 /// use linfa_ftrl::Ftrl;
33 /// use linfa::prelude::*;
34 /// let dataset = Dataset::new(array![[0.], [1.]], array![true, false]);
35 /// let params = Ftrl::params();
36 /// let model = params.fit_with(None, &dataset).unwrap();
37 /// let predictions = model.predict(&dataset);
38 /// ```
39 alpha: F,
40 beta: F,
41 l1_ratio: F,
42 l2_ratio: F,
43 z: Array1<F>,
44 n: Array1<F>,
45}
46
47impl<F: Float> Ftrl<F> {
48 /// Create a default parameter set for construction of Follow The Regularized Leader - proximal model
49 /// The description can be found [here](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)
50 ///
51 /// It requires data preprocessing done in the separate step.
52 ///
53 /// Create default hyperparameters. Random number generator will default to rand_xoshiro::Xoshiro256Plus
54 pub fn params() -> FtrlParams<F, Xoshiro256Plus> {
55 FtrlParams::default_with_rng(Xoshiro256Plus::seed_from_u64(42))
56 }
57
58 /// Create default hyperparameters with custom random number generator
59 pub fn params_with_rng<R: Rng>(rng: R) -> FtrlParams<F, R> {
60 FtrlParams::default_with_rng(rng)
61 }
62
63 /// Create a new model with given parameters, number of features and custom random number generator
64 pub fn new<R: Rng + Clone>(params: FtrlValidParams<F, R>, nfeatures: usize) -> Ftrl<F> {
65 let mut rng = params.rng.clone();
66 Self {
67 alpha: params.alpha,
68 beta: params.beta,
69 l1_ratio: params.l1_ratio,
70 l2_ratio: params.l2_ratio,
71 n: Array1::zeros(nfeatures),
72 z: Array1::random_using(nfeatures, Uniform::new(F::zero(), F::one()), &mut rng),
73 }
74 }
75}