1use crate::error::FtrlError;
2use crate::hyperparams::FtrlValidParams;
3use crate::Ftrl;
4use linfa::dataset::{AsSingleTargets, Pr, Records};
5use linfa::traits::{FitWith, PredictInplace};
6use linfa::{DatasetBase, Float};
7use ndarray::{Array1, ArrayBase, ArrayView1, Data, Ix2, Zip};
8use rand::Rng;
9
10pub type Result<T> = std::result::Result<T, FtrlError>;
12
13impl<F, R, D, T> FitWith<'_, ArrayBase<D, Ix2>, T, FtrlError> for FtrlValidParams<F, R>
14where
15 F: Float,
16 R: Rng + Clone,
17 D: Data<Elem = F>,
18 T: AsSingleTargets<Elem = bool>,
19{
20 type ObjectIn = Option<Ftrl<F>>;
21 type ObjectOut = Ftrl<F>;
22
23 fn fit_with(
34 &self,
35 model_in: Self::ObjectIn,
36 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
37 ) -> Result<Self::ObjectOut> {
38 let mut model_out =
39 model_in.unwrap_or_else(|| Ftrl::new(self.clone(), dataset.nfeatures()));
40 let probabilities = model_out.predict_probabilities(dataset.records());
41 let gradient = calculate_gradient(probabilities.view(), dataset);
42 let sigma = model_out.calculate_sigma(gradient.view());
43 model_out.update_params(gradient, sigma);
44 Ok(model_out)
45 }
46}
47
48impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>> for Ftrl<F> {
49 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<Pr>) {
53 assert_eq!(
54 x.nrows(),
55 y.len(),
56 "The number of data points must match the number of output targets."
57 );
58
59 assert_eq!(
60 x.ncols(),
61 self.z.len(),
62 "Number of data features must match the number of features the model was trained with."
63 );
64
65 let probabilities = self.predict_probabilities(x);
66 Zip::from(&probabilities).and(y).for_each(|prob, out| {
67 *out = *prob;
68 });
69 }
70
71 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<Pr> {
72 Array1::zeros(x.nrows()).mapv(Pr::new)
73 }
74}
75
76impl<F: Float> Ftrl<F> {
79 pub fn z(&self) -> &Array1<F> {
81 &self.z
82 }
83
84 pub fn n(&self) -> &Array1<F> {
86 &self.n
87 }
88
89 pub fn alpha(&self) -> F {
91 self.alpha
92 }
93
94 pub fn beta(&self) -> F {
96 self.beta
97 }
98
99 pub fn l1_ratio(&self) -> F {
101 self.l1_ratio
102 }
103
104 pub fn l2_ratio(&self) -> F {
106 self.l2_ratio
107 }
108
109 pub fn get_weights(&self) -> Array1<F> {
111 Zip::from(self.z.view())
112 .and(self.n.view())
113 .map_collect(|z, n| {
114 apply_proximal_to_weights(
115 *z,
116 *n,
117 self.alpha(),
118 self.beta(),
119 self.l1_ratio(),
120 self.l2_ratio(),
121 )
122 })
123 }
124
125 pub fn update<D: Data<Elem = F>, T: AsSingleTargets<Elem = bool>>(
128 &mut self,
129 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
130 probabilities: ArrayView1<Pr>,
131 ) {
132 let gradient = calculate_gradient(probabilities.view(), dataset);
133 let sigma = self.calculate_sigma(gradient.view());
134 self.update_params(gradient, sigma);
135 }
136
137 fn predict_probabilities<D: Data<Elem = F>>(&self, x: &ArrayBase<D, Ix2>) -> Array1<Pr> {
138 let weights = self.get_weights();
139 let mut probabilities = x.dot(&weights);
140 probabilities.mapv_inplace(stable_sigmoid);
141 probabilities.mapv(|v| Pr::new(F::to_f32(&v).unwrap_or_default()))
142 }
143
144 fn calculate_sigma(&self, gradients: ArrayView1<F>) -> Array1<F> {
145 Zip::from(&self.n)
146 .and(gradients)
147 .map_collect(|n, grad| calculate_weight_in_average(*n, *grad, self.alpha()))
148 }
149
150 fn update_params(&mut self, gradient: Array1<F>, sigma: Array1<F>) {
151 let weights = self.get_weights();
152 self.z += &gradient;
153 self.z -= &(sigma * weights);
154 self.n += &(&gradient * &gradient);
155 }
156}
157
158fn calculate_gradient<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = bool>>(
159 probabilities: ArrayView1<Pr>,
160 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
161) -> Array1<F> {
162 let targets = dataset.as_single_targets();
163 let x = dataset.records();
164 let diff = Zip::from(&probabilities)
165 .and(targets)
166 .map_collect(|prob, y| {
167 let truth = if *y { F::one() } else { F::zero() };
168 F::cast(**prob) - truth
169 });
170 diff.dot(x)
171}
172
173fn calculate_weight_in_average<F: Float>(n: F, gradient: F, alpha: F) -> F {
174 (F::sqrt(n + gradient * gradient) - F::sqrt(n)) / alpha
175}
176
177fn stable_sigmoid<F: Float>(prediction: F) -> F {
179 let max_abs = F::cast(35.);
180 let prediction = prediction.min(max_abs).max(-max_abs);
181 if prediction.is_negative() {
182 negative_sigmoid(prediction)
183 } else {
184 positive_sigmoid(prediction)
185 }
186}
187
188fn positive_sigmoid<F: Float>(prediction: F) -> F {
189 F::one() / (F::one() + (-prediction).exp())
190}
191
192fn negative_sigmoid<F: Float>(prediction: F) -> F {
193 let exp = prediction.exp();
194 exp / (exp + F::one())
195}
196
197fn apply_proximal_to_weights<F: Float>(
198 z: F,
199 n: F,
200 alpha: F,
201 beta: F,
202 l1_ratio: F,
203 l2_ratio: F,
204) -> F {
205 let sign = z.signum();
206 if z * sign <= l1_ratio {
207 F::zero()
208 } else {
209 (sign * l1_ratio - z) / ((n.sqrt() + beta) / alpha + l2_ratio)
210 }
211}
212
213#[cfg(test)]
214mod test {
215 extern crate linfa;
216 use super::*;
217 use crate::algorithm::test::linfa::prelude::Predict;
218 use crate::FtrlParams;
219 use approx::assert_abs_diff_eq;
220 use linfa::{Dataset, ParamGuard};
221 use ndarray::array;
222 use rand::rngs::SmallRng;
223 use rand::SeedableRng;
224
225 #[test]
226 fn autotraits() {
227 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
228 has_autotraits::<FtrlParams<f64, rand_xoshiro::Xoshiro256Plus>>();
229 has_autotraits::<FtrlValidParams<f64, rand_xoshiro::Xoshiro256Plus>>();
230 has_autotraits::<Ftrl<f64>>();
231 has_autotraits::<FtrlError>();
232 }
233
234 #[test]
235 fn sigmoid_works() {
236 let value = 100.;
237 let result = stable_sigmoid(value);
238 assert!(result > 0.9)
239 }
240
241 #[test]
242 fn calculate_weights_with_zero_outcome_works() {
243 let z = 0.1;
244 let n = 0.1;
245 let alpha = 0.5;
246 let beta = 0.5;
247 let l1_ratio = 0.5;
248 let l2_ratio = 0.5;
249 let result = apply_proximal_to_weights(z, n, alpha, beta, l1_ratio, l2_ratio);
250 assert_abs_diff_eq!(result, 0.0)
251 }
252
253 #[test]
254 fn calculate_sigma_works() {
255 let gradient: f64 = 0.5;
256 let n: f64 = 0.11;
257 let alpha = 0.5;
258 let expected_result = ((0.11f64 + 0.25).sqrt() - 0.11f64.sqrt()) / 0.5;
259 let result = calculate_weight_in_average(n, gradient, alpha);
260 assert_abs_diff_eq!(result, expected_result)
261 }
262
263 #[test]
264 fn calculate_weights_works() {
265 let z = 0.5;
266 let n: f64 = 0.16;
267 let alpha = 0.5;
268 let beta = 0.5;
269 let l1_ratio = 0.1;
270 let l2_ratio = 0.5;
271 let expected_result = (0.1 - 0.5) / ((0.4 + 0.5) / 0.5 + 0.5);
272 let result = apply_proximal_to_weights(z, n, alpha, beta, l1_ratio, l2_ratio);
273 assert_abs_diff_eq!(result, expected_result)
274 }
275
276 #[test]
277 fn calculate_gradient_works() {
278 let probabilities = array![0.1, 0.3, 0.8].mapv(Pr::new);
279 let dataset = Dataset::new(
280 array![[0.0, 1.0], [2.0, 3.0], [1.0, 5.0]],
281 array![false, false, true],
282 );
283 let result = calculate_gradient(probabilities.view(), &dataset);
284 assert_abs_diff_eq!(result, array![0.4, 0.0], epsilon = 1e-1)
285 }
286
287 #[test]
288 fn update_params_works() {
289 let rng = SmallRng::seed_from_u64(42);
290 let probabilities = array![0.1, 0.3, 0.8].mapv(Pr::new);
291 let dataset = Dataset::new(
292 array![[0.0, 1.0], [2.0, 3.0], [1.0, 5.0]],
293 array![false, false, true],
294 );
295 let params = FtrlParams::default_with_rng(rng);
296 let valid_params = params.check().unwrap();
297 let mut model = Ftrl::new(valid_params.clone(), dataset.nfeatures());
298 let initial_z = model.z().clone();
299 let initial_n = model.n().clone();
300 let weights = model.get_weights();
301 let gradient = calculate_gradient(probabilities.view(), &dataset);
302 let sigma = model.calculate_sigma(gradient.view());
303 model.update_params(gradient.clone(), sigma.clone());
304 let expected_z = initial_z + &gradient - sigma * weights;
305 let expected_n = initial_n + &gradient.mapv(|grad: f64| grad.powf(2.));
306 assert_abs_diff_eq!(model.z(), &expected_z, epsilon = 1e-1);
307 assert_abs_diff_eq!(model.n(), &expected_n, epsilon = 1e-1)
308 }
309
310 #[test]
311 fn predict_probabilities_works() {
312 let rng = SmallRng::seed_from_u64(42);
313 let dataset = Dataset::new(
314 array![[0.0, 1.0], [2.0, 3.0], [1.0, 5.0]],
315 array![false, false, true],
316 );
317 let params = FtrlParams::default_with_rng(rng);
318 let valid_params = params.check().unwrap();
319 let model = Ftrl::new(valid_params.clone(), dataset.nfeatures());
320 let probabilities = model.predict_probabilities(dataset.records());
321 assert!(probabilities
322 .iter()
323 .all(|prob| **prob >= 0. && **prob <= 1.));
324 }
325
326 #[test]
327 fn update_works() {
328 let probabilities = array![0.5, 0.3, 0.7].mapv(Pr::new);
329 let dataset = Dataset::new(
330 array![[0.0, 1.0], [2.0, 3.0], [1.0, 5.0]],
331 array![false, false, true],
332 );
333
334 let mut model = Ftrl {
336 alpha: 0.005,
337 beta: 0.0,
338 l1_ratio: 0.5,
339 l2_ratio: 0.5,
340 z: array![0.5, 0.7],
341 n: array![0.0, 0.0],
342 };
343 model.update(&dataset, probabilities.view());
344 assert_abs_diff_eq!(model.n(), &array![0.09, 0.01], epsilon = 1e-2);
345 assert_abs_diff_eq!(model.z(), &array![0.8, 8.6], epsilon = 1e-2);
346 }
347
348 #[test]
349 fn ftrl_toy_example_works() {
350 let alpha = 0.1;
351 let beta = 0.0;
352 let regularization = 0.5;
353 let dataset = Dataset::new(
354 array![[-1.0], [-2.0], [10.0], [9.0]],
355 array![true, true, false, false],
356 );
357 let params = Ftrl::params()
358 .l2_ratio(regularization)
359 .l1_ratio(regularization)
360 .alpha(alpha)
361 .beta(beta);
362
363 let model = Ftrl {
365 alpha,
366 beta,
367 l1_ratio: regularization,
368 l2_ratio: regularization,
369 z: array![0.5],
370 n: array![0.],
371 };
372 let model = params.fit_with(Some(model), &dataset).unwrap();
373 let test_x = array![[11.0]];
374 assert_abs_diff_eq!(
375 model.predict(&test_x).mapv(|v| *v),
376 array![0.25],
377 epsilon = 1e-2
378 );
379 }
380
381 #[test]
382 fn ftrl_2d_toy_example_works() {
383 let alpha = 0.01;
384 let beta = 0.0;
385 let regularization = 0.5;
386 let dataset = Dataset::new(array![[0.0, -5.0], [10.0, 20.0]], array![true, false]);
387 let params = Ftrl::params()
388 .l2_ratio(regularization)
389 .l1_ratio(regularization)
390 .alpha(alpha)
391 .beta(beta);
392
393 let model = Ftrl {
395 alpha,
396 beta,
397 l1_ratio: regularization,
398 l2_ratio: regularization,
399 z: array![0.5, 0.5],
400 n: array![0.0, 0.0],
401 };
402 let model = params.fit_with(Some(model), &dataset).unwrap();
403 let test_x = array![[-4.0, -10.0], [15.0, 25.0]];
404 assert_abs_diff_eq!(
405 model.predict(&test_x).mapv(|v| *v),
406 array![0.53, 0.401],
407 epsilon = 1e-2
408 );
409 }
410}