linfa_preprocessing/
norm_scaling.rs

1//! Sample normalization methods
2use linfa::dataset::{AsTargets, DatasetBase, Float, WithLapack, WithoutLapack};
3use linfa::traits::Transformer;
4#[cfg(not(feature = "blas"))]
5use linfa_linalg::norm::Norm;
6use ndarray::{Array2, ArrayBase, Axis, Data, Ix2, Zip};
7#[cfg(feature = "blas")]
8use ndarray_linalg::norm::Norm;
9
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12
13#[cfg_attr(
14    feature = "serde",
15    derive(Serialize, Deserialize),
16    serde(crate = "serde_crate")
17)]
18#[derive(Debug, Clone, PartialEq, Eq)]
19enum Norms {
20    L1,
21    L2,
22    Max,
23}
24
25/// Norm scaler: scales all samples in a dataset to have unit norm, according to the specified norm
26/// measure
27///
28/// ### Example
29///
30/// ```rust
31/// use linfa::traits::Transformer;
32/// use linfa_preprocessing::norm_scaling::NormScaler;
33///
34/// // Load dataset
35/// let dataset = linfa_datasets::diabetes();
36/// // Initialize scaler
37/// let scaler = NormScaler::l2();
38/// // Scale dataset
39/// let dataset = scaler.transform(dataset);
40/// ```
41#[cfg_attr(
42    feature = "serde",
43    derive(Serialize, Deserialize),
44    serde(crate = "serde_crate")
45)]
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct NormScaler {
48    norm: Norms,
49}
50
51impl NormScaler {
52    /// Initializes a norm scaler that uses l2 norm
53    pub fn l2() -> Self {
54        Self { norm: Norms::L2 }
55    }
56
57    /// Initializes a norm scaler that uses l1 norm
58    pub fn l1() -> Self {
59        Self { norm: Norms::L1 }
60    }
61
62    /// Initializes a norm scaler that uses max norm
63    pub fn max() -> Self {
64        Self { norm: Norms::Max }
65    }
66}
67
68impl<F: Float> Transformer<Array2<F>, Array2<F>> for NormScaler {
69    /// Scales all samples in the array of shape (nsamples, nfeatures) to have unit norm.
70    fn transform(&self, x: Array2<F>) -> Array2<F> {
71        // add Lapack trait bound
72        let x = x.with_lapack();
73
74        let norms = match &self.norm {
75            Norms::L1 => x.map_axis(Axis(1), |row| F::cast(row.norm_l1())),
76            Norms::L2 => x.map_axis(Axis(1), |row| F::cast(row.norm_l2())),
77            Norms::Max => x.map_axis(Axis(1), |row| F::cast(row.norm_max())),
78        };
79
80        // remove Lapack trait bound
81        let mut x = x.without_lapack();
82
83        Zip::from(x.rows_mut())
84            .and(&norms)
85            .for_each(|mut row, &norm| {
86                row.mapv_inplace(|el| el / norm);
87            });
88        x
89    }
90}
91
92impl<F: Float, D: Data<Elem = F>, T: AsTargets>
93    Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>> for NormScaler
94{
95    /// Substitutes the records of the dataset with their scaled versions with unit norm.
96    fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
97        let feature_names = x.feature_names().to_vec();
98        let target_names = x.target_names().to_vec();
99        let (records, targets, weights) = (x.records, x.targets, x.weights);
100        let records = self.transform(records.to_owned());
101        DatasetBase::new(records, targets)
102            .with_weights(weights)
103            .with_feature_names(feature_names)
104            .with_target_names(target_names)
105    }
106}
107
108#[cfg(test)]
109mod tests {
110
111    use crate::norm_scaling::NormScaler;
112    use approx::assert_abs_diff_eq;
113    use linfa::dataset::DatasetBase;
114    use linfa::traits::Transformer;
115    use ndarray::{array, Array2};
116
117    #[test]
118    fn autotraits() {
119        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
120        has_autotraits::<NormScaler>();
121    }
122
123    #[test]
124    fn test_norm_l2() {
125        let dataset = DatasetBase::from(array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]]);
126        let scaler = NormScaler::l2();
127        let normalized_data = scaler.transform(dataset);
128        let ground_truth = array![[0.4, -0.4, 0.81], [1., 0., 0.], [0., 0.7, -0.7]];
129        assert_abs_diff_eq!(*normalized_data.records(), ground_truth, epsilon = 1e-2);
130    }
131
132    #[test]
133    fn test_norm_l1() {
134        let dataset = DatasetBase::from(array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]]);
135        let scaler = NormScaler::l1();
136        let normalized_data = scaler.transform(dataset);
137        let ground_truth = array![[0.25, -0.25, 0.5], [1., 0., 0.], [0., 0.5, -0.5]];
138        assert_abs_diff_eq!(*normalized_data.records(), ground_truth, epsilon = 1e-2);
139    }
140
141    #[test]
142    fn test_norm_max() {
143        let dataset = DatasetBase::from(array![[1., -1., 2.], [2., 0., 0.], [0., 1., -1.]]);
144        let scaler = NormScaler::max();
145        let normalized_data = scaler.transform(dataset);
146        let ground_truth = array![[0.5, -0.5, 1.], [1., 0., 0.], [0., 1., -1.]];
147        assert_abs_diff_eq!(*normalized_data.records(), ground_truth, epsilon = 1e-2);
148    }
149
150    #[test]
151    fn test_no_input() {
152        let input: Array2<f64> = Array2::from_shape_vec((0, 0), vec![]).unwrap();
153        let ground_truth: Array2<f64> = Array2::from_shape_vec((0, 0), vec![]).unwrap();
154        let scaler = NormScaler::max();
155        assert_abs_diff_eq!(scaler.transform(input.clone()), ground_truth);
156        let scaler = NormScaler::l1();
157        assert_abs_diff_eq!(scaler.transform(input.clone()), ground_truth);
158        let scaler = NormScaler::l2();
159        assert_abs_diff_eq!(scaler.transform(input), ground_truth);
160    }
161
162    #[test]
163    fn test_retain_feature_names() {
164        let dataset = linfa_datasets::diabetes();
165        let original_feature_names = dataset.feature_names().to_vec();
166        let transformed = NormScaler::l2().transform(dataset);
167        assert_eq!(original_feature_names, transformed.feature_names())
168    }
169}