linfa_datasets/
generate.rs1use linfa::Dataset;
4use ndarray::{s, Array, Array2, ArrayBase, Data, Ix1, Ix2};
5use ndarray_rand::{
6 rand::Rng,
7 rand_distr::{Distribution, StandardNormal},
8 RandomExt,
9};
10
11pub fn blobs(
13 blob_size: usize,
14 blob_centroids: &ArrayBase<impl Data<Elem = f64>, Ix2>,
15 rng: &mut impl Rng,
16) -> Array2<f64> {
17 blobs_with_distribution(blob_size, blob_centroids, StandardNormal, rng)
18}
19
20pub fn blobs_with_distribution(
29 blob_size: usize,
30 blob_centroids: &ArrayBase<impl Data<Elem = f64>, Ix2>,
31 distribution: impl Distribution<f64> + Clone,
32 rng: &mut impl Rng,
33) -> Array2<f64> {
34 let (n_centroids, n_features) = blob_centroids.dim();
35 let mut blobs: Array2<f64> = Array2::zeros((n_centroids * blob_size, n_features));
36
37 for (blob_index, blob_centroid) in blob_centroids.rows().into_iter().enumerate() {
38 let blob = make_blob(blob_size, &blob_centroid, distribution.clone(), rng);
39
40 let indexes = s![blob_index * blob_size..(blob_index + 1) * blob_size, ..];
41 blobs.slice_mut(indexes).assign(&blob);
42 }
43 blobs
44}
45
46fn make_blob(
50 blob_size: usize,
51 blob_centroid: &ArrayBase<impl Data<Elem = f64>, Ix1>,
52 distribution: impl Distribution<f64>,
53 rng: &mut impl Rng,
54) -> Array2<f64> {
55 let shape = (blob_size, blob_centroid.len());
56 let origin_blob: Array2<f64> = Array::random_using(shape, distribution, rng);
57 origin_blob + blob_centroid
58}
59
60pub fn make_dataset<X, Y>(
72 num_rows: usize,
73 num_feats: usize,
74 num_targets: usize,
75 feat_distr: X,
76 target_distr: Y,
77) -> Dataset<f64, f64>
78where
79 X: Distribution<f64>,
80 Y: Distribution<f64>,
81{
82 let features = Array::random((num_rows, num_feats), feat_distr);
83 let targets = Array::random((num_rows, num_targets), target_distr);
84
85 Dataset::new(features, targets)
86}