Skip to main content

linfa_clustering/k_means/
hyperparams.rs

1use crate::{KMeansAlgorithm, KMeansParamsError};
2
3use super::init::KMeansInit;
4use linfa::prelude::*;
5use linfa::Float;
6use linfa_nn::distance::Distance;
7use ndarray_rand::rand::Rng;
8#[cfg(feature = "serde")]
9use serde_crate::{Deserialize, Serialize};
10
11#[cfg_attr(
12    feature = "serde",
13    derive(Serialize, Deserialize),
14    serde(crate = "serde_crate")
15)]
16#[derive(Clone, Debug, PartialEq)]
17/// The set of hyperparameters that can be specified for the execution of
18/// the [K-means algorithm](crate::KMeans).
19pub struct KMeansValidParams<F: Float, R: Rng, D: Distance<F>> {
20    /// Number of time the k-means algorithm will be run with different centroid seeds.
21    n_runs: usize,
22    /// The training is considered complete if the euclidean distance
23    /// between the old set of centroids and the new set of centroids
24    /// after a training iteration is lower or equal than `tolerance`.
25    tolerance: F,
26    /// We exit the training loop when the number of training iterations
27    /// exceeds `max_n_iterations` even if the `tolerance` convergence
28    /// condition has not been met.
29    max_n_iterations: u64,
30    /// The number of clusters we will be looking for in the training dataset.
31    n_clusters: usize,
32    /// The initialization strategy used to initialize the centroids.
33    init: KMeansInit<F>,
34    /// The random number generator
35    rng: R,
36    /// Distance metric used in the centroid assignment step
37    dist_fn: D,
38    /// Algorithm variant used for the assignment step
39    algorithm: KMeansAlgorithm,
40}
41
42#[derive(Clone, Debug, PartialEq)]
43#[cfg_attr(
44    feature = "serde",
45    derive(Serialize, Deserialize),
46    serde(crate = "serde_crate")
47)]
48/// An helper struct used to construct a set of [valid hyperparameters](KMeansParams) for
49/// the [K-means algorithm](crate::KMeans) (using the builder pattern).
50pub struct KMeansParams<F: Float, R: Rng, D: Distance<F>>(KMeansValidParams<F, R, D>);
51
52impl<F: Float, R: Rng, D: Distance<F>> KMeansParams<F, R, D> {
53    /// `new` lets us configure our training algorithm parameters:
54    /// * we will be looking for `n_clusters` in the training dataset;
55    /// * the training is considered complete if the euclidean distance
56    ///   between the old set of centroids and the new set of centroids
57    ///   after a training iteration is lower or equal than `tolerance`;
58    /// * we exit the training loop when the number of training iterations
59    ///   exceeds `max_n_iterations` even if the `tolerance` convergence
60    ///   condition has not been met.
61    /// * As KMeans convergence depends on centroids initialization
62    ///   we run the algorithm `n_runs` times and we keep the best outputs
63    ///   in terms of inertia that the ones which minimizes the sum of squared
64    ///   euclidean distances to the closest centroid for all observations.
65    ///
66    /// Defaults are provided if optional parameters are not specified:
67    /// * `tolerance = 1e-4`
68    /// * `max_n_iterations = 300`
69    /// * `n_runs = 10`
70    /// * `init = KMeansPlusPlus`
71    pub fn new(n_clusters: usize, rng: R, dist_fn: D) -> Self {
72        Self(KMeansValidParams {
73            n_runs: 10,
74            tolerance: F::cast(1e-4),
75            max_n_iterations: 300,
76            n_clusters,
77            init: KMeansInit::KMeansPlusPlus,
78            rng,
79            dist_fn,
80            algorithm: KMeansAlgorithm::Lloyd,
81        })
82    }
83
84    /// Change the value of `n_runs`
85    pub fn n_runs(mut self, n_runs: usize) -> Self {
86        self.0.n_runs = n_runs;
87        self
88    }
89
90    /// Change the value of `tolerance`
91    pub fn tolerance(mut self, tolerance: F) -> Self {
92        self.0.tolerance = tolerance;
93        self
94    }
95
96    /// Change the value of `max_n_iterations`
97    pub fn max_n_iterations(mut self, max_n_iterations: u64) -> Self {
98        self.0.max_n_iterations = max_n_iterations;
99        self
100    }
101
102    /// Change the value of `init`
103    pub fn init_method(mut self, init: KMeansInit<F>) -> Self {
104        self.0.init = init;
105        self
106    }
107
108    /// Select the variant used for the assignment step.
109    ///
110    /// See [`KMeansAlgorithm`] for the available variants and when to prefer each.
111    /// Defaults to [`KMeansAlgorithm::Lloyd`]. This setting only affects batch `fit`;
112    /// `fit_with` (Mini-Batch K-means) always uses Lloyd's update and will reject
113    /// `Hamerly` with [`KMeansParamsError::IncrementalHamerly`](crate::KMeansParamsError::IncrementalHamerly).
114    pub fn algorithm(mut self, algorithm: KMeansAlgorithm) -> Self {
115        self.0.algorithm = algorithm;
116        self
117    }
118}
119
120impl<F: Float, R: Rng, D: Distance<F>> ParamGuard for KMeansParams<F, R, D> {
121    type Checked = KMeansValidParams<F, R, D>;
122    type Error = KMeansParamsError;
123
124    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
125        if self.0.n_clusters == 0 {
126            Err(KMeansParamsError::NClusters)
127        } else if self.0.n_runs == 0 {
128            Err(KMeansParamsError::NRuns)
129        } else if self.0.tolerance <= F::zero() {
130            Err(KMeansParamsError::Tolerance)
131        } else if self.0.max_n_iterations == 0 {
132            Err(KMeansParamsError::MaxIterations)
133        } else {
134            Ok(&self.0)
135        }
136    }
137
138    fn check(self) -> Result<Self::Checked, Self::Error> {
139        self.check_ref()?;
140        Ok(self.0)
141    }
142}
143
144impl<F: Float, R: Rng, D: Distance<F>> KMeansValidParams<F, R, D> {
145    /// The final results will be the best output of n_runs consecutive runs in terms of inertia.
146    pub fn n_runs(&self) -> usize {
147        self.n_runs
148    }
149
150    /// The training is considered complete if the euclidean distance
151    /// between the old set of centroids and the new set of centroids
152    /// after a training iteration is lower or equal than `tolerance`.
153    pub fn tolerance(&self) -> F {
154        self.tolerance
155    }
156
157    /// We exit the training loop when the number of training iterations
158    /// exceeds `max_n_iterations` even if the `tolerance` convergence
159    /// condition has not been met.
160    pub fn max_n_iterations(&self) -> u64 {
161        self.max_n_iterations
162    }
163
164    /// The number of clusters we will be looking for in the training dataset.
165    pub fn n_clusters(&self) -> usize {
166        self.n_clusters
167    }
168
169    /// Cluster initialization strategy
170    pub fn init_method(&self) -> &KMeansInit<F> {
171        &self.init
172    }
173
174    /// Returns the random generator
175    pub fn rng(&self) -> &R {
176        &self.rng
177    }
178
179    /// Returns the distance metric
180    pub fn dist_fn(&self) -> &D {
181        &self.dist_fn
182    }
183
184    /// The [`KMeansAlgorithm`] variant used by batch `fit` for the assignment step.
185    pub fn algorithm(&self) -> &KMeansAlgorithm {
186        &self.algorithm
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use crate::{KMeans, KMeansParams, KMeansParamsError, KMeansValidParams};
193    use linfa::ParamGuard;
194    use linfa_nn::distance::L2Dist;
195    use rand_xoshiro::Xoshiro256Plus;
196
197    #[test]
198    fn autotraits() {
199        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
200        has_autotraits::<KMeansParams<f64, Xoshiro256Plus, L2Dist>>();
201        has_autotraits::<KMeansValidParams<f64, Xoshiro256Plus, L2Dist>>();
202    }
203
204    #[test]
205    fn n_clusters_cannot_be_zero() {
206        let res = KMeans::<f32, _>::params(0).check();
207        assert!(matches!(res, Err(KMeansParamsError::NClusters)))
208    }
209
210    #[test]
211    fn tolerance_has_to_positive() {
212        let res = KMeans::params(1).tolerance(-1.).check();
213        assert!(matches!(res, Err(KMeansParamsError::Tolerance)))
214    }
215
216    #[test]
217    fn tolerance_cannot_be_zero() {
218        let res = KMeans::params(1).tolerance(0.).check();
219        assert!(matches!(res, Err(KMeansParamsError::Tolerance)))
220    }
221
222    #[test]
223    fn max_n_iterations_cannot_be_zero() {
224        let res = KMeans::params(1).tolerance(1.).max_n_iterations(0).check();
225        assert!(matches!(res, Err(KMeansParamsError::MaxIterations)))
226    }
227
228    #[test]
229    fn n_runs_cannot_be_zero() {
230        let res = KMeans::params(1).tolerance(1.).n_runs(0).check();
231        assert!(matches!(res, Err(KMeansParamsError::NRuns)))
232    }
233}