linfa_ftrl/
algorithm.rs

1use crate::error::FtrlError;
2use crate::hyperparams::FtrlValidParams;
3use crate::Ftrl;
4use linfa::dataset::{AsSingleTargets, Pr, Records};
5use linfa::traits::{FitWith, PredictInplace};
6use linfa::{DatasetBase, Float};
7use ndarray::{Array1, ArrayBase, ArrayView1, Data, Ix2, Zip};
8use rand::Rng;
9
10/// Simplified `Result` using [`FtrlError`](crate::FtrlError) as error type
11pub type Result<T> = std::result::Result<T, FtrlError>;
12
13impl<F, R, D, T> FitWith<'_, ArrayBase<D, Ix2>, T, FtrlError> for FtrlValidParams<F, R>
14where
15    F: Float,
16    R: Rng + Clone,
17    D: Data<Elem = F>,
18    T: AsSingleTargets<Elem = bool>,
19{
20    type ObjectIn = Option<Ftrl<F>>;
21    type ObjectOut = Ftrl<F>;
22
23    /// Fit a follow the regularized leader, proximal, model given a feature matrix `x` and a target
24    /// variable `y`.
25    ///
26    /// The feature matrix `x` must have shape `(n_samples, n_features)`
27    ///
28    /// The target variable `y` must have shape `(n_samples)`
29    ///
30    /// Returns a fitted `FTRL` object which contains the fitted
31    /// parameters and can be used to `predict` values of the target variable
32    /// for new feature values.
33    fn fit_with(
34        &self,
35        model_in: Self::ObjectIn,
36        dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
37    ) -> Result<Self::ObjectOut> {
38        let mut model_out =
39            model_in.unwrap_or_else(|| Ftrl::new(self.clone(), dataset.nfeatures()));
40        let probabilities = model_out.predict_probabilities(dataset.records());
41        let gradient = calculate_gradient(probabilities.view(), dataset);
42        let sigma = model_out.calculate_sigma(gradient.view());
43        model_out.update_params(gradient, sigma);
44        Ok(model_out)
45    }
46}
47
48impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>> for Ftrl<F> {
49    /// Given an input matrix `X`, with shape `(n_samples, n_features)`,
50    /// `predict` returns the target variable according to the parameters
51    /// learned from the training data distribution.
52    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<Pr>) {
53        assert_eq!(
54            x.nrows(),
55            y.len(),
56            "The number of data points must match the number of output targets."
57        );
58
59        assert_eq!(
60            x.ncols(),
61            self.z.len(),
62            "Number of data features must match the number of features the model was trained with."
63        );
64
65        let probabilities = self.predict_probabilities(x);
66        Zip::from(&probabilities).and(y).for_each(|prob, out| {
67            *out = *prob;
68        });
69    }
70
71    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<Pr> {
72        Array1::zeros(x.nrows()).mapv(Pr::new)
73    }
74}
75
76/// View the fitted parameters and make predictions with a fitted
77/// follow the regularized leader -proximal, model
78impl<F: Float> Ftrl<F> {
79    /// Get Z values
80    pub fn z(&self) -> &Array1<F> {
81        &self.z
82    }
83
84    /// Get N values
85    pub fn n(&self) -> &Array1<F> {
86        &self.n
87    }
88
89    /// Get the alpha parameter
90    pub fn alpha(&self) -> F {
91        self.alpha
92    }
93
94    /// Get the beta parameter
95    pub fn beta(&self) -> F {
96        self.beta
97    }
98
99    /// Get the L1 regularization value
100    pub fn l1_ratio(&self) -> F {
101        self.l1_ratio
102    }
103
104    /// Get the L2 regularization value
105    pub fn l2_ratio(&self) -> F {
106        self.l2_ratio
107    }
108
109    /// Calculate weights for model prediction
110    pub fn get_weights(&self) -> Array1<F> {
111        Zip::from(self.z.view())
112            .and(self.n.view())
113            .map_collect(|z, n| {
114                apply_proximal_to_weights(
115                    *z,
116                    *n,
117                    self.alpha(),
118                    self.beta(),
119                    self.l1_ratio(),
120                    self.l2_ratio(),
121                )
122            })
123    }
124
125    /// Update method of the model hyperparameters in case of async mode.
126    /// To use this method, we must store probabilities and features for the row, and update once the result (bool) is observed.
127    pub fn update<D: Data<Elem = F>, T: AsSingleTargets<Elem = bool>>(
128        &mut self,
129        dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
130        probabilities: ArrayView1<Pr>,
131    ) {
132        let gradient = calculate_gradient(probabilities.view(), dataset);
133        let sigma = self.calculate_sigma(gradient.view());
134        self.update_params(gradient, sigma);
135    }
136
137    fn predict_probabilities<D: Data<Elem = F>>(&self, x: &ArrayBase<D, Ix2>) -> Array1<Pr> {
138        let weights = self.get_weights();
139        let mut probabilities = x.dot(&weights);
140        probabilities.mapv_inplace(stable_sigmoid);
141        probabilities.mapv(|v| Pr::new(F::to_f32(&v).unwrap_or_default()))
142    }
143
144    fn calculate_sigma(&self, gradients: ArrayView1<F>) -> Array1<F> {
145        Zip::from(&self.n)
146            .and(gradients)
147            .map_collect(|n, grad| calculate_weight_in_average(*n, *grad, self.alpha()))
148    }
149
150    fn update_params(&mut self, gradient: Array1<F>, sigma: Array1<F>) {
151        let weights = self.get_weights();
152        self.z += &gradient;
153        self.z -= &(sigma * weights);
154        self.n += &(&gradient * &gradient);
155    }
156}
157
158fn calculate_gradient<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = bool>>(
159    probabilities: ArrayView1<Pr>,
160    dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
161) -> Array1<F> {
162    let targets = dataset.as_single_targets();
163    let x = dataset.records();
164    let diff = Zip::from(&probabilities)
165        .and(targets)
166        .map_collect(|prob, y| {
167            let truth = if *y { F::one() } else { F::zero() };
168            F::cast(**prob) - truth
169        });
170    diff.dot(x)
171}
172
173fn calculate_weight_in_average<F: Float>(n: F, gradient: F, alpha: F) -> F {
174    (F::sqrt(n + gradient * gradient) - F::sqrt(n)) / alpha
175}
176
177/// Stable sigmoid uses branching for negative and positive values to avoid numerical overflow for float type data.
178fn stable_sigmoid<F: Float>(prediction: F) -> F {
179    let max_abs = F::cast(35.);
180    let prediction = prediction.min(max_abs).max(-max_abs);
181    if prediction.is_negative() {
182        negative_sigmoid(prediction)
183    } else {
184        positive_sigmoid(prediction)
185    }
186}
187
188fn positive_sigmoid<F: Float>(prediction: F) -> F {
189    F::one() / (F::one() + (-prediction).exp())
190}
191
192fn negative_sigmoid<F: Float>(prediction: F) -> F {
193    let exp = prediction.exp();
194    exp / (exp + F::one())
195}
196
197fn apply_proximal_to_weights<F: Float>(
198    z: F,
199    n: F,
200    alpha: F,
201    beta: F,
202    l1_ratio: F,
203    l2_ratio: F,
204) -> F {
205    let sign = z.signum();
206    if z * sign <= l1_ratio {
207        F::zero()
208    } else {
209        (sign * l1_ratio - z) / ((n.sqrt() + beta) / alpha + l2_ratio)
210    }
211}
212
213#[cfg(test)]
214mod test {
215    extern crate linfa;
216    use super::*;
217    use crate::algorithm::test::linfa::prelude::Predict;
218    use crate::FtrlParams;
219    use approx::assert_abs_diff_eq;
220    use linfa::{Dataset, ParamGuard};
221    use ndarray::array;
222    use rand::rngs::SmallRng;
223    use rand::SeedableRng;
224
225    #[test]
226    fn autotraits() {
227        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
228        has_autotraits::<FtrlParams<f64, rand_xoshiro::Xoshiro256Plus>>();
229        has_autotraits::<FtrlValidParams<f64, rand_xoshiro::Xoshiro256Plus>>();
230        has_autotraits::<Ftrl<f64>>();
231        has_autotraits::<FtrlError>();
232    }
233
234    #[test]
235    fn sigmoid_works() {
236        let value = 100.;
237        let result = stable_sigmoid(value);
238        assert!(result > 0.9)
239    }
240
241    #[test]
242    fn calculate_weights_with_zero_outcome_works() {
243        let z = 0.1;
244        let n = 0.1;
245        let alpha = 0.5;
246        let beta = 0.5;
247        let l1_ratio = 0.5;
248        let l2_ratio = 0.5;
249        let result = apply_proximal_to_weights(z, n, alpha, beta, l1_ratio, l2_ratio);
250        assert_abs_diff_eq!(result, 0.0)
251    }
252
253    #[test]
254    fn calculate_sigma_works() {
255        let gradient: f64 = 0.5;
256        let n: f64 = 0.11;
257        let alpha = 0.5;
258        let expected_result = ((0.11f64 + 0.25).sqrt() - 0.11f64.sqrt()) / 0.5;
259        let result = calculate_weight_in_average(n, gradient, alpha);
260        assert_abs_diff_eq!(result, expected_result)
261    }
262
263    #[test]
264    fn calculate_weights_works() {
265        let z = 0.5;
266        let n: f64 = 0.16;
267        let alpha = 0.5;
268        let beta = 0.5;
269        let l1_ratio = 0.1;
270        let l2_ratio = 0.5;
271        let expected_result = (0.1 - 0.5) / ((0.4 + 0.5) / 0.5 + 0.5);
272        let result = apply_proximal_to_weights(z, n, alpha, beta, l1_ratio, l2_ratio);
273        assert_abs_diff_eq!(result, expected_result)
274    }
275
276    #[test]
277    fn calculate_gradient_works() {
278        let probabilities = array![0.1, 0.3, 0.8].mapv(Pr::new);
279        let dataset = Dataset::new(
280            array![[0.0, 1.0], [2.0, 3.0], [1.0, 5.0]],
281            array![false, false, true],
282        );
283        let result = calculate_gradient(probabilities.view(), &dataset);
284        assert_abs_diff_eq!(result, array![0.4, 0.0], epsilon = 1e-1)
285    }
286
287    #[test]
288    fn update_params_works() {
289        let rng = SmallRng::seed_from_u64(42);
290        let probabilities = array![0.1, 0.3, 0.8].mapv(Pr::new);
291        let dataset = Dataset::new(
292            array![[0.0, 1.0], [2.0, 3.0], [1.0, 5.0]],
293            array![false, false, true],
294        );
295        let params = FtrlParams::default_with_rng(rng);
296        let valid_params = params.check().unwrap();
297        let mut model = Ftrl::new(valid_params.clone(), dataset.nfeatures());
298        let initial_z = model.z().clone();
299        let initial_n = model.n().clone();
300        let weights = model.get_weights();
301        let gradient = calculate_gradient(probabilities.view(), &dataset);
302        let sigma = model.calculate_sigma(gradient.view());
303        model.update_params(gradient.clone(), sigma.clone());
304        let expected_z = initial_z + &gradient - sigma * weights;
305        let expected_n = initial_n + &gradient.mapv(|grad: f64| grad.powf(2.));
306        assert_abs_diff_eq!(model.z(), &expected_z, epsilon = 1e-1);
307        assert_abs_diff_eq!(model.n(), &expected_n, epsilon = 1e-1)
308    }
309
310    #[test]
311    fn predict_probabilities_works() {
312        let rng = SmallRng::seed_from_u64(42);
313        let dataset = Dataset::new(
314            array![[0.0, 1.0], [2.0, 3.0], [1.0, 5.0]],
315            array![false, false, true],
316        );
317        let params = FtrlParams::default_with_rng(rng);
318        let valid_params = params.check().unwrap();
319        let model = Ftrl::new(valid_params.clone(), dataset.nfeatures());
320        let probabilities = model.predict_probabilities(dataset.records());
321        assert!(probabilities
322            .iter()
323            .all(|prob| **prob >= 0. && **prob <= 1.));
324    }
325
326    #[test]
327    fn update_works() {
328        let probabilities = array![0.5, 0.3, 0.7].mapv(Pr::new);
329        let dataset = Dataset::new(
330            array![[0.0, 1.0], [2.0, 3.0], [1.0, 5.0]],
331            array![false, false, true],
332        );
333
334        // Initialize model this way to control random z values
335        let mut model = Ftrl {
336            alpha: 0.005,
337            beta: 0.0,
338            l1_ratio: 0.5,
339            l2_ratio: 0.5,
340            z: array![0.5, 0.7],
341            n: array![0.0, 0.0],
342        };
343        model.update(&dataset, probabilities.view());
344        assert_abs_diff_eq!(model.n(), &array![0.09, 0.01], epsilon = 1e-2);
345        assert_abs_diff_eq!(model.z(), &array![0.8, 8.6], epsilon = 1e-2);
346    }
347
348    #[test]
349    fn ftrl_toy_example_works() {
350        let alpha = 0.1;
351        let beta = 0.0;
352        let regularization = 0.5;
353        let dataset = Dataset::new(
354            array![[-1.0], [-2.0], [10.0], [9.0]],
355            array![true, true, false, false],
356        );
357        let params = Ftrl::params()
358            .l2_ratio(regularization)
359            .l1_ratio(regularization)
360            .alpha(alpha)
361            .beta(beta);
362
363        // Initialize model this way to control random z values
364        let model = Ftrl {
365            alpha,
366            beta,
367            l1_ratio: regularization,
368            l2_ratio: regularization,
369            z: array![0.5],
370            n: array![0.],
371        };
372        let model = params.fit_with(Some(model), &dataset).unwrap();
373        let test_x = array![[11.0]];
374        assert_abs_diff_eq!(
375            model.predict(&test_x).mapv(|v| *v),
376            array![0.25],
377            epsilon = 1e-2
378        );
379    }
380
381    #[test]
382    fn ftrl_2d_toy_example_works() {
383        let alpha = 0.01;
384        let beta = 0.0;
385        let regularization = 0.5;
386        let dataset = Dataset::new(array![[0.0, -5.0], [10.0, 20.0]], array![true, false]);
387        let params = Ftrl::params()
388            .l2_ratio(regularization)
389            .l1_ratio(regularization)
390            .alpha(alpha)
391            .beta(beta);
392
393        // Initialize model this way to control random z values
394        let model = Ftrl {
395            alpha,
396            beta,
397            l1_ratio: regularization,
398            l2_ratio: regularization,
399            z: array![0.5, 0.5],
400            n: array![0.0, 0.0],
401        };
402        let model = params.fit_with(Some(model), &dataset).unwrap();
403        let test_x = array![[-4.0, -10.0], [15.0, 25.0]];
404        assert_abs_diff_eq!(
405            model.predict(&test_x).mapv(|v| *v),
406            array![0.53, 0.401],
407            epsilon = 1e-2
408        );
409    }
410}