linfa_tsne/
hyperparams.rs

1use linfa::{Float, ParamGuard};
2use ndarray_rand::rand::{rngs::SmallRng, Rng, SeedableRng};
3
4use crate::TSneError;
5
6/// The t-SNE algorithm is a statistical method for visualizing high-dimensional data by
7/// giving each datapoint a location in a two or three-dimensional map.
8///
9/// The t-SNE algorithm comprises two main stages. First, t-SNE constructs a probability
10/// distribution over pairs of high-dimensional objects in such a way that similar objects
11/// are assigned a higher probability while dissimilar points are assigned a lower probability.
12/// Second, t-SNE defines a similar probability distribution over the points in the low-dimensional
13/// map, and it minimizes the Kullback–Leibler divergence (KL divergence) between the two
14/// distributions with respect to the locations of the points in the map.
15///
16/// This crate wraps the [bhtsne](https://github.com/frjnn/bhtsne) crate for the linfa project. It
17/// implements the exact t-SNE, as well as the Barnes-Hut approximation.
18///
19/// # Examples
20///
21/// ```no_run
22/// use linfa::traits::Transformer;
23/// use linfa_tsne::TSneParams;
24///
25/// let ds = linfa_datasets::iris();
26///
27/// let ds = TSneParams::embedding_size(2)
28///     .perplexity(10.0)
29///     .approx_threshold(0.6)
30///     .transform(ds);
31/// ```
32///
33/// A verified hyper-parameter set ready for prediction
34#[derive(Debug, Clone, PartialEq)]
35pub struct TSneValidParams<F, R> {
36    embedding_size: usize,
37    approx_threshold: F,
38    perplexity: F,
39    max_iter: usize,
40    preliminary_iter: Option<usize>,
41    rng: R,
42}
43
44impl<F: Float, R> TSneValidParams<F, R> {
45    pub fn embedding_size(&self) -> usize {
46        self.embedding_size
47    }
48
49    pub fn approx_threshold(&self) -> F {
50        self.approx_threshold
51    }
52
53    pub fn perplexity(&self) -> F {
54        self.perplexity
55    }
56
57    pub fn max_iter(&self) -> usize {
58        self.max_iter
59    }
60
61    pub fn preliminary_iter(&self) -> &Option<usize> {
62        &self.preliminary_iter
63    }
64
65    pub fn rng(&self) -> &R {
66        &self.rng
67    }
68}
69
70#[derive(Debug, Clone, PartialEq)]
71pub struct TSneParams<F, R>(TSneValidParams<F, R>);
72
73impl<F: Float> TSneParams<F, SmallRng> {
74    /// Create a t-SNE param set with given embedding size
75    ///
76    /// # Defaults to:
77    ///  * `approx_threshold`: 0.5
78    ///  * `perplexity`: 5.0
79    ///  * `max_iter`: 2000
80    ///  * `rng`: SmallRng with seed 42
81    pub fn embedding_size(embedding_size: usize) -> TSneParams<F, SmallRng> {
82        Self::embedding_size_with_rng(embedding_size, SmallRng::seed_from_u64(42))
83    }
84}
85
86impl<F: Float, R: Rng + Clone> TSneParams<F, R> {
87    /// Create a t-SNE param set with given embedding size and random number generator
88    ///
89    /// # Defaults to:
90    ///  * `approx_threshold`: 0.5
91    ///  * `perplexity`: 5.0
92    ///  * `max_iter`: 2000
93    pub fn embedding_size_with_rng(embedding_size: usize, rng: R) -> TSneParams<F, R> {
94        Self(TSneValidParams {
95            embedding_size,
96            rng,
97            approx_threshold: F::cast(0.5),
98            perplexity: F::cast(5.0),
99            max_iter: 2000,
100            preliminary_iter: None,
101        })
102    }
103
104    /// Set the approximation threshold of the Barnes Hut algorithm
105    ///
106    /// The threshold decides whether a cluster centroid can be used as a summary for the whole
107    /// area. This was proposed by Barnes and Hut and compares the ratio of cell radius and
108    /// distance to a factor theta. This threshold lies in range (0, inf) where a value of 0
109    /// disables approximation and a positive value approximates the gradient with the cell center.
110    pub fn approx_threshold(mut self, threshold: F) -> Self {
111        self.0.approx_threshold = threshold;
112
113        self
114    }
115
116    /// Set the perplexity of the t-SNE algorithm
117    pub fn perplexity(mut self, perplexity: F) -> Self {
118        self.0.perplexity = perplexity;
119
120        self
121    }
122
123    /// Set the maximal number of iterations
124    pub fn max_iter(mut self, max_iter: usize) -> Self {
125        self.0.max_iter = max_iter;
126
127        self
128    }
129
130    /// Set the number of iterations after which the true P distribution is used
131    ///
132    /// At the beginning of the training process it is useful to multiply the P distribution values
133    /// by a certain factor (here 12x) to get the global view right. After this number of iterations
134    /// the true P distribution value is used. If None the number is estimated.
135    pub fn preliminary_iter(mut self, num_iter: usize) -> Self {
136        self.0.preliminary_iter = Some(num_iter);
137
138        self
139    }
140}
141
142impl<F: Float, R> ParamGuard for TSneParams<F, R> {
143    type Checked = TSneValidParams<F, R>;
144    type Error = TSneError;
145
146    /// Validates parameters
147    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
148        if self.0.perplexity.is_negative() {
149            Err(TSneError::NegativePerplexity)
150        } else if self.0.approx_threshold.is_negative() {
151            Err(TSneError::NegativeApproximationThreshold)
152        } else {
153            Ok(&self.0)
154        }
155    }
156
157    fn check(self) -> Result<Self::Checked, Self::Error> {
158        self.check_ref()?;
159        Ok(self.0)
160    }
161}