linfa_tsne/hyperparams.rs
1use linfa::{Float, ParamGuard};
2use ndarray_rand::rand::{rngs::SmallRng, Rng, SeedableRng};
3
4use crate::TSneError;
5
6/// The t-SNE algorithm is a statistical method for visualizing high-dimensional data by
7/// giving each datapoint a location in a two or three-dimensional map.
8///
9/// The t-SNE algorithm comprises two main stages. First, t-SNE constructs a probability
10/// distribution over pairs of high-dimensional objects in such a way that similar objects
11/// are assigned a higher probability while dissimilar points are assigned a lower probability.
12/// Second, t-SNE defines a similar probability distribution over the points in the low-dimensional
13/// map, and it minimizes the Kullback–Leibler divergence (KL divergence) between the two
14/// distributions with respect to the locations of the points in the map.
15///
16/// This crate wraps the [bhtsne](https://github.com/frjnn/bhtsne) crate for the linfa project. It
17/// implements the exact t-SNE, as well as the Barnes-Hut approximation.
18///
19/// # Examples
20///
21/// ```no_run
22/// use linfa::traits::Transformer;
23/// use linfa_tsne::TSneParams;
24///
25/// let ds = linfa_datasets::iris();
26///
27/// let ds = TSneParams::embedding_size(2)
28/// .perplexity(10.0)
29/// .approx_threshold(0.6)
30/// .transform(ds);
31/// ```
32///
33/// A verified hyper-parameter set ready for prediction
34#[derive(Debug, Clone, PartialEq)]
35pub struct TSneValidParams<F, R> {
36 embedding_size: usize,
37 approx_threshold: F,
38 perplexity: F,
39 max_iter: usize,
40 preliminary_iter: Option<usize>,
41 rng: R,
42}
43
44impl<F: Float, R> TSneValidParams<F, R> {
45 pub fn embedding_size(&self) -> usize {
46 self.embedding_size
47 }
48
49 pub fn approx_threshold(&self) -> F {
50 self.approx_threshold
51 }
52
53 pub fn perplexity(&self) -> F {
54 self.perplexity
55 }
56
57 pub fn max_iter(&self) -> usize {
58 self.max_iter
59 }
60
61 pub fn preliminary_iter(&self) -> &Option<usize> {
62 &self.preliminary_iter
63 }
64
65 pub fn rng(&self) -> &R {
66 &self.rng
67 }
68}
69
70#[derive(Debug, Clone, PartialEq)]
71pub struct TSneParams<F, R>(TSneValidParams<F, R>);
72
73impl<F: Float> TSneParams<F, SmallRng> {
74 /// Create a t-SNE param set with given embedding size
75 ///
76 /// # Defaults to:
77 /// * `approx_threshold`: 0.5
78 /// * `perplexity`: 5.0
79 /// * `max_iter`: 2000
80 /// * `rng`: SmallRng with seed 42
81 pub fn embedding_size(embedding_size: usize) -> TSneParams<F, SmallRng> {
82 Self::embedding_size_with_rng(embedding_size, SmallRng::seed_from_u64(42))
83 }
84}
85
86impl<F: Float, R: Rng + Clone> TSneParams<F, R> {
87 /// Create a t-SNE param set with given embedding size and random number generator
88 ///
89 /// # Defaults to:
90 /// * `approx_threshold`: 0.5
91 /// * `perplexity`: 5.0
92 /// * `max_iter`: 2000
93 pub fn embedding_size_with_rng(embedding_size: usize, rng: R) -> TSneParams<F, R> {
94 Self(TSneValidParams {
95 embedding_size,
96 rng,
97 approx_threshold: F::cast(0.5),
98 perplexity: F::cast(5.0),
99 max_iter: 2000,
100 preliminary_iter: None,
101 })
102 }
103
104 /// Set the approximation threshold of the Barnes Hut algorithm
105 ///
106 /// The threshold decides whether a cluster centroid can be used as a summary for the whole
107 /// area. This was proposed by Barnes and Hut and compares the ratio of cell radius and
108 /// distance to a factor theta. This threshold lies in range (0, inf) where a value of 0
109 /// disables approximation and a positive value approximates the gradient with the cell center.
110 pub fn approx_threshold(mut self, threshold: F) -> Self {
111 self.0.approx_threshold = threshold;
112
113 self
114 }
115
116 /// Set the perplexity of the t-SNE algorithm
117 pub fn perplexity(mut self, perplexity: F) -> Self {
118 self.0.perplexity = perplexity;
119
120 self
121 }
122
123 /// Set the maximal number of iterations
124 pub fn max_iter(mut self, max_iter: usize) -> Self {
125 self.0.max_iter = max_iter;
126
127 self
128 }
129
130 /// Set the number of iterations after which the true P distribution is used
131 ///
132 /// At the beginning of the training process it is useful to multiply the P distribution values
133 /// by a certain factor (here 12x) to get the global view right. After this number of iterations
134 /// the true P distribution value is used. If None the number is estimated.
135 pub fn preliminary_iter(mut self, num_iter: usize) -> Self {
136 self.0.preliminary_iter = Some(num_iter);
137
138 self
139 }
140}
141
142impl<F: Float, R> ParamGuard for TSneParams<F, R> {
143 type Checked = TSneValidParams<F, R>;
144 type Error = TSneError;
145
146 /// Validates parameters
147 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
148 if self.0.perplexity.is_negative() {
149 Err(TSneError::NegativePerplexity)
150 } else if self.0.approx_threshold.is_negative() {
151 Err(TSneError::NegativeApproximationThreshold)
152 } else {
153 Ok(&self.0)
154 }
155 }
156
157 fn check(self) -> Result<Self::Checked, Self::Error> {
158 self.check_ref()?;
159 Ok(self.0)
160 }
161}