1use crate::error::{PreprocessingError, Result};
4use approx::abs_diff_eq;
5use linfa::dataset::{AsTargets, DatasetBase, Float, WithLapack};
6use linfa::traits::{Fit, Transformer};
7#[cfg(not(feature = "blas"))]
8use linfa_linalg::norm::Norm;
9use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2, Zip};
10#[cfg(feature = "blas")]
11use ndarray_linalg::norm::Norm;
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16#[cfg_attr(
17 feature = "serde",
18 derive(Serialize, Deserialize),
19 serde(crate = "serde_crate")
20)]
21#[derive(Clone, Debug, PartialEq, Eq)]
22pub enum ScalingMethod<F: Float> {
28 Standard(bool, bool),
29 MinMax(F, F),
30 MaxAbs,
31}
32
33impl<F: Float> ScalingMethod<F> {
34 pub(crate) fn fit<D: Data<Elem = F>>(
35 &self,
36 records: &ArrayBase<D, Ix2>,
37 ) -> Result<LinearScaler<F>> {
38 match self {
39 ScalingMethod::Standard(a, b) => Self::standardize(records, *a, *b),
40 ScalingMethod::MinMax(a, b) => Self::min_max(records, *a, *b),
41 ScalingMethod::MaxAbs => Self::max_abs(records),
42 }
43 }
44
45 fn standardize<D: Data<Elem = F>>(
46 records: &ArrayBase<D, Ix2>,
47 with_mean: bool,
48 with_std: bool,
49 ) -> Result<LinearScaler<F>> {
50 if records.dim().0 == 0 {
51 return Err(PreprocessingError::NotEnoughSamples);
52 }
53 let means = records.mean_axis(Axis(0)).unwrap();
55 let std_devs = if with_std {
56 records.std_axis(Axis(0), F::zero()).mapv(|s| {
57 if abs_diff_eq!(s, F::zero()) {
58 F::one()
60 } else {
61 F::one() / s
62 }
63 })
64 } else {
65 Array1::ones(records.dim().1)
66 };
67 Ok(LinearScaler {
68 offsets: means,
69 scales: std_devs,
70 method: ScalingMethod::Standard(with_mean, with_std),
71 })
72 }
73
74 fn min_max<D: Data<Elem = F>>(
75 records: &ArrayBase<D, Ix2>,
76 min: F,
77 max: F,
78 ) -> Result<LinearScaler<F>> {
79 if records.dim().0 == 0 {
80 return Err(PreprocessingError::NotEnoughSamples);
81 } else if min > max {
82 return Err(PreprocessingError::FlippedMinMaxRange);
83 }
84
85 let mins = records.fold_axis(
86 Axis(0),
87 F::infinity(),
88 |&x, &prev| if x < prev { x } else { prev },
89 );
90 let mut scales =
91 records.fold_axis(
92 Axis(0),
93 F::neg_infinity(),
94 |&x, &prev| if x > prev { x } else { prev },
95 );
96 Zip::from(&mut scales).and(&mins).for_each(|max, min| {
97 if abs_diff_eq!(*max - *min, F::zero()) {
98 *max = F::one();
100 } else {
101 *max = F::one() / (*max - *min);
102 }
103 });
104 Ok(LinearScaler {
105 offsets: mins,
106 scales,
107 method: ScalingMethod::MinMax(min, max),
108 })
109 }
110
111 fn max_abs<D: Data<Elem = F>>(records: &ArrayBase<D, Ix2>) -> Result<LinearScaler<F>> {
112 if records.dim().0 == 0 {
113 return Err(PreprocessingError::NotEnoughSamples);
114 }
115 let scales: Array1<F> = records.map_axis(Axis(0), |col| {
116 let norm_max = F::cast(col.with_lapack().norm_max());
117 if abs_diff_eq!(norm_max, F::zero()) {
118 F::one()
120 } else {
121 F::one() / norm_max
122 }
123 });
124
125 let offsets = Array1::zeros(records.dim().1);
126 Ok(LinearScaler {
127 offsets,
128 scales,
129 method: ScalingMethod::MaxAbs,
130 })
131 }
132}
133
134impl<F: Float> std::fmt::Display for ScalingMethod<F> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 ScalingMethod::Standard(with_mean, with_std) => write!(
138 f,
139 "Standard scaler (with_mean = {}, with_std = {})",
140 with_mean, with_std
141 ),
142 ScalingMethod::MinMax(min, max) => {
143 write!(f, "Min-Max scaler (min = {}, max = {})", min, max)
144 }
145 ScalingMethod::MaxAbs => write!(f, "MaxAbs scaler"),
146 }
147 }
148}
149
150#[cfg_attr(
168 feature = "serde",
169 derive(Serialize, Deserialize),
170 serde(crate = "serde_crate")
171)]
172#[derive(Debug, Clone, PartialEq, Eq)]
173pub struct LinearScalerParams<F: Float> {
174 method: ScalingMethod<F>,
175}
176
177impl<F: Float> LinearScalerParams<F> {
178 pub fn new(method: ScalingMethod<F>) -> Self {
180 Self { method }
181 }
182
183 pub fn method(mut self, method: ScalingMethod<F>) -> Self {
185 self.method = method;
186 self
187 }
188}
189
190impl<F: Float> LinearScaler<F> {
191 pub fn standard() -> LinearScalerParams<F> {
193 LinearScalerParams {
194 method: ScalingMethod::Standard(true, true),
195 }
196 }
197
198 pub fn standard_no_mean() -> LinearScalerParams<F> {
200 LinearScalerParams {
201 method: ScalingMethod::Standard(false, true),
202 }
203 }
204
205 pub fn standard_no_std() -> LinearScalerParams<F> {
207 LinearScalerParams {
208 method: ScalingMethod::Standard(true, false),
209 }
210 }
211
212 pub fn min_max() -> LinearScalerParams<F> {
214 LinearScalerParams {
215 method: ScalingMethod::MinMax(F::zero(), F::one()),
216 }
217 }
218
219 pub fn min_max_range(min: F, max: F) -> LinearScalerParams<F> {
223 LinearScalerParams {
224 method: ScalingMethod::MinMax(min, max),
225 }
226 }
227
228 pub fn max_abs() -> LinearScalerParams<F> {
230 LinearScalerParams {
231 method: ScalingMethod::MaxAbs,
232 }
233 }
234}
235
236impl<F: Float, D: Data<Elem = F>, T: AsTargets> Fit<ArrayBase<D, Ix2>, T, PreprocessingError>
237 for LinearScalerParams<F>
238{
239 type Object = LinearScaler<F>;
240
241 fn fit(&self, x: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
244 self.method.fit(x.records())
245 }
246}
247
248#[cfg_attr(
249 feature = "serde",
250 derive(Serialize, Deserialize),
251 serde(crate = "serde_crate")
252)]
253#[derive(Debug, Clone, PartialEq, Eq)]
254pub struct LinearScaler<F: Float> {
257 offsets: Array1<F>,
258 scales: Array1<F>,
259 method: ScalingMethod<F>,
260}
261
262impl<F: Float> LinearScaler<F> {
263 pub fn offsets(&self) -> &Array1<F> {
265 &self.offsets
266 }
267
268 pub fn scales(&self) -> &Array1<F> {
270 &self.scales
271 }
272
273 pub fn method(&self) -> &ScalingMethod<F> {
275 &self.method
276 }
277}
278
279impl<F: Float> Transformer<Array2<F>, Array2<F>> for LinearScaler<F> {
280 fn transform(&self, x: Array2<F>) -> Array2<F> {
283 if x.is_empty() {
284 return x;
285 }
286 let mut x = x;
287 Zip::from(x.columns_mut())
288 .and(self.offsets())
289 .and(self.scales())
290 .for_each(|mut col, &offset, &scale| {
291 if let ScalingMethod::Standard(false, _) = self.method {
292 col.mapv_inplace(|el| (el - offset) * scale + offset);
293 } else {
294 col.mapv_inplace(|el| (el - offset) * scale);
295 }
296 });
297 match &self.method {
298 ScalingMethod::MinMax(min, max) => x * (*max - *min) + *min,
299 _ => x,
300 }
301 }
302}
303
304impl<F: Float, D: Data<Elem = F>, T: AsTargets>
305 Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>> for LinearScaler<F>
306{
307 fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
310 let feature_names = x.feature_names().to_vec();
311 let target_names = x.target_names().to_vec();
312 let (records, targets, weights) = (x.records, x.targets, x.weights);
313 let records = self.transform(records.to_owned());
314 DatasetBase::new(records, targets)
315 .with_weights(weights)
316 .with_feature_names(feature_names)
317 .with_target_names(target_names)
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use crate::linear_scaling::{LinearScaler, LinearScalerParams};
324 use approx::assert_abs_diff_eq;
325 use linfa::dataset::DatasetBase;
326 use linfa::traits::{Fit, Transformer};
327 use ndarray::{array, Array2, Axis};
328
329 #[test]
330 fn autotraits() {
331 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
332 has_autotraits::<LinearScaler<f64>>();
333 has_autotraits::<LinearScalerParams<f64>>();
334 has_autotraits::<ScalingMethod<f64>>();
335 }
336
337 #[test]
338 fn test_max_abs() {
339 let dataset = array![[1., -1.], [2., -2.], [3., -3.], [4., -5.]].into();
340 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
341 let scaled = scaler.transform(dataset);
342 let col0 = scaled.records().column(0);
343 let col1 = scaled.records().column(1);
344 assert_abs_diff_eq!(col0, array![1. / 4., 2. / 4., 3. / 4., 1.]);
345 assert_abs_diff_eq!(col1, array![-1. / 5., -2. / 5., -3. / 5., -1.]);
346 }
347
348 #[test]
349 fn test_standard_scaler() {
350 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
351 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
352 assert_abs_diff_eq!(*scaler.offsets(), array![1., 0., 1. / 3.]);
353 assert_abs_diff_eq!(
354 *scaler.scales(),
355 array![1. / 0.81, 1. / 0.81, 1. / 1.24],
356 epsilon = 1e-2
357 );
358 let transformed = scaler.transform(dataset);
359 let means = transformed.records().mean_axis(Axis(0)).unwrap();
360 let std_devs = transformed.records().std_axis(Axis(0), 0.);
361 assert_abs_diff_eq!(means, array![0., 0., 0.]);
362 assert_abs_diff_eq!(std_devs, array![1., 1., 1.]);
363 }
364
365 #[test]
366 fn test_standard_scaler_no_mean() {
367 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
368 let scaler = LinearScaler::standard_no_mean().fit(&dataset).unwrap();
369 assert_abs_diff_eq!(*scaler.offsets(), array![1., 0., 1. / 3.]);
370 assert_abs_diff_eq!(
371 *scaler.scales(),
372 array![1. / 0.81, 1. / 0.81, 1. / 1.24],
373 epsilon = 1e-2
374 );
375 let transformed = scaler.transform(dataset);
376 let means = transformed.records().mean_axis(Axis(0)).unwrap();
377 let std_devs = transformed.records().std_axis(Axis(0), 0.);
378 assert_abs_diff_eq!(means, array![1., 0., (1. / 3.)], epsilon = 1e-2);
379 assert_abs_diff_eq!(std_devs, array![1., 1., 1.]);
380 }
381
382 #[test]
383 fn test_standard_scaler_no_std() {
384 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
385 let scaler = LinearScaler::standard_no_std().fit(&dataset).unwrap();
386 assert_abs_diff_eq!(*scaler.offsets(), array![1., 0., 1. / 3.]);
387 assert_abs_diff_eq!(*scaler.scales(), array![1., 1., 1.],);
388 let transformed = scaler.transform(dataset);
389 let means = transformed.records().mean_axis(Axis(0)).unwrap();
390 let std_devs = transformed.records().std_axis(Axis(0), 0.);
391 assert_abs_diff_eq!(means, array![0., 0., 0.]);
392 assert_abs_diff_eq!(std_devs, array![0.81, 0.81, 1.24], epsilon = 1e-2);
393 }
394
395 use super::ScalingMethod;
396
397 #[test]
398 fn test_standard_scaler_no_both() {
399 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
400 let scaler = LinearScalerParams::new(ScalingMethod::Standard(false, false))
401 .fit(&dataset)
402 .unwrap();
403
404 let original_means = dataset.records().mean_axis(Axis(0)).unwrap();
405 let original_stds = dataset.records().std_axis(Axis(0), 0.);
406
407 assert_abs_diff_eq!(*scaler.offsets(), original_means);
408 assert_abs_diff_eq!(*scaler.scales(), array![1., 1., 1.],);
409
410 let transformed = scaler.transform(dataset);
411
412 let means = transformed.records().mean_axis(Axis(0)).unwrap();
413 let std_devs = transformed.records().std_axis(Axis(0), 0.);
414
415 assert_abs_diff_eq!(means, original_means);
416 assert_abs_diff_eq!(std_devs, original_stds, epsilon = 1e-2);
417 }
418
419 #[test]
420 fn test_min_max_scaler() {
421 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
422 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
423 assert_abs_diff_eq!(*scaler.offsets(), array![0., -1., -1.]);
424 assert_abs_diff_eq!(*scaler.scales(), array![1. / 2., 1. / 2., 1. / 3.]);
425 let transformed = scaler.transform(dataset);
426 let mins = transformed
427 .records()
428 .fold_axis(
429 Axis(0),
430 f64::INFINITY,
431 |&x, &prev| if x < prev { x } else { prev },
432 );
433 let maxes = transformed
434 .records()
435 .fold_axis(
436 Axis(0),
437 f64::NEG_INFINITY,
438 |&x, &prev| if x > prev { x } else { prev },
439 );
440 assert_abs_diff_eq!(maxes, array![1., 1., 1.]);
441 assert_abs_diff_eq!(mins, array![0., 0., 0.]);
442 }
443
444 #[test]
445 fn test_min_max_scaler_range() {
446 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
447 let scaler = LinearScaler::min_max_range(5., 10.).fit(&dataset).unwrap();
448 assert_abs_diff_eq!(*scaler.offsets(), array![0., -1., -1.]);
449 assert_abs_diff_eq!(*scaler.scales(), array![1. / 2., 1. / 2., 1. / 3.]);
450 let transformed = scaler.transform(dataset);
451 let mins = transformed
452 .records()
453 .fold_axis(
454 Axis(0),
455 f64::INFINITY,
456 |&x, &prev| if x < prev { x } else { prev },
457 );
458 let maxes = transformed
459 .records()
460 .fold_axis(
461 Axis(0),
462 f64::NEG_INFINITY,
463 |&x, &prev| if x > prev { x } else { prev },
464 );
465 assert_abs_diff_eq!(mins, array![5., 5., 5.]);
466 assert_abs_diff_eq!(maxes, array![10., 10., 10.]);
467 }
468
469 #[test]
470 fn test_standard_const_feature() {
471 let dataset = array![[1., 2., 2.], [2., 2., 0.], [0., 2., -1.]].into();
472 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
473 assert_abs_diff_eq!(*scaler.offsets(), array![1., 2., 1. / 3.]);
474 assert_abs_diff_eq!(
475 *scaler.scales(),
476 array![1. / 0.81, 1., 1. / 1.24],
477 epsilon = 1e-2
478 );
479 let transformed = scaler.transform(dataset);
480 let means = transformed.records().mean_axis(Axis(0)).unwrap();
481 let std_devs = transformed.records().std_axis(Axis(0), 0.);
482 assert_abs_diff_eq!(means, array![0., 0., 0.]);
483 assert_abs_diff_eq!(std_devs, array![1., 0., 1.]);
485 }
486
487 #[test]
488 fn test_max_abs_const_null_feature() {
489 let dataset = array![[1., 0.], [2., 0.], [3., 0.], [4., 0.]].into();
490 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
491 let scaled = scaler.transform(dataset);
492 let col0 = scaled.records().column(0);
493 let col1 = scaled.records().column(1);
494 assert_abs_diff_eq!(col0, array![1. / 4., 2. / 4., 3. / 4., 1.]);
495 assert_abs_diff_eq!(col1, array![0., 0., 0., 0.]);
497 }
498
499 #[test]
500 fn test_min_max_scaler_const_feature() {
501 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
502 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
503 assert_abs_diff_eq!(*scaler.offsets(), array![0., -1., 2.]);
504 assert_abs_diff_eq!(*scaler.scales(), array![1. / 2., 1. / 2., 1.]);
505 let transformed = scaler.transform(dataset);
506 let mins = transformed
507 .records()
508 .fold_axis(
509 Axis(0),
510 f64::INFINITY,
511 |&x, &prev| if x < prev { x } else { prev },
512 );
513 let maxes = transformed
514 .records()
515 .fold_axis(
516 Axis(0),
517 f64::NEG_INFINITY,
518 |&x, &prev| if x > prev { x } else { prev },
519 );
520 assert_abs_diff_eq!(maxes, array![1., 1., 0.]);
522 assert_abs_diff_eq!(mins, array![0., 0., 0.]);
523 }
524
525 #[test]
526 fn test_empty_input() {
527 let dataset: DatasetBase<Array2<f64>, _> =
528 Array2::from_shape_vec((0, 0), vec![]).unwrap().into();
529 let scaler = LinearScaler::standard().fit(&dataset);
530 assert_eq!(
531 scaler.err().unwrap().to_string(),
532 "not enough samples".to_string()
533 );
534 let scaler = LinearScaler::standard_no_mean().fit(&dataset);
535 assert_eq!(
536 scaler.err().unwrap().to_string(),
537 "not enough samples".to_string()
538 );
539 let scaler = LinearScaler::standard_no_std().fit(&dataset);
540 assert_eq!(
541 scaler.err().unwrap().to_string(),
542 "not enough samples".to_string()
543 );
544 let scaler = LinearScaler::min_max().fit(&dataset);
545 assert_eq!(
546 scaler.err().unwrap().to_string(),
547 "not enough samples".to_string()
548 );
549 let scaler = LinearScaler::max_abs().fit(&dataset);
550 assert_eq!(
551 scaler.err().unwrap().to_string(),
552 "not enough samples".to_string()
553 );
554 }
555
556 #[test]
557 fn test_transform_empty_array() {
558 let empty: Array2<f64> = Array2::from_shape_vec((0, 0), vec![]).unwrap();
559 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
560 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
561 let transformed = scaler.transform(empty.clone());
562 assert!(transformed.is_empty());
563 let scaler = LinearScaler::standard_no_mean().fit(&dataset).unwrap();
564 let transformed = scaler.transform(empty.clone());
565 assert!(transformed.is_empty());
566 let scaler = LinearScaler::standard_no_std().fit(&dataset).unwrap();
567 let transformed = scaler.transform(empty.clone());
568 assert!(transformed.is_empty());
569 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
570 let transformed = scaler.transform(empty.clone());
571 assert!(transformed.is_empty());
572 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
573 let transformed = scaler.transform(empty);
574 assert!(transformed.is_empty());
575 }
576
577 #[test]
578 fn test_retain_feature_names() {
579 let dataset = linfa_datasets::diabetes();
580 let original_feature_names = dataset.feature_names().to_vec();
581 let transformed = LinearScaler::standard()
582 .fit(&dataset)
583 .unwrap()
584 .transform(dataset);
585 assert_eq!(original_feature_names, transformed.feature_names())
586 }
587
588 #[test]
589 #[should_panic]
590 fn test_transform_wrong_size_array_standard() {
591 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
592 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
593 let scaler = LinearScaler::standard().fit(&dataset).unwrap();
594 let _transformed = scaler.transform(wrong_size);
595 }
596 #[test]
597 #[should_panic]
598 fn test_transform_wrong_size_array_standard_no_mean() {
599 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
600 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
601 let scaler = LinearScaler::standard_no_mean().fit(&dataset).unwrap();
602 let _transformed = scaler.transform(wrong_size);
603 }
604 #[test]
605 #[should_panic]
606 fn test_transform_wrong_size_array_standard_no_std() {
607 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
608 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
609 let scaler = LinearScaler::standard_no_std().fit(&dataset).unwrap();
610 let _transformed = scaler.transform(wrong_size);
611 }
612 #[test]
613 #[should_panic]
614 fn test_transform_wrong_size_array_min_max() {
615 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
616 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
617 let scaler = LinearScaler::min_max().fit(&dataset).unwrap();
618 let _transformed = scaler.transform(wrong_size);
619 }
620 #[test]
621 #[should_panic]
622 fn test_transform_wrong_size_array_max_abs() {
623 let wrong_size = Array2::from_shape_vec((1, 2), vec![0., 0.]).unwrap();
624 let dataset = array![[1., -1., 2.], [2., 0., 2.], [0., 1., 2.]].into();
625 let scaler = LinearScaler::max_abs().fit(&dataset).unwrap();
626 let _transformed = scaler.transform(wrong_size);
627 }
628
629 #[test]
630 #[should_panic]
631 fn test_min_max_wrong_range() {
632 let dataset = array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]].into();
633 let _scaler = LinearScaler::min_max_range(10., 5.).fit(&dataset).unwrap();
634 }
635}