1use crate::error::{PreprocessingError, Result};
4use approx::abs_diff_eq;
5use linfa::dataset::{AsTargets, DatasetBase, Float, WithLapack};
6use linfa::traits::{Fit, Transformer};
7#[cfg(not(feature = "blas"))]
8use linfa_linalg::norm::Norm;
9use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2, Zip};
10#[cfg(feature = "blas")]
11use ndarray_linalg::norm::Norm;
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16#[cfg_attr(
17 feature = "serde",
18 derive(Serialize, Deserialize),
19 serde(crate = "serde_crate")
20)]
21#[derive(Clone, Debug, PartialEq, Eq)]
22pub enum ScalingMethod<F: Float> {
28 Standard(bool, bool),
29 MinMax(F, F),
30 MaxAbs,
31}
32
33impl<F: Float> ScalingMethod<F> {
34 pub(crate) fn fit<D: Data<Elem = F>>(
35 &self,
36 records: &ArrayBase<D, Ix2>,
37 ) -> Result<LinearScaler<F>> {
38 match self {
39 ScalingMethod::Standard(a, b) => Self::standardize(records, *a, *b),
40 ScalingMethod::MinMax(a, b) => Self::min_max(records, *a, *b),
41 ScalingMethod::MaxAbs => Self::max_abs(records),
42 }
43 }
44
45 fn standardize<D: Data<Elem = F>>(
46 records: &ArrayBase<D, Ix2>,
47 with_mean: bool,
48 with_std: bool,
49 ) -> Result<LinearScaler<F>> {
50 if records.dim().0 == 0 {
51 return Err(PreprocessingError::NotEnoughSamples);
52 }
53 let means = records.mean_axis(Axis(0)).unwrap();
55 let std_devs = if with_std {
56 records.std_axis(Axis(0), F::zero()).mapv(|s| {
57 if abs_diff_eq!(s, F::zero()) {
58 F::one()
60 } else {
61 F::one() / s
62 }
63 })
64 } else {
65 Array1::ones(records.dim().1)
66 };
67 Ok(LinearScaler {
68 offsets: means,
69 scales: std_devs,
70 method: ScalingMethod::Standard(with_mean, with_std),
71 })
72 }
73
74 fn min_max<D: Data<Elem = F>>(
75 records: &ArrayBase<D, Ix2>,
76 min: F,
77 max: F,
78 ) -> Result<LinearScaler<F>> {
79 if records.dim().0 == 0 {
80 return Err(PreprocessingError::NotEnoughSamples);
81 } else if min > max {
82 return Err(PreprocessingError::FlippedMinMaxRange);
83 }
84
85 let mins = records.fold_axis(
86 Axis(0),
87 F::infinity(),
88 |&x, &prev| if x < prev { x } else { prev },
89 );
90 let mut scales =
91 records.fold_axis(
92 Axis(0),
93 F::neg_infinity(),
94 |&x, &prev| if x > prev { x } else { prev },
95 );
96 Zip::from(&mut scales).and(&mins).for_each(|max, min| {
97 if abs_diff_eq!(*max - *min, F::zero()) {
98 *max = F::one();
100 } else {
101 *max = F::one() / (*max - *min);
102 }
103 });
104 Ok(LinearScaler {
105 offsets: mins,
106 scales,
107 method: ScalingMethod::MinMax(min, max),
108 })
109 }
110
111 fn max_abs<D: Data<Elem = F>>(records: &ArrayBase<D, Ix2>) -> Result<LinearScaler<F>> {
112 if records.dim().0 == 0 {
113 return Err(PreprocessingError::NotEnoughSamples);
114 }
115 let scales: Array1<F> = records.map_axis(Axis(0), |col| {
116 let norm_max = F::cast(col.with_lapack().norm_max());
117 if abs_diff_eq!(norm_max, F::zero()) {
118 F::one()
120 } else {
121 F::one() / norm_max
122 }
123 });
124
125 let offsets = Array1::zeros(records.dim().1);
126 Ok(LinearScaler {
127 offsets,
128 scales,
129 method: ScalingMethod::MaxAbs,
130 })
131 }
132}
133
134impl<F: Float> std::fmt::Display for ScalingMethod<F> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 ScalingMethod::Standard(with_mean, with_std) => write!(
138 f,
139 "Standard scaler (with_mean = {with_mean}, with_std = {with_std})"
140 ),
141 ScalingMethod::MinMax(min, max) => {
142 write!(f, "Min-Max scaler (min = {min}, max = {max})")
143 }
144 ScalingMethod::MaxAbs => write!(f, "MaxAbs scaler"),
145 }
146 }
147}
148
149#[cfg_attr(
167 feature = "serde",
168 derive(Serialize, Deserialize),
169 serde(crate = "serde_crate")
170)]
171#[derive(Debug, Clone, PartialEq, Eq)]
172pub struct LinearScalerParams<F: Float> {
173 method: ScalingMethod<F>,
174}
175
176impl<F: Float> LinearScalerParams<F> {
177 pub fn new(method: ScalingMethod<F>) -> Self {
179 Self { method }
180 }
181
182 pub fn method(mut self, method: ScalingMethod<F>) -> Self {
184 self.method = method;
185 self
186 }
187}
188
189impl<F: Float> LinearScaler<F> {
190 pub fn standard() -> LinearScalerParams<F> {
192 LinearScalerParams {
193 method: ScalingMethod::Standard(true, true),
194 }
195 }
196
197 pub fn standard_no_mean() -> LinearScalerParams<F> {
199 LinearScalerParams {
200 method: ScalingMethod::Standard(false, true),
201 }
202 }
203
204 pub fn standard_no_std() -> LinearScalerParams<F> {
206 LinearScalerParams {
207 method: ScalingMethod::Standard(true, false),
208 }
209 }
210
211 pub fn min_max() -> LinearScalerParams<F> {
213 LinearScalerParams {
214 method: ScalingMethod::MinMax(F::zero(), F::one()),
215 }
216 }
217
218 pub fn min_max_range(min: F, max: F) -> LinearScalerParams<F> {
222 LinearScalerParams {
223 method: ScalingMethod::MinMax(min, max),
224 }
225 }
226
227 pub fn max_abs() -> LinearScalerParams<F> {
229 LinearScalerParams {
230 method: ScalingMethod::MaxAbs,
231 }
232 }
233}
234
235impl<F: Float, D: Data<Elem = F>, T: AsTargets> Fit<ArrayBase<D, Ix2>, T, PreprocessingError>
236 for LinearScalerParams<F>
237{
238 type Object = LinearScaler<F>;
239
240 fn fit(&self, x: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
243 self.method.fit(x.records())
244 }
245}
246
247#[cfg_attr(
248 feature = "serde",
249 derive(Serialize, Deserialize),
250 serde(crate = "serde_crate")
251)]
252#[derive(Debug, Clone, PartialEq, Eq)]
253pub struct LinearScaler<F: Float> {
256 offsets: Array1<F>,
257 scales: Array1<F>,
258 method: ScalingMethod<F>,
259}
260
261impl<F: Float> LinearScaler<F> {
262 pub fn offsets(&self) -> &Array1<F> {
264 &self.offsets
265 }
266
267 pub fn scales(&self) -> &Array1<F> {
269 &self.scales
270 }
271
272 pub fn method(&self) -> &ScalingMethod<F> {
274 &self.method
275 }
276}
277
278impl<F: Float> Transformer<Array2<F>, Array2<F>> for LinearScaler<F> {
279 fn transform(&self, x: Array2<F>) -> Array2<F> {
282 if x.is_empty() {
283 return x;
284 }
285 let mut x = x;
286 Zip::from(x.columns_mut())
287 .and(self.offsets())
288 .and(self.scales())
289 .for_each(|mut col, &offset, &scale| {
290 if let ScalingMethod::Standard(false, _) = self.method {
291 col.mapv_inplace(|el| (el - offset) * scale + offset);
292 } else {
293 col.mapv_inplace(|el| (el - offset) * scale);
294 }
295 });
296 match &self.method {
297 ScalingMethod::MinMax(min, max) => x * (*max - *min) + *min,
298 _ => x,
299 }
300 }
301}
302
303impl<F: Float, D: Data<Elem = F>, T: AsTargets>
304 Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>> for LinearScaler<F>
305{
306 fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
309 let feature_names = x.feature_names().to_vec();
310 let target_names = x.target_names().to_vec();
311 let (records, targets, weights) = (x.records, x.targets, x.weights);
312 let records = self.transform(records.to_owned());
313 DatasetBase::new(records, targets)
314 .with_weights(weights)
315 .with_feature_names(feature_names)
316 .with_target_names(target_names)
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use crate::linear_scaling::{LinearScaler, LinearScalerParams};
323 use approx::assert_abs_diff_eq;
324 use linfa::dataset::DatasetBase;
325 use linfa::traits::{Fit, Transformer};
326 use ndarray::{array, Array2, Axis};
327
328 #[test]
329 fn autotraits() {
330 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
331 has_autotraits::<LinearScaler<f64>>();
332 has_autotraits::<LinearScalerParams<f64>>();
333 has_autotraits::<ScalingMethod<f64>>();
334 }
335
336 #[test]
337 fn test_max_abs() {
338 let dataset = array![[1., -1.], [2., -2.], [3., -3.], [4., -5.]].into();
339 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
340 let scaled = scaler.transform(dataset);
341 let col0 = scaled.records().column(0);
342 let col1 = scaled.records().column(1);
343 assert_abs_diff_eq!(col0, array![1. / 4., 2. / 4., 3. / 4., 1.]);
344 assert_abs_diff_eq!(col1, array![-1. / 5., -2. / 5., -3. / 5., -1.]);
345 }
346
347 #[test]
348 fn test_standard_scaler() {
349 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
350 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
351 assert_abs_diff_eq!(*scaler.offsets(), array![1., 0., 1. / 3.]);
352 assert_abs_diff_eq!(
353 *scaler.scales(),
354 array![1. / 0.81, 1. / 0.81, 1. / 1.24],
355 epsilon = 1e-2
356 );
357 let transformed = scaler.transform(dataset);
358 let means = transformed.records().mean_axis(Axis(0)).unwrap();
359 let std_devs = transformed.records().std_axis(Axis(0), 0.);
360 assert_abs_diff_eq!(means, array![0., 0., 0.]);
361 assert_abs_diff_eq!(std_devs, array![1., 1., 1.]);
362 }
363
364 #[test]
365 fn test_standard_scaler_no_mean() {
366 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
367 let scaler = LinearScaler::standard_no_mean().fit(&dataset).unwrap();
368 assert_abs_diff_eq!(*scaler.offsets(), array![1., 0., 1. / 3.]);
369 assert_abs_diff_eq!(
370 *scaler.scales(),
371 array![1. / 0.81, 1. / 0.81, 1. / 1.24],
372 epsilon = 1e-2
373 );
374 let transformed = scaler.transform(dataset);
375 let means = transformed.records().mean_axis(Axis(0)).unwrap();
376 let std_devs = transformed.records().std_axis(Axis(0), 0.);
377 assert_abs_diff_eq!(means, array![1., 0., (1. / 3.)], epsilon = 1e-2);
378 assert_abs_diff_eq!(std_devs, array![1., 1., 1.]);
379 }
380
381 #[test]
382 fn test_standard_scaler_no_std() {
383 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
384 let scaler = LinearScaler::standard_no_std().fit(&dataset).unwrap();
385 assert_abs_diff_eq!(*scaler.offsets(), array![1., 0., 1. / 3.]);
386 assert_abs_diff_eq!(*scaler.scales(), array![1., 1., 1.],);
387 let transformed = scaler.transform(dataset);
388 let means = transformed.records().mean_axis(Axis(0)).unwrap();
389 let std_devs = transformed.records().std_axis(Axis(0), 0.);
390 assert_abs_diff_eq!(means, array![0., 0., 0.]);
391 assert_abs_diff_eq!(std_devs, array![0.81, 0.81, 1.24], epsilon = 1e-2);
392 }
393
394 use super::ScalingMethod;
395
396 #[test]
397 fn test_standard_scaler_no_both() {
398 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
399 let scaler = LinearScalerParams::new(ScalingMethod::Standard(false, false))
400 .fit(&dataset)
401 .unwrap();
402
403 let original_means = dataset.records().mean_axis(Axis(0)).unwrap();
404 let original_stds = dataset.records().std_axis(Axis(0), 0.);
405
406 assert_abs_diff_eq!(*scaler.offsets(), original_means);
407 assert_abs_diff_eq!(*scaler.scales(), array![1., 1., 1.],);
408
409 let transformed = scaler.transform(dataset);
410
411 let means = transformed.records().mean_axis(Axis(0)).unwrap();
412 let std_devs = transformed.records().std_axis(Axis(0), 0.);
413
414 assert_abs_diff_eq!(means, original_means);
415 assert_abs_diff_eq!(std_devs, original_stds, epsilon = 1e-2);
416 }
417
418 #[test]
419 fn test_min_max_scaler() {
420 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
421 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
422 assert_abs_diff_eq!(*scaler.offsets(), array![0., -1., -1.]);
423 assert_abs_diff_eq!(*scaler.scales(), array![1. / 2., 1. / 2., 1. / 3.]);
424 let transformed = scaler.transform(dataset);
425 let mins = transformed
426 .records()
427 .fold_axis(
428 Axis(0),
429 f64::INFINITY,
430 |&x, &prev| if x < prev { x } else { prev },
431 );
432 let maxes = transformed
433 .records()
434 .fold_axis(
435 Axis(0),
436 f64::NEG_INFINITY,
437 |&x, &prev| if x > prev { x } else { prev },
438 );
439 assert_abs_diff_eq!(maxes, array![1., 1., 1.]);
440 assert_abs_diff_eq!(mins, array![0., 0., 0.]);
441 }
442
443 #[test]
444 fn test_min_max_scaler_range() {
445 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
446 let scaler = LinearScaler::min_max_range(5., 10.).fit(&dataset).unwrap();
447 assert_abs_diff_eq!(*scaler.offsets(), array![0., -1., -1.]);
448 assert_abs_diff_eq!(*scaler.scales(), array![1. / 2., 1. / 2., 1. / 3.]);
449 let transformed = scaler.transform(dataset);
450 let mins = transformed
451 .records()
452 .fold_axis(
453 Axis(0),
454 f64::INFINITY,
455 |&x, &prev| if x < prev { x } else { prev },
456 );
457 let maxes = transformed
458 .records()
459 .fold_axis(
460 Axis(0),
461 f64::NEG_INFINITY,
462 |&x, &prev| if x > prev { x } else { prev },
463 );
464 assert_abs_diff_eq!(mins, array![5., 5., 5.]);
465 assert_abs_diff_eq!(maxes, array![10., 10., 10.]);
466 }
467
468 #[test]
469 fn test_standard_const_feature() {
470 let dataset = array![[1., 2., 2.], [2., 2., 0.], [0., 2., -1.]].into();
471 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
472 assert_abs_diff_eq!(*scaler.offsets(), array![1., 2., 1. / 3.]);
473 assert_abs_diff_eq!(
474 *scaler.scales(),
475 array![1. / 0.81, 1., 1. / 1.24],
476 epsilon = 1e-2
477 );
478 let transformed = scaler.transform(dataset);
479 let means = transformed.records().mean_axis(Axis(0)).unwrap();
480 let std_devs = transformed.records().std_axis(Axis(0), 0.);
481 assert_abs_diff_eq!(means, array![0., 0., 0.]);
482 assert_abs_diff_eq!(std_devs, array![1., 0., 1.]);
484 }
485
486 #[test]
487 fn test_max_abs_const_null_feature() {
488 let dataset = array![[1., 0.], [2., 0.], [3., 0.], [4., 0.]].into();
489 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
490 let scaled = scaler.transform(dataset);
491 let col0 = scaled.records().column(0);
492 let col1 = scaled.records().column(1);
493 assert_abs_diff_eq!(col0, array![1. / 4., 2. / 4., 3. / 4., 1.]);
494 assert_abs_diff_eq!(col1, array![0., 0., 0., 0.]);
496 }
497
498 #[test]
499 fn test_min_max_scaler_const_feature() {
500 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
501 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
502 assert_abs_diff_eq!(*scaler.offsets(), array![0., -1., 2.]);
503 assert_abs_diff_eq!(*scaler.scales(), array![1. / 2., 1. / 2., 1.]);
504 let transformed = scaler.transform(dataset);
505 let mins = transformed
506 .records()
507 .fold_axis(
508 Axis(0),
509 f64::INFINITY,
510 |&x, &prev| if x < prev { x } else { prev },
511 );
512 let maxes = transformed
513 .records()
514 .fold_axis(
515 Axis(0),
516 f64::NEG_INFINITY,
517 |&x, &prev| if x > prev { x } else { prev },
518 );
519 assert_abs_diff_eq!(maxes, array![1., 1., 0.]);
521 assert_abs_diff_eq!(mins, array![0., 0., 0.]);
522 }
523
524 #[test]
525 fn test_empty_input() {
526 let dataset: DatasetBase<Array2<f64>, _> =
527 Array2::from_shape_vec((0, 0), vec![]).unwrap().into();
528 let scaler = LinearScaler::standard().fit(&dataset);
529 assert_eq!(
530 scaler.err().unwrap().to_string(),
531 "not enough samples".to_string()
532 );
533 let scaler = LinearScaler::standard_no_mean().fit(&dataset);
534 assert_eq!(
535 scaler.err().unwrap().to_string(),
536 "not enough samples".to_string()
537 );
538 let scaler = LinearScaler::standard_no_std().fit(&dataset);
539 assert_eq!(
540 scaler.err().unwrap().to_string(),
541 "not enough samples".to_string()
542 );
543 let scaler = LinearScaler::min_max().fit(&dataset);
544 assert_eq!(
545 scaler.err().unwrap().to_string(),
546 "not enough samples".to_string()
547 );
548 let scaler = LinearScaler::max_abs().fit(&dataset);
549 assert_eq!(
550 scaler.err().unwrap().to_string(),
551 "not enough samples".to_string()
552 );
553 }
554
555 #[test]
556 fn test_transform_empty_array() {
557 let empty: Array2<f64> = Array2::from_shape_vec((0, 0), vec![]).unwrap();
558 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
559 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
560 let transformed = scaler.transform(empty.clone());
561 assert!(transformed.is_empty());
562 let scaler = LinearScaler::standard_no_mean().fit(&dataset).unwrap();
563 let transformed = scaler.transform(empty.clone());
564 assert!(transformed.is_empty());
565 let scaler = LinearScaler::standard_no_std().fit(&dataset).unwrap();
566 let transformed = scaler.transform(empty.clone());
567 assert!(transformed.is_empty());
568 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
569 let transformed = scaler.transform(empty.clone());
570 assert!(transformed.is_empty());
571 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
572 let transformed = scaler.transform(empty);
573 assert!(transformed.is_empty());
574 }
575
576 #[test]
577 fn test_retain_feature_names() {
578 let dataset = linfa_datasets::diabetes();
579 let original_feature_names = dataset.feature_names().to_vec();
580 let transformed = LinearScaler::standard()
581 .fit(&dataset)
582 .unwrap()
583 .transform(dataset);
584 assert_eq!(original_feature_names, transformed.feature_names())
585 }
586
587 #[test]
588 #[should_panic]
589 fn test_transform_wrong_size_array_standard() {
590 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
591 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
592 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
593 let _transformed = scaler.transform(wrong_size);
594 }
595 #[test]
596 #[should_panic]
597 fn test_transform_wrong_size_array_standard_no_mean() {
598 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
599 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
600 let scaler = LinearScaler::standard_no_mean().fit(&dataset).unwrap();
601 let _transformed = scaler.transform(wrong_size);
602 }
603 #[test]
604 #[should_panic]
605 fn test_transform_wrong_size_array_standard_no_std() {
606 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
607 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
608 let scaler = LinearScaler::standard_no_std().fit(&dataset).unwrap();
609 let _transformed = scaler.transform(wrong_size);
610 }
611 #[test]
612 #[should_panic]
613 fn test_transform_wrong_size_array_min_max() {
614 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
615 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
616 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
617 let _transformed = scaler.transform(wrong_size);
618 }
619 #[test]
620 #[should_panic]
621 fn test_transform_wrong_size_array_max_abs() {
622 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
623 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
624 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
625 let _transformed = scaler.transform(wrong_size);
626 }
627
628 #[test]
629 #[should_panic]
630 fn test_min_max_wrong_range() {
631 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
632 let _scaler = LinearScaler::min_max_range(10., 5.).fit(&dataset).unwrap();
633 }
634}