1use crate::{
6 dataset::{AsMultiTargets, AsSingleTargets, DatasetBase},
7 error::{Error, Result},
8 Float,
9};
10use ndarray::prelude::*;
11use ndarray::Data;
12use std::ops::{Div, Sub};
13
14pub trait SingleTargetRegression<F: Float, T: AsSingleTargets<Elem = F>>:
19 AsSingleTargets<Elem = F>
20{
21 fn max_error(&self, compare_to: &T) -> Result<F> {
23 let max_error = self
24 .as_single_targets()
25 .sub(&compare_to.as_single_targets())
26 .iter()
27 .map(|x| x.abs())
28 .fold(F::neg_infinity(), F::max);
29 Ok(max_error)
30 }
31 fn mean_absolute_error(&self, compare_to: &T) -> Result<F> {
33 self.as_single_targets()
34 .sub(&compare_to.as_single_targets())
35 .mapv_into(|x| x.abs())
36 .mean()
37 .ok_or(Error::NotEnoughSamples)
38 }
39
40 fn mean_squared_error(&self, compare_to: &T) -> Result<F> {
42 self.as_single_targets()
43 .sub(&compare_to.as_single_targets())
44 .mapv_into(|x| x * x)
45 .mean()
46 .ok_or(Error::NotEnoughSamples)
47 }
48
49 fn mean_squared_log_error(&self, compare_to: &T) -> Result<F> {
51 self.as_single_targets()
52 .mapv(|x| (F::one() + x).ln())
53 .mean_squared_error(&compare_to.as_single_targets().mapv(|x| (F::one() + x).ln()))
54 }
55
56 fn median_absolute_error(&self, compare_to: &T) -> Result<F> {
58 let mut abs_error = self
59 .as_single_targets()
60 .sub(&compare_to.as_single_targets())
61 .mapv_into(|x| x.abs())
62 .to_vec();
63 abs_error.sort_by(|a, b| a.partial_cmp(b).unwrap());
64 let mid = abs_error.len() / 2;
65 if abs_error.len() % 2 == 0 {
66 Ok((abs_error[mid - 1] + abs_error[mid]) / F::cast(2.0))
67 } else {
68 Ok(abs_error[mid])
69 }
70 }
71
72 fn mean_absolute_percentage_error(&self, compare_to: &T) -> Result<F> {
75 self.as_single_targets()
76 .sub(&compare_to.as_single_targets())
77 .div(self.as_single_targets())
78 .mapv_into(|x| x.abs())
79 .mean()
80 .ok_or(Error::NotEnoughSamples)
81 }
82
83 fn r2(&self, compare_to: &T) -> Result<F> {
89 let single_target_compare_to = compare_to.as_single_targets();
90 let mean = single_target_compare_to
91 .mean()
92 .ok_or(Error::NotEnoughSamples)?;
93
94 Ok(F::one()
95 - self
96 .as_single_targets()
97 .sub(&single_target_compare_to)
98 .mapv_into(|x| x * x)
99 .sum()
100 / (single_target_compare_to
101 .mapv(|x| (x - mean) * (x - mean))
102 .sum()
103 + F::cast(1e-10)))
104 }
105
106 fn explained_variance(&self, compare_to: &T) -> Result<F> {
108 let single_target_compare_to = compare_to.as_single_targets();
109 let diff = self.as_single_targets().sub(&single_target_compare_to);
110
111 let mean = single_target_compare_to
112 .mean()
113 .ok_or(Error::NotEnoughSamples)?;
114 let mean_error = diff.mean().ok_or(Error::NotEnoughSamples)?;
115
116 Ok(F::one()
117 - (diff.mapv_into(|x| x * x).sum() - mean_error)
118 / (single_target_compare_to
119 .mapv(|x| (x - mean) * (x - mean))
120 .sum()
121 + F::cast(1e-10)))
122 }
123}
124
125impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>> SingleTargetRegression<F, T>
126 for ArrayBase<D, Ix1>
127{
128}
129
130impl<F: Float, T: AsSingleTargets<Elem = F>, T2: AsSingleTargets<Elem = F>, D: Data<Elem = F>>
131 SingleTargetRegression<F, T2> for DatasetBase<ArrayBase<D, Ix2>, T>
132{
133}
134
135pub trait MultiTargetRegression<F: Float, T: AsMultiTargets<Elem = F>>:
140 AsMultiTargets<Elem = F>
141{
142 fn max_error(&self, other: &T) -> Result<Array1<F>> {
144 self.as_multi_targets()
145 .axis_iter(Axis(1))
146 .zip(other.as_multi_targets().axis_iter(Axis(1)))
147 .map(|(a, b)| a.max_error(&b))
148 .collect()
149 }
150 fn mean_absolute_error(&self, other: &T) -> Result<Array1<F>> {
152 self.as_multi_targets()
153 .axis_iter(Axis(1))
154 .zip(other.as_multi_targets().axis_iter(Axis(1)))
155 .map(|(a, b)| a.mean_absolute_error(&b))
156 .collect()
157 }
158
159 fn mean_squared_error(&self, other: &T) -> Result<Array1<F>> {
161 self.as_multi_targets()
162 .axis_iter(Axis(1))
163 .zip(other.as_multi_targets().axis_iter(Axis(1)))
164 .map(|(a, b)| a.mean_squared_error(&b))
165 .collect()
166 }
167
168 fn mean_squared_log_error(&self, other: &T) -> Result<Array1<F>> {
170 self.as_multi_targets()
171 .axis_iter(Axis(1))
172 .zip(other.as_multi_targets().axis_iter(Axis(1)))
173 .map(|(a, b)| a.mean_squared_log_error(&b))
174 .collect()
175 }
176
177 fn median_absolute_error(&self, other: &T) -> Result<Array1<F>> {
179 self.as_multi_targets()
180 .axis_iter(Axis(1))
181 .zip(other.as_multi_targets().axis_iter(Axis(1)))
182 .map(|(a, b)| a.median_absolute_error(&b))
183 .collect()
184 }
185
186 fn mean_absolute_percentage_error(&self, other: &T) -> Result<Array1<F>> {
189 self.as_multi_targets()
190 .axis_iter(Axis(1))
191 .zip(other.as_multi_targets().axis_iter(Axis(1)))
192 .map(|(a, b)| a.mean_absolute_percentage_error(&b))
193 .collect()
194 }
195
196 fn r2(&self, other: &T) -> Result<Array1<F>> {
199 self.as_multi_targets()
200 .axis_iter(Axis(1))
201 .zip(other.as_multi_targets().axis_iter(Axis(1)))
202 .map(|(a, b)| a.r2(&b))
203 .collect()
204 }
205
206 fn explained_variance(&self, other: &T) -> Result<Array1<F>> {
208 self.as_multi_targets()
209 .axis_iter(Axis(1))
210 .zip(other.as_multi_targets().axis_iter(Axis(1)))
211 .map(|(a, b)| a.explained_variance(&b))
212 .collect()
213 }
214}
215
216impl<F: Float, D: Data<Elem = F>, T: AsMultiTargets<Elem = F>> MultiTargetRegression<F, T>
217 for ArrayBase<D, Ix2>
218{
219}
220
221impl<F: Float, T: AsMultiTargets<Elem = F>, T2: AsMultiTargets<Elem = F>, D: Data<Elem = F>>
222 MultiTargetRegression<F, T2> for DatasetBase<ArrayBase<D, Ix2>, T>
223{
224}
225
226#[cfg(test)]
227mod tests {
228 use super::SingleTargetRegression;
229 use crate::dataset::DatasetBase;
230 use approx::assert_abs_diff_eq;
231 use ndarray::prelude::*;
232
233 #[test]
234 fn test_same() {
235 let a: Array1<f32> = Array1::ones(100);
236
237 assert_abs_diff_eq!(a.max_error(&a).unwrap(), 0.0f32);
238 assert_abs_diff_eq!(a.mean_absolute_error(&a).unwrap(), 0.0f32);
239 assert_abs_diff_eq!(a.mean_squared_error(&a).unwrap(), 0.0f32);
240 assert_abs_diff_eq!(a.mean_squared_log_error(&a).unwrap(), 0.0f32);
241 assert_abs_diff_eq!(a.median_absolute_error(&a).unwrap(), 0.0f32);
242 assert_abs_diff_eq!(a.r2(&a).unwrap(), 1.0f32);
243 assert_abs_diff_eq!(a.explained_variance(&a).unwrap(), 1.0f32);
244 assert_abs_diff_eq!(a.mean_absolute_percentage_error(&a).unwrap(), 0.0f32);
245 }
246
247 #[test]
248 fn test_max_error() {
249 let a = array![0.0, 0.1, 0.2, 0.3, 0.4];
250 let b = array![0.1, 0.3, 0.2, 0.5, 0.7];
251
252 assert_abs_diff_eq!(a.max_error(&b).unwrap(), 0.3f32, epsilon = 1e-5);
253 }
254
255 #[test]
256 fn test_median_absolute_error() {
257 let a = array![0.0, 0.1, 0.2, 0.3, 0.4];
258 let b = array![0.1, 0.3, 0.2, 0.5, 0.7];
259 assert_abs_diff_eq!(a.median_absolute_error(&b).unwrap(), 0.2f32, epsilon = 1e-5);
262 }
263
264 #[test]
265 fn test_mean_squared_error() {
266 let a = array![0.0, 0.1, 0.2, 0.3, 0.4];
267 let b = array![0.1, 0.2, 0.3, 0.4, 0.5];
268
269 assert_abs_diff_eq!(a.mean_squared_error(&b).unwrap(), 0.01, epsilon = 1e-5);
270 }
271
272 #[test]
273 fn test_mean_absolute_percentage_error() {
274 let a = array![0.5, 0.1, 0.2, 0.3, 0.4];
275 let b = array![0.1, 0.2, 0.3, 0.4, 0.5];
276
277 assert_abs_diff_eq!(
278 a.mean_absolute_percentage_error(&b).unwrap(),
279 0.5766666666666667,
280 epsilon = 1e-5
281 );
282 }
283
284 #[test]
285 fn test_max_error_for_single_targets() {
286 let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
287 let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
288 let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
289 let prediction: Array1<f64> = array![0.1, 0.3, 0.2, 0.5, 0.7];
290 let abs_err_from_arr1 = prediction.max_error(st_dataset.targets()).unwrap();
291 let prediction: DatasetBase<_, _> = (records.view(), prediction.view()).into();
292 let abs_err_from_ds = prediction.max_error(st_dataset.targets()).unwrap();
293 assert_abs_diff_eq!(abs_err_from_arr1, 0.3);
294 assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
295 }
296
297 #[test]
298 fn test_mean_absolute_error_for_single_targets() {
299 let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
300 let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
301 let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
302 let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
303 let abs_err_from_arr1 = prediction.mean_absolute_error(&st_dataset).unwrap();
304 let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
305 let abs_err_from_ds = prediction
306 .mean_absolute_error(st_dataset.targets())
307 .unwrap();
308 assert_abs_diff_eq!(abs_err_from_arr1, 0.16);
309 assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
310 }
311
312 #[test]
313 fn test_mean_squared_error_for_single_targets() {
314 let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
315 let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
316 let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
317 let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
318 let abs_err_from_arr1 = prediction.mean_squared_error(st_dataset.targets()).unwrap();
319 let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
320 let abs_err_from_ds = prediction.mean_squared_error(st_dataset.targets()).unwrap();
321 assert_abs_diff_eq!(abs_err_from_arr1, 0.036);
322 assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
323 }
324
325 #[test]
326 fn test_mean_absolute_percentage_error_for_single_targets() {
327 let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
328 let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
329 let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
330 let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
331 let pct_err_from_arr1 = prediction
332 .mean_absolute_percentage_error(st_dataset.targets())
333 .unwrap();
334 let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
335 let pct_err_from_ds = prediction
336 .mean_absolute_percentage_error(st_dataset.targets())
337 .unwrap();
338 assert_abs_diff_eq!(pct_err_from_arr1, 0.49904761904761896);
339 assert_abs_diff_eq!(pct_err_from_arr1, pct_err_from_ds);
340 }
341
342 #[test]
343 fn test_mean_squared_log_error_for_single_targets() {
344 let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
345 let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
346 let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
347 let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
348 let abs_err_from_arr1 = prediction
349 .mean_squared_log_error(st_dataset.targets())
350 .unwrap();
351 let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
352 let abs_err_from_ds = prediction
353 .mean_squared_log_error(st_dataset.targets())
354 .unwrap();
355 assert_abs_diff_eq!(abs_err_from_arr1, 0.019_033, epsilon = 1e-5);
356 assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
357 }
358
359 #[test]
360 fn test_median_absolute_error_for_single_targets() {
361 let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3],];
362 let targets = array![0.0, 0.1, 0.2, 0.4];
363 let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
364 let prediction = array![0.1, 0.3, 0.2, 0.7];
365 let abs_err_from_arr1 = prediction
367 .median_absolute_error(st_dataset.targets())
368 .unwrap();
369 let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
370 let abs_err_from_ds = prediction
371 .median_absolute_error(st_dataset.targets())
372 .unwrap();
373 assert_abs_diff_eq!(abs_err_from_arr1, 0.15, epsilon = 1e-5);
374 assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
375
376 let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
378 let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
379 let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
380 let prediction = array![0.1, 0.3, 0.2, 0.51, 0.7];
381 let abs_err_from_arr1 = prediction.median_absolute_error(&st_dataset).unwrap();
382 let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
383 let abs_err_from_ds = prediction.median_absolute_error(&st_dataset).unwrap();
384 assert_abs_diff_eq!(abs_err_from_arr1, 0.2, epsilon = 1e-5);
385 assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
386 }
387
388 #[test]
389 fn test_r2_for_single_targets() {
390 let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
391 let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
392 let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
393 let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
394 let abs_err_from_arr1 = prediction.r2(st_dataset.targets()).unwrap();
395 let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
396 let abs_err_from_ds = prediction.r2(st_dataset.targets()).unwrap();
397 assert_abs_diff_eq!(abs_err_from_arr1, -0.8, epsilon = 1e-5);
398 assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
399 }
400
401 #[test]
402 fn test_explained_variance_for_single_targets() {
403 let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
404 let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
405 let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
406 let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
407 let abs_err_from_arr1 = prediction.explained_variance(st_dataset.targets()).unwrap();
408 let prediction: DatasetBase<_, _> = (records.view(), prediction).into();
409 let abs_err_from_ds = prediction.explained_variance(&st_dataset).unwrap();
410 assert_abs_diff_eq!(abs_err_from_arr1, 0.8, epsilon = 1e-5);
411 assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds);
412 }
413}