1use std::io::Read;
2
3use csv::ReaderBuilder;
4use flate2::read::GzDecoder;
5use linfa::Dataset;
6use ndarray::prelude::*;
7use ndarray_csv::{Array2Reader, ReadError};
8
9pub fn array_from_gz_csv<R: Read>(
11 gz: R,
12 has_headers: bool,
13 separator: u8,
14) -> Result<Array2<f64>, ReadError> {
15 let file = GzDecoder::new(gz);
17 array_from_csv(file, has_headers, separator)
18}
19
20pub fn array_from_csv<R: Read>(
22 csv: R,
23 has_headers: bool,
24 separator: u8,
25) -> Result<Array2<f64>, ReadError> {
26 let mut reader = ReaderBuilder::new()
28 .has_headers(has_headers)
29 .delimiter(separator)
30 .from_reader(csv);
31
32 reader.deserialize_array2_dynamic()
34}
35
36#[cfg(feature = "iris")]
37pub fn iris() -> Dataset<f64, usize, Ix1> {
40 let data = include_bytes!("../data/iris.csv.gz");
41 let array = array_from_gz_csv(&data[..], true, b',').unwrap();
42
43 let (data, targets) = (
44 array.slice(s![.., 0..4]).to_owned(),
45 array.column(4).to_owned(),
46 );
47
48 let feature_names = vec!["sepal length", "sepal width", "petal length", "petal width"];
49
50 Dataset::new(data, targets)
51 .map_targets(|x| *x as usize)
52 .with_feature_names(feature_names)
53}
54
55#[cfg(feature = "diabetes")]
56pub fn diabetes() -> Dataset<f64, f64, Ix1> {
58 let data = include_bytes!("../data/diabetes_data.csv.gz");
59 let data = array_from_gz_csv(&data[..], true, b',').unwrap();
60
61 let targets = include_bytes!("../data/diabetes_target.csv.gz");
62 let targets = array_from_gz_csv(&targets[..], true, b',')
63 .unwrap()
64 .column(0)
65 .to_owned();
66
67 let feature_names = vec![
68 "age",
69 "sex",
70 "body mass index",
71 "blood pressure",
72 "t-cells",
73 "low-density lipoproteins",
74 "high-density lipoproteins",
75 "thyroid stimulating hormone",
76 "lamotrigine",
77 "blood sugar level",
78 ];
79
80 Dataset::new(data, targets).with_feature_names(feature_names)
81}
82
83#[cfg(feature = "winequality")]
84pub fn winequality() -> Dataset<f64, usize, Ix1> {
86 let data = include_bytes!("../data/winequality-red.csv.gz");
87 let array = array_from_gz_csv(&data[..], true, b',').unwrap();
88
89 let (data, targets) = (
90 array.slice(s![.., 0..11]).to_owned(),
91 array.column(11).to_owned(),
92 );
93
94 let feature_names = vec![
95 "fixed acidity",
96 "volatile acidity",
97 "citric acid",
98 "residual sugar",
99 "chlorides",
100 "free sulfur dioxide",
101 "total sulfur dioxide",
102 "density",
103 "pH",
104 "sulphates",
105 "alcohol",
106 ];
107
108 Dataset::new(data, targets)
109 .map_targets(|x| *x as usize)
110 .with_feature_names(feature_names)
111}
112
113#[cfg(feature = "linnerud")]
114pub fn linnerud() -> Dataset<f64, f64> {
127 let input_data = include_bytes!("../data/linnerud_exercise.csv.gz");
128 let input_array = array_from_gz_csv(&input_data[..], true, b',').unwrap();
129
130 let output_data = include_bytes!("../data/linnerud_physiological.csv.gz");
131 let output_array = array_from_gz_csv(&output_data[..], true, b',').unwrap();
132
133 let feature_names = vec!["Chins", "Situps", "Jumps"];
134 let target_names = vec!["Weight", "Waist", "Pulse"];
135
136 Dataset::new(input_array, output_array)
137 .with_feature_names(feature_names)
138 .with_target_names(target_names)
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use approx::assert_abs_diff_eq;
145 use linfa::prelude::*;
146
147 #[cfg(feature = "iris")]
148 #[test]
149 fn test_iris() {
150 let ds = iris();
151
152 assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (150, 4, 1));
154
155 assert_eq!(
157 ds.feature_names(),
158 &["sepal length", "sepal width", "petal length", "petal width"]
159 );
160
161 assert_abs_diff_eq!(
163 ds.label_frequencies()
164 .into_iter()
165 .map(|b| b.1)
166 .collect::<Array1<_>>(),
167 array![50., 50., 50.]
168 );
169
170 let _pcc = ds.pearson_correlation_with_p_value(100);
172 let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
177 assert_abs_diff_eq!(
178 mean_features,
179 array![5.84, 3.05, 3.75, 1.20],
180 epsilon = 0.01
181 );
182 }
183
184 #[cfg(feature = "diabetes")]
185 #[test]
186 fn test_diabetes() {
187 let ds = diabetes();
188
189 assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (441, 10, 1));
191
192 let _pcc = ds.pearson_correlation_with_p_value(100);
195 let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
199 assert_abs_diff_eq!(mean_features, Array1::zeros(10), epsilon = 0.005);
200 }
201
202 #[cfg(feature = "winequality")]
203 #[test]
204 fn test_winequality() {
205 use approx::abs_diff_eq;
206
207 let ds = winequality();
208
209 assert_eq!(
211 (ds.nsamples(), ds.nfeatures(), ds.ntargets()),
212 (1599, 11, 1)
213 );
214
215 let feature_names = vec![
217 "fixed acidity",
218 "volatile acidity",
219 "citric acid",
220 "residual sugar",
221 "chlorides",
222 "free sulfur dioxide",
223 "total sulfur dioxide",
224 "density",
225 "pH",
226 "sulphates",
227 "alcohol",
228 ];
229 assert_eq!(ds.feature_names(), feature_names);
230
231 let compare_to = vec![
233 (5, 681.0),
234 (7, 199.0),
235 (6, 638.0),
236 (8, 18.0),
237 (3, 10.0),
238 (4, 53.0),
239 ];
240
241 let freqs = ds.label_frequencies();
242 assert!(compare_to.into_iter().all(|(key, val)| {
243 freqs
244 .get(&key)
245 .map(|x| abs_diff_eq!(*x, val))
246 .unwrap_or(false)
247 }));
248
249 let _pcc = ds.pearson_correlation_with_p_value(100);
252 }
254
255 #[cfg(feature = "linnerud")]
256 #[test]
257 fn test_linnerud() {
258 let ds = linnerud();
259
260 assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (20, 3, 3));
262
263 let feature_names = vec!["Chins", "Situps", "Jumps"];
265 assert_eq!(ds.feature_names(), feature_names);
266
267 let target_names = vec!["Weight", "Waist", "Pulse"];
269 assert_eq!(ds.target_names(), target_names);
270
271 let mean_targets = ds.targets().mean_axis(Axis(0)).unwrap();
273 assert_abs_diff_eq!(mean_targets, array![178.6, 35.4, 56.1]);
274 }
275}