linfa/
metrics_regression.rs

1//! Common metrics for regression
2//!
3//! This module implements common comparison metrices for continuous variables.
4
5use crate::{
6    dataset::{AsMultiTargets, AsSingleTargets, DatasetBase},
7    error::{Error, Result},
8    Float,
9};
10use ndarray::prelude::*;
11use ndarray::Data;
12use std::ops::{Div, Sub};
13
14/// Regression metrices trait for single targets.
15///
16/// It is possible to compute the listed mectrics between two 1D arrays.
17/// To compare bi-dimensional arrays use [`MultiTargetRegression`].
18pub trait SingleTargetRegression<F: Float, T: AsSingleTargets<Elem = F>>:
19    AsSingleTargets<Elem = F>
20{
21    /// Maximal error between two continuous variables
22    fn max_error(&self, compare_to: &T) -> Result<F> {
23        let max_error = self
24            .as_single_targets()
25            .sub(&compare_to.as_single_targets())
26            .iter()
27            .map(|x| x.abs())
28            .fold(F::neg_infinity(), F::max);
29        Ok(max_error)
30    }
31    /// Mean error between two continuous variables
32    fn mean_absolute_error(&self, compare_to: &T) -> Result<F> {
33        self.as_single_targets()
34            .sub(&compare_to.as_single_targets())
35            .mapv_into(|x| x.abs())
36            .mean()
37            .ok_or(Error::NotEnoughSamples)
38    }
39
40    /// Mean squared error between two continuous variables
41    fn mean_squared_error(&self, compare_to: &T) -> Result<F> {
42        self.as_single_targets()
43            .sub(&compare_to.as_single_targets())
44            .mapv_into(|x| x * x)
45            .mean()
46            .ok_or(Error::NotEnoughSamples)
47    }
48
49    /// Mean squared log error between two continuous variables
50    fn mean_squared_log_error(&self, compare_to: &T) -> Result<F> {
51        self.as_single_targets()
52            .mapv(|x| (F::one() + x).ln())
53            .mean_squared_error(&compare_to.as_single_targets().mapv(|x| (F::one() + x).ln()))
54    }
55
56    /// Median absolute error between two continuous variables
57    fn median_absolute_error(&self, compare_to: &T) -> Result<F> {
58        let mut abs_error = self
59            .as_single_targets()
60            .sub(&compare_to.as_single_targets())
61            .mapv_into(|x| x.abs())
62            .to_vec();
63        abs_error.sort_by(|a, b| a.partial_cmp(b).unwrap());
64        let mid = abs_error.len() / 2;
65        if abs_error.len() % 2 == 0 {
66            Ok((abs_error[mid - 1] + abs_error[mid]) / F::cast(2.0))
67        } else {
68            Ok(abs_error[mid])
69        }
70    }
71
72    /// Mean absolute percentage error between two continuous variables
73    /// MAPE = 1/N * SUM(abs((y_hat - y) / y))
74    fn mean_absolute_percentage_error(&self, compare_to: &T) -> Result<F> {
75        self.as_single_targets()
76            .sub(&compare_to.as_single_targets())
77            .div(self.as_single_targets())
78            .mapv_into(|x| x.abs())
79            .mean()
80            .ok_or(Error::NotEnoughSamples)
81    }
82
83    /// R squared coefficient, is the proportion of the variance in the dependent variable that is
84    /// predictable from the independent variable
85    // r2 = 1 - sum((pred_i - y_i)^2)/sum((mean_y - y_i)^2)
86    // if the mean is of `compare_to`, then the denominator
87    // should compare `compare_to` and the mean, and not self and the mean
88    fn r2(&self, compare_to: &T) -> Result<F> {
89        let single_target_compare_to = compare_to.as_single_targets();
90        let mean = single_target_compare_to
91            .mean()
92            .ok_or(Error::NotEnoughSamples)?;
93
94        Ok(F::one()
95            - self
96                .as_single_targets()
97                .sub(&single_target_compare_to)
98                .mapv_into(|x| x * x)
99                .sum()
100                / (single_target_compare_to
101                    .mapv(|x| (x - mean) * (x - mean))
102                    .sum()
103                    + F::cast(1e-10)))
104    }
105
106    /// Same as R-Squared but with biased variance
107    fn explained_variance(&self, compare_to: &T) -> Result<F> {
108        let single_target_compare_to = compare_to.as_single_targets();
109        let diff = self.as_single_targets().sub(&single_target_compare_to);
110
111        let mean = single_target_compare_to
112            .mean()
113            .ok_or(Error::NotEnoughSamples)?;
114        let mean_error = diff.mean().ok_or(Error::NotEnoughSamples)?;
115
116        Ok(F::one()
117            - (diff.mapv_into(|x| x * x).sum() - mean_error)
118                / (single_target_compare_to
119                    .mapv(|x| (x - mean) * (x - mean))
120                    .sum()
121                    + F::cast(1e-10)))
122    }
123}
124
125impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>> SingleTargetRegression<F, T>
126    for ArrayBase<D, Ix1>
127{
128}
129
130impl<F: Float, T: AsSingleTargets<Elem = F>, T2: AsSingleTargets<Elem = F>, D: Data<Elem = F>>
131    SingleTargetRegression<F, T2> for DatasetBase<ArrayBase<D, Ix2>, T>
132{
133}
134
135/// Regression metrices trait for multiple targets.
136///
137/// It is possible to compute the listed mectrics between two 2D arrays.
138/// To compare single-dimensional arrays use [`SingleTargetRegression`].
139pub trait MultiTargetRegression<F: Float, T: AsMultiTargets<Elem = F>>:
140    AsMultiTargets<Elem = F>
141{
142    /// Maximal error between two continuous variables
143    fn max_error(&self, other: &T) -> Result<Array1<F>> {
144        self.as_multi_targets()
145            .axis_iter(Axis(1))
146            .zip(other.as_multi_targets().axis_iter(Axis(1)))
147            .map(|(a, b)| a.max_error(&b))
148            .collect()
149    }
150    /// Mean error between two continuous variables
151    fn mean_absolute_error(&self, other: &T) -> Result<Array1<F>> {
152        self.as_multi_targets()
153            .axis_iter(Axis(1))
154            .zip(other.as_multi_targets().axis_iter(Axis(1)))
155            .map(|(a, b)| a.mean_absolute_error(&b))
156            .collect()
157    }
158
159    /// Mean squared error between two continuous variables
160    fn mean_squared_error(&self, other: &T) -> Result<Array1<F>> {
161        self.as_multi_targets()
162            .axis_iter(Axis(1))
163            .zip(other.as_multi_targets().axis_iter(Axis(1)))
164            .map(|(a, b)| a.mean_squared_error(&b))
165            .collect()
166    }
167
168    /// Mean squared log error between two continuous variables
169    fn mean_squared_log_error(&self, other: &T) -> Result<Array1<F>> {
170        self.as_multi_targets()
171            .axis_iter(Axis(1))
172            .zip(other.as_multi_targets().axis_iter(Axis(1)))
173            .map(|(a, b)| a.mean_squared_log_error(&b))
174            .collect()
175    }
176
177    /// Median absolute error between two continuous variables
178    fn median_absolute_error(&self, other: &T) -> Result<Array1<F>> {
179        self.as_multi_targets()
180            .axis_iter(Axis(1))
181            .zip(other.as_multi_targets().axis_iter(Axis(1)))
182            .map(|(a, b)| a.median_absolute_error(&b))
183            .collect()
184    }
185
186    /// Mean absolute percentage error between two continuous variables
187    /// MAPE = 1/N * SUM(abs((y_hat - y) / y))
188    fn mean_absolute_percentage_error(&self, other: &T) -> Result<Array1<F>> {
189        self.as_multi_targets()
190            .axis_iter(Axis(1))
191            .zip(other.as_multi_targets().axis_iter(Axis(1)))
192            .map(|(a, b)| a.mean_absolute_percentage_error(&b))
193            .collect()
194    }
195
196    /// R squared coefficient, is the proportion of the variance in the dependent variable that is
197    /// predictable from the independent variable
198    fn r2(&self, other: &T) -> Result<Array1<F>> {
199        self.as_multi_targets()
200            .axis_iter(Axis(1))
201            .zip(other.as_multi_targets().axis_iter(Axis(1)))
202            .map(|(a, b)| a.r2(&b))
203            .collect()
204    }
205
206    /// Same as R-Squared but with biased variance
207    fn explained_variance(&self, other: &T) -> Result<Array1<F>> {
208        self.as_multi_targets()
209            .axis_iter(Axis(1))
210            .zip(other.as_multi_targets().axis_iter(Axis(1)))
211            .map(|(a, b)| a.explained_variance(&b))
212            .collect()
213    }
214}
215
216impl<F: Float, D: Data<Elem = F>, T: AsMultiTargets<Elem = F>> MultiTargetRegression<F, T>
217    for ArrayBase<D, Ix2>
218{
219}
220
221impl<F: Float, T: AsMultiTargets<Elem = F>, T2: AsMultiTargets<Elem = F>, D: Data<Elem = F>>
222    MultiTargetRegression<F, T2> for DatasetBase<ArrayBase<D, Ix2>, T>
223{
224}
225
226#[cfg(test)]
227mod tests {
228    use super::SingleTargetRegression;
229    use crate::dataset::DatasetBase;
230    use approx::assert_abs_diff_eq;
231    use ndarray::prelude::*;
232
233    #[test]
234    fn test_same() {
235        let a: Array1<f32> = Array1::ones(100);
236
237        assert_abs_diff_eq!(a.max_error(&a).unwrap(), 0.0f32);
238        assert_abs_diff_eq!(a.mean_absolute_error(&a).unwrap(), 0.0f32);
239        assert_abs_diff_eq!(a.mean_squared_error(&a).unwrap(), 0.0f32);
240        assert_abs_diff_eq!(a.mean_squared_log_error(&a).unwrap(), 0.0f32);
241        assert_abs_diff_eq!(a.median_absolute_error(&a).unwrap(), 0.0f32);
242        assert_abs_diff_eq!(a.r2(&a).unwrap(), 1.0f32);
243        assert_abs_diff_eq!(a.explained_variance(&a).unwrap(), 1.0f32);
244        assert_abs_diff_eq!(a.mean_absolute_percentage_error(&a).unwrap(), 0.0f32);
245    }
246
247    #[test]
248    fn test_max_error() {
249        let a = array![0.0, 0.1, 0.2, 0.3, 0.4];
250        let b = array![0.1, 0.3, 0.2, 0.5, 0.7];
251
252        assert_abs_diff_eq!(a.max_error(&b).unwrap(), 0.3f32, epsilon = 1e-5);
253    }
254
255    #[test]
256    fn test_median_absolute_error() {
257        let a = array![0.0, 0.1, 0.2, 0.3, 0.4];
258        let b = array![0.1, 0.3, 0.2, 0.5, 0.7];
259        // 0.1, 0.2, 0.0, 0.2, 0.3 -> median error is 0.2
260
261        assert_abs_diff_eq!(a.median_absolute_error(&b).unwrap(), 0.2f32, epsilon = 1e-5);
262    }
263
264    #[test]
265    fn test_mean_squared_error() {
266        let a = array![0.0, 0.1, 0.2, 0.3, 0.4];
267        let b = array![0.1, 0.2, 0.3, 0.4, 0.5];
268
269        assert_abs_diff_eq!(a.mean_squared_error(&b).unwrap(), 0.01, epsilon = 1e-5);
270    }
271
272    #[test]
273    fn test_mean_absolute_percentage_error() {
274        let a = array![0.5, 0.1, 0.2, 0.3, 0.4];
275        let b = array![0.1, 0.2, 0.3, 0.4, 0.5];
276
277        assert_abs_diff_eq!(
278            a.mean_absolute_percentage_error(&b).unwrap(),
279            0.5766666666666667,
280            epsilon = 1e-5
281        );
282    }
283
284    #[test]
285    fn test_max_error_for_single_targets() {
286        let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
287        let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
288        let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
289        let prediction: Array1<f64> = array![0.1, 0.3, 0.2, 0.5, 0.7];
290        let abs_err_from_arr1 = prediction.max_error(st_dataset.targets()).unwrap();
291        let prediction: DatasetBase<_, _> = (records.view(), prediction.view()).into();
292        let abs_err_from_ds = prediction.max_error(st_dataset.targets()).unwrap();
293        assert_abs_diff_eq!(abs_err_from_arr1, 0.3);
294        assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
295    }
296
297    #[test]
298    fn test_mean_absolute_error_for_single_targets() {
299        let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
300        let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
301        let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
302        let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
303        let abs_err_from_arr1 = prediction.mean_absolute_error(&st_dataset).unwrap();
304        let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
305        let abs_err_from_ds = prediction
306            .mean_absolute_error(st_dataset.targets())
307            .unwrap();
308        assert_abs_diff_eq!(abs_err_from_arr1, 0.16);
309        assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
310    }
311
312    #[test]
313    fn test_mean_squared_error_for_single_targets() {
314        let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
315        let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
316        let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
317        let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
318        let abs_err_from_arr1 = prediction.mean_squared_error(st_dataset.targets()).unwrap();
319        let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
320        let abs_err_from_ds = prediction.mean_squared_error(st_dataset.targets()).unwrap();
321        assert_abs_diff_eq!(abs_err_from_arr1, 0.036);
322        assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
323    }
324
325    #[test]
326    fn test_mean_absolute_percentage_error_for_single_targets() {
327        let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
328        let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
329        let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
330        let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
331        let pct_err_from_arr1 = prediction
332            .mean_absolute_percentage_error(st_dataset.targets())
333            .unwrap();
334        let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
335        let pct_err_from_ds = prediction
336            .mean_absolute_percentage_error(st_dataset.targets())
337            .unwrap();
338        assert_abs_diff_eq!(pct_err_from_arr1, 0.49904761904761896);
339        assert_abs_diff_eq!(pct_err_from_arr1, pct_err_from_ds);
340    }
341
342    #[test]
343    fn test_mean_squared_log_error_for_single_targets() {
344        let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
345        let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
346        let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
347        let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
348        let abs_err_from_arr1 = prediction
349            .mean_squared_log_error(st_dataset.targets())
350            .unwrap();
351        let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
352        let abs_err_from_ds = prediction
353            .mean_squared_log_error(st_dataset.targets())
354            .unwrap();
355        assert_abs_diff_eq!(abs_err_from_arr1, 0.019_033, epsilon = 1e-5);
356        assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
357    }
358
359    #[test]
360    fn test_median_absolute_error_for_single_targets() {
361        let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3],];
362        let targets = array![0.0, 0.1, 0.2, 0.4];
363        let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
364        let prediction = array![0.1, 0.3, 0.2, 0.7];
365        // even length absolute errors
366        let abs_err_from_arr1 = prediction
367            .median_absolute_error(st_dataset.targets())
368            .unwrap();
369        let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
370        let abs_err_from_ds = prediction
371            .median_absolute_error(st_dataset.targets())
372            .unwrap();
373        assert_abs_diff_eq!(abs_err_from_arr1, 0.15, epsilon = 1e-5);
374        assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
375
376        // odd length absolute errors
377        let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
378        let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
379        let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
380        let prediction = array![0.1, 0.3, 0.2, 0.51, 0.7];
381        let abs_err_from_arr1 = prediction.median_absolute_error(&st_dataset).unwrap();
382        let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
383        let abs_err_from_ds = prediction.median_absolute_error(&st_dataset).unwrap();
384        assert_abs_diff_eq!(abs_err_from_arr1, 0.2, epsilon = 1e-5);
385        assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
386    }
387
388    #[test]
389    fn test_r2_for_single_targets() {
390        let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
391        let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
392        let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
393        let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
394        let abs_err_from_arr1 = prediction.r2(st_dataset.targets()).unwrap();
395        let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
396        let abs_err_from_ds = prediction.r2(st_dataset.targets()).unwrap();
397        assert_abs_diff_eq!(abs_err_from_arr1, -0.8, epsilon = 1e-5);
398        assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
399    }
400
401    #[test]
402    fn test_explained_variance_for_single_targets() {
403        let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
404        let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
405        let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
406        let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
407        let abs_err_from_arr1 = prediction.explained_variance(st_dataset.targets()).unwrap();
408        let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
409        let abs_err_from_ds = prediction.explained_variance(&st_dataset).unwrap();
410        assert_abs_diff_eq!(abs_err_from_arr1, 0.8, epsilon = 1e-5);
411        assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
412    }
413}