1#![doc = include_str!("../README.md")]
2
3use ndarray::Array2;
4use ndarray_rand::rand::Rng;
5use ndarray_rand::rand_distr::Normal;
6
7use linfa::{dataset::DatasetBase, traits::Transformer, Float, ParamGuard};
8
9mod error;
10mod hyperparams;
11
12pub use error::{Result, TSneError};
13pub use hyperparams::{TSneParams, TSneValidParams};
14
15impl<F: Float, R: Rng + Clone> Transformer<Array2<F>, Result<Array2<F>>> for TSneValidParams<F, R> {
16 fn transform(&self, mut data: Array2<F>) -> Result<Array2<F>> {
17 let (nfeatures, nsamples) = (data.ncols(), data.nrows());
18
19 if self.embedding_size() > nfeatures {
21 return Err(TSneError::EmbeddingSizeTooLarge);
22 }
23
24 if F::cast(nsamples - 1) < F::cast(3) * self.perplexity() {
25 return Err(TSneError::PerplexityTooLarge);
26 }
27
28 let preliminary_iter = match self.preliminary_iter() {
30 Some(x) => *x,
31 None => usize::min(self.max_iter() / 2, 250),
32 };
33
34 let data = data.as_slice_mut().unwrap();
35
36 let mut rng = self.rng().clone();
37 let normal = Normal::new(0.0, 1e-4 * 10e-4).unwrap();
38
39 let mut embedding: Vec<F> = (0..nsamples * self.embedding_size())
40 .map(|_| rng.sample(normal))
41 .map(F::cast)
42 .collect();
43
44 bhtsne::run(
45 data,
46 nsamples,
47 nfeatures,
48 &mut embedding,
49 self.embedding_size(),
50 self.perplexity(),
51 self.approx_threshold(),
52 true,
53 self.max_iter() as u64,
54 preliminary_iter as u64,
55 preliminary_iter as u64,
56 );
57
58 Array2::from_shape_vec((nsamples, self.embedding_size()), embedding).map_err(|e| e.into())
59 }
60}
61
62impl<F: Float, R: Rng + Clone> Transformer<Array2<F>, Result<Array2<F>>> for TSneParams<F, R> {
63 fn transform(&self, x: Array2<F>) -> Result<Array2<F>> {
64 self.check_ref()?.transform(x)
65 }
66}
67
68impl<T, F: Float, R: Rng + Clone>
69 Transformer<DatasetBase<Array2<F>, T>, Result<DatasetBase<Array2<F>, T>>>
70 for TSneValidParams<F, R>
71{
72 fn transform(&self, ds: DatasetBase<Array2<F>, T>) -> Result<DatasetBase<Array2<F>, T>> {
73 let DatasetBase {
74 records,
75 targets,
76 weights,
77 ..
78 } = ds;
79
80 self.transform(records)
81 .map(|new_records| DatasetBase::new(new_records, targets).with_weights(weights))
82 }
83}
84
85impl<T, F: Float, R: Rng + Clone>
86 Transformer<DatasetBase<Array2<F>, T>, Result<DatasetBase<Array2<F>, T>>> for TSneParams<F, R>
87{
88 fn transform(&self, ds: DatasetBase<Array2<F>, T>) -> Result<DatasetBase<Array2<F>, T>> {
89 self.check_ref()?.transform(ds)
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use approx::assert_abs_diff_eq;
97 use ndarray::{Array, Array1, Axis};
98 use ndarray_rand::{rand_distr::Normal, RandomExt};
99 use rand::{rngs::SmallRng, SeedableRng};
100
101 use linfa::{dataset::Dataset, metrics::SilhouetteScore};
102
103 #[test]
104 fn autotraits() {
105 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
106 has_autotraits::<TSneParams<f64, rand::distributions::Uniform<f64>>>();
107 has_autotraits::<TSneValidParams<f64, rand::distributions::Uniform<f64>>>();
108 has_autotraits::<TSneError>();
109 }
110
111 #[test]
112 fn iris_separate() -> Result<()> {
113 let ds = linfa_datasets::iris();
114 let rng = SmallRng::seed_from_u64(42);
115
116 let ds = TSneParams::embedding_size_with_rng(2, rng)
117 .perplexity(10.0)
118 .approx_threshold(0.0)
119 .transform(ds)?;
120
121 assert!(ds.silhouette_score()? > 0.6);
122
123 Ok(())
124 }
125
126 #[test]
127 fn blob_separate() -> Result<()> {
128 let mut rng = SmallRng::seed_from_u64(42);
129 let entries: Array2<f64> = ndarray::concatenate(
130 Axis(0),
131 &[
132 Array::random_using((100, 2), Normal::new(-10., 0.5).unwrap(), &mut rng).view(),
133 Array::random_using((100, 2), Normal::new(10., 0.5).unwrap(), &mut rng).view(),
134 ],
135 )?;
136
137 let targets = (0..200).map(|x| x < 100).collect::<Array1<_>>();
138 let dataset = Dataset::new(entries, targets);
139
140 let ds = TSneParams::embedding_size_with_rng(2, rng)
141 .perplexity(60.0)
142 .approx_threshold(0.0)
143 .transform(dataset)?;
144
145 assert_abs_diff_eq!(ds.silhouette_score()?, 0.945, epsilon = 0.01);
146
147 Ok(())
148 }
149
150 #[test]
151 #[should_panic(expected = "NegativePerplexity")]
152 fn perplexity_panic() {
153 let ds = linfa_datasets::iris();
154
155 TSneParams::embedding_size(2)
156 .perplexity(-10.0)
157 .transform(ds)
158 .unwrap();
159 }
160
161 #[test]
162 #[should_panic(expected = "NegativeApproximationThreshold")]
163 fn approx_threshold_panic() {
164 let ds = linfa_datasets::iris();
165
166 TSneParams::embedding_size(2)
167 .approx_threshold(-10.0)
168 .transform(ds)
169 .unwrap();
170 }
171 #[test]
172 #[should_panic(expected = "EmbeddingSizeTooLarge")]
173 fn embedding_size_panic() {
174 let ds = linfa_datasets::iris();
175
176 TSneParams::embedding_size(5).transform(ds).unwrap();
177 }
178}