linfa_ftrl/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod algorithm;
4mod error;
5mod hyperparams;
6
7use crate::hyperparams::FtrlValidParams;
8pub use algorithm::Result;
9pub use error::FtrlError;
10pub use hyperparams::FtrlParams;
11use linfa::Float;
12use ndarray::Array1;
13use ndarray_rand::RandomExt;
14use rand::{distributions::Uniform, Rng};
15use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256Plus};
16#[cfg(feature = "serde")]
17use serde_crate::{Deserialize, Serialize};
18
19#[derive(Debug, Clone)]
20#[cfg_attr(
21    feature = "serde",
22    derive(Serialize, Deserialize),
23    serde(crate = "serde_crate")
24)]
25pub struct Ftrl<F: Float> {
26    /// FTRL (Follow The Regularized Leader - proximal) is a linear model for CTR prediction in online learning settings.
27    /// It stores z and n values, which are later used to calculate weights at incremental model fit and during prediction.
28    /// It is a special type of linear model with sigmoid function which uses L1 and L2 regularization.
29    /// ```rust
30    /// use linfa::Dataset;
31    /// use ndarray::array;
32    /// use linfa_ftrl::Ftrl;
33    /// use linfa::prelude::*;
34    /// let dataset = Dataset::new(array![[0.], [1.]], array![true, false]);
35    /// let params = Ftrl::params();
36    /// let model = params.fit_with(None, &dataset).unwrap();
37    /// let predictions = model.predict(&dataset);
38    /// ```
39    alpha: F,
40    beta: F,
41    l1_ratio: F,
42    l2_ratio: F,
43    z: Array1<F>,
44    n: Array1<F>,
45}
46
47impl<F: Float> Ftrl<F> {
48    /// Create a default parameter set for construction of Follow The Regularized Leader - proximal model
49    /// The description can be found [here](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)
50    ///
51    /// It requires data preprocessing done in the separate step.
52    ///
53    /// Create default hyperparameters. Random number generator will default to rand_xoshiro::Xoshiro256Plus
54    pub fn params() -> FtrlParams<F, Xoshiro256Plus> {
55        FtrlParams::default_with_rng(Xoshiro256Plus::seed_from_u64(42))
56    }
57
58    /// Create default hyperparameters with custom random number generator
59    pub fn params_with_rng<R: Rng>(rng: R) -> FtrlParams<F, R> {
60        FtrlParams::default_with_rng(rng)
61    }
62
63    /// Create a new model with given parameters, number of features and custom random number generator
64    pub fn new<R: Rng + Clone>(params: FtrlValidParams<F, R>, nfeatures: usize) -> Ftrl<F> {
65        let mut rng = params.rng.clone();
66        Self {
67            alpha: params.alpha,
68            beta: params.beta,
69            l1_ratio: params.l1_ratio,
70            l2_ratio: params.l2_ratio,
71            n: Array1::zeros(nfeatures),
72            z: Array1::random_using(nfeatures, Uniform::new(F::zero(), F::one()), &mut rng),
73        }
74    }
75}