linfa_clustering/k_means/
hyperparams.rs

1use crate::KMeansParamsError;
2
3use super::init::KMeansInit;
4use linfa::prelude::*;
5use linfa::Float;
6use linfa_nn::distance::Distance;
7use ndarray_rand::rand::Rng;
8#[cfg(feature = "serde")]
9use serde_crate::{Deserialize, Serialize};
10
11#[cfg_attr(
12    feature = "serde",
13    derive(Serialize, Deserialize),
14    serde(crate = "serde_crate")
15)]
16#[derive(Clone, Debug, PartialEq)]
17/// The set of hyperparameters that can be specified for the execution of
18/// the [K-means algorithm](crate::KMeans).
19pub struct KMeansValidParams<F: Float, R: Rng, D: Distance<F>> {
20    /// Number of time the k-means algorithm will be run with different centroid seeds.
21    n_runs: usize,
22    /// The training is considered complete if the euclidean distance
23    /// between the old set of centroids and the new set of centroids
24    /// after a training iteration is lower or equal than `tolerance`.
25    tolerance: F,
26    /// We exit the training loop when the number of training iterations
27    /// exceeds `max_n_iterations` even if the `tolerance` convergence
28    /// condition has not been met.
29    max_n_iterations: u64,
30    /// The number of clusters we will be looking for in the training dataset.
31    n_clusters: usize,
32    /// The initialization strategy used to initialize the centroids.
33    init: KMeansInit<F>,
34    /// The random number generator
35    rng: R,
36    /// Distance metric used in the centroid assignment step
37    dist_fn: D,
38}
39
40#[derive(Clone, Debug, PartialEq)]
41#[cfg_attr(
42    feature = "serde",
43    derive(Serialize, Deserialize),
44    serde(crate = "serde_crate")
45)]
46/// An helper struct used to construct a set of [valid hyperparameters](KMeansParams) for
47/// the [K-means algorithm](crate::KMeans) (using the builder pattern).
48pub struct KMeansParams<F: Float, R: Rng, D: Distance<F>>(KMeansValidParams<F, R, D>);
49
50impl<F: Float, R: Rng, D: Distance<F>> KMeansParams<F, R, D> {
51    /// `new` lets us configure our training algorithm parameters:
52    /// * we will be looking for `n_clusters` in the training dataset;
53    /// * the training is considered complete if the euclidean distance
54    ///   between the old set of centroids and the new set of centroids
55    ///   after a training iteration is lower or equal than `tolerance`;
56    /// * we exit the training loop when the number of training iterations
57    ///   exceeds `max_n_iterations` even if the `tolerance` convergence
58    ///   condition has not been met.
59    /// * As KMeans convergence depends on centroids initialization
60    ///   we run the algorithm `n_runs` times and we keep the best outputs
61    ///   in terms of inertia that the ones which minimizes the sum of squared
62    ///   euclidean distances to the closest centroid for all observations.
63    ///
64    /// Defaults are provided if optional parameters are not specified:
65    /// * `tolerance = 1e-4`
66    /// * `max_n_iterations = 300`
67    /// * `n_runs = 10`
68    /// * `init = KMeansPlusPlus`
69    pub fn new(n_clusters: usize, rng: R, dist_fn: D) -> Self {
70        Self(KMeansValidParams {
71            n_runs: 10,
72            tolerance: F::cast(1e-4),
73            max_n_iterations: 300,
74            n_clusters,
75            init: KMeansInit::KMeansPlusPlus,
76            rng,
77            dist_fn,
78        })
79    }
80
81    /// Change the value of `n_runs`
82    pub fn n_runs(mut self, n_runs: usize) -> Self {
83        self.0.n_runs = n_runs;
84        self
85    }
86
87    /// Change the value of `tolerance`
88    pub fn tolerance(mut self, tolerance: F) -> Self {
89        self.0.tolerance = tolerance;
90        self
91    }
92
93    /// Change the value of `max_n_iterations`
94    pub fn max_n_iterations(mut self, max_n_iterations: u64) -> Self {
95        self.0.max_n_iterations = max_n_iterations;
96        self
97    }
98
99    /// Change the value of `init`
100    pub fn init_method(mut self, init: KMeansInit<F>) -> Self {
101        self.0.init = init;
102        self
103    }
104}
105
106impl<F: Float, R: Rng, D: Distance<F>> ParamGuard for KMeansParams<F, R, D> {
107    type Checked = KMeansValidParams<F, R, D>;
108    type Error = KMeansParamsError;
109
110    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
111        if self.0.n_clusters == 0 {
112            Err(KMeansParamsError::NClusters)
113        } else if self.0.n_runs == 0 {
114            Err(KMeansParamsError::NRuns)
115        } else if self.0.tolerance <= F::zero() {
116            Err(KMeansParamsError::Tolerance)
117        } else if self.0.max_n_iterations == 0 {
118            Err(KMeansParamsError::MaxIterations)
119        } else {
120            Ok(&self.0)
121        }
122    }
123
124    fn check(self) -> Result<Self::Checked, Self::Error> {
125        self.check_ref()?;
126        Ok(self.0)
127    }
128}
129
130impl<F: Float, R: Rng, D: Distance<F>> KMeansValidParams<F, R, D> {
131    /// The final results will be the best output of n_runs consecutive runs in terms of inertia.
132    pub fn n_runs(&self) -> usize {
133        self.n_runs
134    }
135
136    /// The training is considered complete if the euclidean distance
137    /// between the old set of centroids and the new set of centroids
138    /// after a training iteration is lower or equal than `tolerance`.
139    pub fn tolerance(&self) -> F {
140        self.tolerance
141    }
142
143    /// We exit the training loop when the number of training iterations
144    /// exceeds `max_n_iterations` even if the `tolerance` convergence
145    /// condition has not been met.
146    pub fn max_n_iterations(&self) -> u64 {
147        self.max_n_iterations
148    }
149
150    /// The number of clusters we will be looking for in the training dataset.
151    pub fn n_clusters(&self) -> usize {
152        self.n_clusters
153    }
154
155    /// Cluster initialization strategy
156    pub fn init_method(&self) -> &KMeansInit<F> {
157        &self.init
158    }
159
160    /// Returns the random generator
161    pub fn rng(&self) -> &R {
162        &self.rng
163    }
164
165    /// Returns the distance metric
166    pub fn dist_fn(&self) -> &D {
167        &self.dist_fn
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use crate::{KMeans, KMeansParams, KMeansParamsError, KMeansValidParams};
174    use linfa::ParamGuard;
175    use linfa_nn::distance::L2Dist;
176    use rand_xoshiro::Xoshiro256Plus;
177
178    #[test]
179    fn autotraits() {
180        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
181        has_autotraits::<KMeansParams<f64, Xoshiro256Plus, L2Dist>>();
182        has_autotraits::<KMeansValidParams<f64, Xoshiro256Plus, L2Dist>>();
183    }
184
185    #[test]
186    fn n_clusters_cannot_be_zero() {
187        let res = KMeans::<f32, _>::params(0).check();
188        assert!(matches!(res, Err(KMeansParamsError::NClusters)))
189    }
190
191    #[test]
192    fn tolerance_has_to_positive() {
193        let res = KMeans::params(1).tolerance(-1.).check();
194        assert!(matches!(res, Err(KMeansParamsError::Tolerance)))
195    }
196
197    #[test]
198    fn tolerance_cannot_be_zero() {
199        let res = KMeans::params(1).tolerance(0.).check();
200        assert!(matches!(res, Err(KMeansParamsError::Tolerance)))
201    }
202
203    #[test]
204    fn max_n_iterations_cannot_be_zero() {
205        let res = KMeans::params(1).tolerance(1.).max_n_iterations(0).check();
206        assert!(matches!(res, Err(KMeansParamsError::MaxIterations)))
207    }
208
209    #[test]
210    fn n_runs_cannot_be_zero() {
211        let res = KMeans::params(1).tolerance(1.).n_runs(0).check();
212        assert!(matches!(res, Err(KMeansParamsError::NRuns)))
213    }
214}