linfa_datasets/
dataset.rs

1use std::io::Read;
2
3use csv::ReaderBuilder;
4use flate2::read::GzDecoder;
5use linfa::Dataset;
6use ndarray::prelude::*;
7use ndarray_csv::{Array2Reader, ReadError};
8
9/// Convert Gzipped CSV bytes into 2D array
10pub fn array_from_gz_csv<R: Read>(
11    gz: R,
12    has_headers: bool,
13    separator: u8,
14) -> Result<Array2<f64>, ReadError> {
15    // unzip file
16    let file = GzDecoder::new(gz);
17    array_from_csv(file, has_headers, separator)
18}
19
20/// Convert CSV bytes into 2D array
21pub fn array_from_csv<R: Read>(
22    csv: R,
23    has_headers: bool,
24    separator: u8,
25) -> Result<Array2<f64>, ReadError> {
26    // parse CSV
27    let mut reader = ReaderBuilder::new()
28        .has_headers(has_headers)
29        .delimiter(separator)
30        .from_reader(csv);
31
32    // extract ndarray
33    reader.deserialize_array2_dynamic()
34}
35
36#[cfg(feature = "iris")]
37/// Read in the iris-flower dataset from dataset path.
38// The `.csv` data is two dimensional: Axis(0) denotes y-axis (rows), Axis(1) denotes x-axis (columns)
39pub fn iris() -> Dataset<f64, usize, Ix1> {
40    let data = include_bytes!("../data/iris.csv.gz");
41    let array = array_from_gz_csv(&data[..], true, b',').unwrap();
42
43    let (data, targets) = (
44        array.slice(s![.., 0..4]).to_owned(),
45        array.column(4).to_owned(),
46    );
47
48    let feature_names = vec!["sepal length", "sepal width", "petal length", "petal width"];
49
50    Dataset::new(data, targets)
51        .map_targets(|x| *x as usize)
52        .with_feature_names(feature_names)
53}
54
55#[cfg(feature = "diabetes")]
56/// Read in the diabetes dataset from dataset path
57pub fn diabetes() -> Dataset<f64, f64, Ix1> {
58    let data = include_bytes!("../data/diabetes_data.csv.gz");
59    let data = array_from_gz_csv(&data[..], true, b',').unwrap();
60
61    let targets = include_bytes!("../data/diabetes_target.csv.gz");
62    let targets = array_from_gz_csv(&targets[..], true, b',')
63        .unwrap()
64        .column(0)
65        .to_owned();
66
67    let feature_names = vec![
68        "age",
69        "sex",
70        "body mass index",
71        "blood pressure",
72        "t-cells",
73        "low-density lipoproteins",
74        "high-density lipoproteins",
75        "thyroid stimulating hormone",
76        "lamotrigine",
77        "blood sugar level",
78    ];
79
80    Dataset::new(data, targets).with_feature_names(feature_names)
81}
82
83#[cfg(feature = "winequality")]
84/// Read in the winequality dataset from dataset path
85pub fn winequality() -> Dataset<f64, usize, Ix1> {
86    let data = include_bytes!("../data/winequality-red.csv.gz");
87    let array = array_from_gz_csv(&data[..], true, b',').unwrap();
88
89    let (data, targets) = (
90        array.slice(s![.., 0..11]).to_owned(),
91        array.column(11).to_owned(),
92    );
93
94    let feature_names = vec![
95        "fixed acidity",
96        "volatile acidity",
97        "citric acid",
98        "residual sugar",
99        "chlorides",
100        "free sulfur dioxide",
101        "total sulfur dioxide",
102        "density",
103        "pH",
104        "sulphates",
105        "alcohol",
106    ];
107
108    Dataset::new(data, targets)
109        .map_targets(|x| *x as usize)
110        .with_feature_names(feature_names)
111}
112
113#[cfg(feature = "linnerud")]
114/// Read in the physical exercise dataset from dataset path.
115///
116/// Linnerud dataset contains 20 samples collected from 20 middle-aged men in a fitness club.
117///
118/// ## Features:
119/// 3 exercises measurements: Chins, Situps, Jumps
120///
121/// ## Targets:
122/// 3 physiological measurements: Weight, Waist, Pulse
123///
124/// # Reference:
125/// Tenenhaus (1998). La regression PLS: theorie et pratique. Paris: Editions Technip. Table p 15.
126pub fn linnerud() -> Dataset<f64, f64> {
127    let input_data = include_bytes!("../data/linnerud_exercise.csv.gz");
128    let input_array = array_from_gz_csv(&input_data[..], true, b',').unwrap();
129
130    let output_data = include_bytes!("../data/linnerud_physiological.csv.gz");
131    let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();
132
133    let feature_names = vec!["Chins", "Situps", "Jumps"];
134    let target_names = vec!["Weight", "Waist", "Pulse"];
135
136    Dataset::new(input_array, output_array)
137        .with_feature_names(feature_names)
138        .with_target_names(target_names)
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use approx::assert_abs_diff_eq;
145    use linfa::prelude::*;
146
147    #[cfg(feature = "iris")]
148    #[test]
149    fn test_iris() {
150        let ds = iris();
151
152        // check that we have the right amount of data
153        assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (150, 4, 1));
154
155        // check for feature names
156        assert_eq!(
157            ds.feature_names(),
158            &["sepal length", "sepal width", "petal length", "petal width"]
159        );
160
161        // check label frequency
162        assert_abs_diff_eq!(
163            ds.label_frequencies()
164                .into_iter()
165                .map(|b| b.1)
166                .collect::<Array1<_>>(),
167            array![50., 50., 50.]
168        );
169
170        // perform correlation analysis and assert that petal length and width are correlated
171        let _pcc = ds.pearson_correlation_with_p_value(100);
172        // TODO: wait for pearson correlation to accept rng
173        // assert_abs_diff_eq!(pcc.get_p_values().unwrap()[5], 0.04, epsilon = 0.04);
174
175        // get the mean per feature
176        let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
177        assert_abs_diff_eq!(
178            mean_features,
179            array![5.84, 3.05, 3.75, 1.20],
180            epsilon = 0.01
181        );
182    }
183
184    #[cfg(feature = "diabetes")]
185    #[test]
186    fn test_diabetes() {
187        let ds = diabetes();
188
189        // check that we have the right amount of data
190        assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (441, 10, 1));
191
192        // perform correlation analysis and assert that T-Cells and low-density lipoproteins are
193        // correlated
194        let _pcc = ds.pearson_correlation_with_p_value(100);
195        //assert_abs_diff_eq!(pcc.get_p_values().unwrap()[30], 0.02, epsilon = 0.02);
196
197        // get the mean per feature, the data should be normalized
198        let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
199        assert_abs_diff_eq!(mean_features, Array1::zeros(10), epsilon = 0.005);
200    }
201
202    #[cfg(feature = "winequality")]
203    #[test]
204    fn test_winequality() {
205        use approx::abs_diff_eq;
206
207        let ds = winequality();
208
209        // check that we have the right amount of data
210        assert_eq!(
211            (ds.nsamples(), ds.nfeatures(), ds.ntargets()),
212            (1599, 11, 1)
213        );
214
215        // check for feature names
216        let feature_names = vec![
217            "fixed acidity",
218            "volatile acidity",
219            "citric acid",
220            "residual sugar",
221            "chlorides",
222            "free sulfur dioxide",
223            "total sulfur dioxide",
224            "density",
225            "pH",
226            "sulphates",
227            "alcohol",
228        ];
229        assert_eq!(ds.feature_names(), feature_names);
230
231        // check label frequency
232        let compare_to = vec![
233            (5, 681.0),
234            (7, 199.0),
235            (6, 638.0),
236            (8, 18.0),
237            (3, 10.0),
238            (4, 53.0),
239        ];
240
241        let freqs = ds.label_frequencies();
242        assert!(compare_to.into_iter().all(|(key, val)| {
243            freqs
244                .get(&key)
245                .map(|x| abs_diff_eq!(*x, val))
246                .unwrap_or(false)
247        }));
248
249        // perform correlation analysis and assert that fixed acidity and citric acid are
250        // correlated
251        let _pcc = ds.pearson_correlation_with_p_value(100);
252        //assert_abs_diff_eq!(pcc.get_p_values().unwrap()[1], 0.05, epsilon = 0.05);
253    }
254
255    #[cfg(feature = "linnerud")]
256    #[test]
257    fn test_linnerud() {
258        let ds = linnerud();
259
260        // check that we have the right amount of data
261        assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (20, 3, 3));
262
263        // check for feature names
264        let feature_names = vec!["Chins", "Situps", "Jumps"];
265        assert_eq!(ds.feature_names(), feature_names);
266
267        // check for target names
268        let target_names = vec!["Weight", "Waist", "Pulse"];
269        assert_eq!(ds.target_names(), target_names);
270
271        // get the mean per target: Weight, Waist, Pulse
272        let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
273        assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);
274    }
275}