linfa_tsne/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use ndarray::Array2;
4use ndarray_rand::rand::Rng;
5use ndarray_rand::rand_distr::Normal;
6
7use linfa::{dataset::DatasetBase, traits::Transformer, Float, ParamGuard};
8
9mod error;
10mod hyperparams;
11
12pub use error::{Result, TSneError};
13pub use hyperparams::{TSneParams, TSneValidParams};
14
15impl<F: Float, R: Rng + Clone> Transformer<Array2<F>, Result<Array2<F>>> for TSneValidParams<F, R> {
16    fn transform(&self, mut data: Array2<F>) -> Result<Array2<F>> {
17        let (nfeatures, nsamples) = (data.ncols(), data.nrows());
18
19        // validate parameter-data constraints
20        if self.embedding_size() > nfeatures {
21            return Err(TSneError::EmbeddingSizeTooLarge);
22        }
23
24        if F::cast(nsamples - 1) < F::cast(3) * self.perplexity() {
25            return Err(TSneError::PerplexityTooLarge);
26        }
27
28        // estimate number of preliminary iterations if not given
29        let preliminary_iter = match self.preliminary_iter() {
30            Some(x) => *x,
31            None => usize::min(self.max_iter() / 2, 250),
32        };
33
34        let data = data.as_slice_mut().unwrap();
35
36        let mut rng = self.rng().clone();
37        let normal = Normal::new(0.0, 1e-4 * 10e-4).unwrap();
38
39        let mut embedding: Vec<F> = (0..nsamples * self.embedding_size())
40            .map(|_| rng.sample(normal))
41            .map(F::cast)
42            .collect();
43
44        bhtsne::run(
45            data,
46            nsamples,
47            nfeatures,
48            &mut embedding,
49            self.embedding_size(),
50            self.perplexity(),
51            self.approx_threshold(),
52            true,
53            self.max_iter() as u64,
54            preliminary_iter as u64,
55            preliminary_iter as u64,
56        );
57
58        Array2::from_shape_vec((nsamples, self.embedding_size()), embedding).map_err(|e| e.into())
59    }
60}
61
62impl<F: Float, R: Rng + Clone> Transformer<Array2<F>, Result<Array2<F>>> for TSneParams<F, R> {
63    fn transform(&self, x: Array2<F>) -> Result<Array2<F>> {
64        self.check_ref()?.transform(x)
65    }
66}
67
68impl<T, F: Float, R: Rng + Clone>
69    Transformer<DatasetBase<Array2<F>, T>, Result<DatasetBase<Array2<F>, T>>>
70    for TSneValidParams<F, R>
71{
72    fn transform(&self, ds: DatasetBase<Array2<F>, T>) -> Result<DatasetBase<Array2<F>, T>> {
73        let DatasetBase {
74            records,
75            targets,
76            weights,
77            ..
78        } = ds;
79
80        self.transform(records)
81            .map(|new_records| DatasetBase::new(new_records, targets).with_weights(weights))
82    }
83}
84
85impl<T, F: Float, R: Rng + Clone>
86    Transformer<DatasetBase<Array2<F>, T>, Result<DatasetBase<Array2<F>, T>>> for TSneParams<F, R>
87{
88    fn transform(&self, ds: DatasetBase<Array2<F>, T>) -> Result<DatasetBase<Array2<F>, T>> {
89        self.check_ref()?.transform(ds)
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use approx::assert_abs_diff_eq;
97    use ndarray::{Array, Array1, Axis};
98    use ndarray_rand::{rand_distr::Normal, RandomExt};
99    use rand::{rngs::SmallRng, SeedableRng};
100
101    use linfa::{dataset::Dataset, metrics::SilhouetteScore};
102
103    #[test]
104    fn autotraits() {
105        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
106        has_autotraits::<TSneParams<f64, rand::distributions::Uniform<f64>>>();
107        has_autotraits::<TSneValidParams<f64, rand::distributions::Uniform<f64>>>();
108        has_autotraits::<TSneError>();
109    }
110
111    #[test]
112    fn iris_separate() -> Result<()> {
113        let ds = linfa_datasets::iris();
114        let rng = SmallRng::seed_from_u64(42);
115
116        let ds = TSneParams::embedding_size_with_rng(2, rng)
117            .perplexity(10.0)
118            .approx_threshold(0.0)
119            .transform(ds)?;
120
121        assert!(ds.silhouette_score()? > 0.6);
122
123        Ok(())
124    }
125
126    #[test]
127    fn blob_separate() -> Result<()> {
128        let mut rng = SmallRng::seed_from_u64(42);
129        let entries: Array2<f64> = ndarray::concatenate(
130            Axis(0),
131            &[
132                Array::random_using((100, 2), Normal::new(-10., 0.5).unwrap(), &mut rng).view(),
133                Array::random_using((100, 2), Normal::new(10., 0.5).unwrap(), &mut rng).view(),
134            ],
135        )?;
136
137        let targets = (0..200).map(|x| x < 100).collect::<Array1<_>>();
138        let dataset = Dataset::new(entries, targets);
139
140        let ds = TSneParams::embedding_size_with_rng(2, rng)
141            .perplexity(60.0)
142            .approx_threshold(0.0)
143            .transform(dataset)?;
144
145        assert_abs_diff_eq!(ds.silhouette_score()?, 0.945, epsilon = 0.01);
146
147        Ok(())
148    }
149
150    #[test]
151    #[should_panic(expected = "NegativePerplexity")]
152    fn perplexity_panic() {
153        let ds = linfa_datasets::iris();
154
155        TSneParams::embedding_size(2)
156            .perplexity(-10.0)
157            .transform(ds)
158            .unwrap();
159    }
160
161    #[test]
162    #[should_panic(expected = "NegativeApproximationThreshold")]
163    fn approx_threshold_panic() {
164        let ds = linfa_datasets::iris();
165
166        TSneParams::embedding_size(2)
167            .approx_threshold(-10.0)
168            .transform(ds)
169            .unwrap();
170    }
171    #[test]
172    #[should_panic(expected = "EmbeddingSizeTooLarge")]
173    fn embedding_size_panic() {
174        let ds = linfa_datasets::iris();
175
176        TSneParams::embedding_size(5).transform(ds).unwrap();
177    }
178}