linfa_clustering/k_means/
hyperparams.rs1use crate::{KMeansAlgorithm, KMeansParamsError};
2
3use super::init::KMeansInit;
4use linfa::prelude::*;
5use linfa::Float;
6use linfa_nn::distance::Distance;
7use ndarray_rand::rand::Rng;
8#[cfg(feature = "serde")]
9use serde_crate::{Deserialize, Serialize};
10
11#[cfg_attr(
12 feature = "serde",
13 derive(Serialize, Deserialize),
14 serde(crate = "serde_crate")
15)]
16#[derive(Clone, Debug, PartialEq)]
17pub struct KMeansValidParams<F: Float, R: Rng, D: Distance<F>> {
20 n_runs: usize,
22 tolerance: F,
26 max_n_iterations: u64,
30 n_clusters: usize,
32 init: KMeansInit<F>,
34 rng: R,
36 dist_fn: D,
38 algorithm: KMeansAlgorithm,
40}
41
42#[derive(Clone, Debug, PartialEq)]
43#[cfg_attr(
44 feature = "serde",
45 derive(Serialize, Deserialize),
46 serde(crate = "serde_crate")
47)]
48pub struct KMeansParams<F: Float, R: Rng, D: Distance<F>>(KMeansValidParams<F, R, D>);
51
52impl<F: Float, R: Rng, D: Distance<F>> KMeansParams<F, R, D> {
53 pub fn new(n_clusters: usize, rng: R, dist_fn: D) -> Self {
72 Self(KMeansValidParams {
73 n_runs: 10,
74 tolerance: F::cast(1e-4),
75 max_n_iterations: 300,
76 n_clusters,
77 init: KMeansInit::KMeansPlusPlus,
78 rng,
79 dist_fn,
80 algorithm: KMeansAlgorithm::Lloyd,
81 })
82 }
83
84 pub fn n_runs(mut self, n_runs: usize) -> Self {
86 self.0.n_runs = n_runs;
87 self
88 }
89
90 pub fn tolerance(mut self, tolerance: F) -> Self {
92 self.0.tolerance = tolerance;
93 self
94 }
95
96 pub fn max_n_iterations(mut self, max_n_iterations: u64) -> Self {
98 self.0.max_n_iterations = max_n_iterations;
99 self
100 }
101
102 pub fn init_method(mut self, init: KMeansInit<F>) -> Self {
104 self.0.init = init;
105 self
106 }
107
108 pub fn algorithm(mut self, algorithm: KMeansAlgorithm) -> Self {
115 self.0.algorithm = algorithm;
116 self
117 }
118}
119
120impl<F: Float, R: Rng, D: Distance<F>> ParamGuard for KMeansParams<F, R, D> {
121 type Checked = KMeansValidParams<F, R, D>;
122 type Error = KMeansParamsError;
123
124 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
125 if self.0.n_clusters == 0 {
126 Err(KMeansParamsError::NClusters)
127 } else if self.0.n_runs == 0 {
128 Err(KMeansParamsError::NRuns)
129 } else if self.0.tolerance <= F::zero() {
130 Err(KMeansParamsError::Tolerance)
131 } else if self.0.max_n_iterations == 0 {
132 Err(KMeansParamsError::MaxIterations)
133 } else {
134 Ok(&self.0)
135 }
136 }
137
138 fn check(self) -> Result<Self::Checked, Self::Error> {
139 self.check_ref()?;
140 Ok(self.0)
141 }
142}
143
144impl<F: Float, R: Rng, D: Distance<F>> KMeansValidParams<F, R, D> {
145 pub fn n_runs(&self) -> usize {
147 self.n_runs
148 }
149
150 pub fn tolerance(&self) -> F {
154 self.tolerance
155 }
156
157 pub fn max_n_iterations(&self) -> u64 {
161 self.max_n_iterations
162 }
163
164 pub fn n_clusters(&self) -> usize {
166 self.n_clusters
167 }
168
169 pub fn init_method(&self) -> &KMeansInit<F> {
171 &self.init
172 }
173
174 pub fn rng(&self) -> &R {
176 &self.rng
177 }
178
179 pub fn dist_fn(&self) -> &D {
181 &self.dist_fn
182 }
183
184 pub fn algorithm(&self) -> &KMeansAlgorithm {
186 &self.algorithm
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use crate::{KMeans, KMeansParams, KMeansParamsError, KMeansValidParams};
193 use linfa::ParamGuard;
194 use linfa_nn::distance::L2Dist;
195 use rand_xoshiro::Xoshiro256Plus;
196
197 #[test]
198 fn autotraits() {
199 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
200 has_autotraits::<KMeansParams<f64, Xoshiro256Plus, L2Dist>>();
201 has_autotraits::<KMeansValidParams<f64, Xoshiro256Plus, L2Dist>>();
202 }
203
204 #[test]
205 fn n_clusters_cannot_be_zero() {
206 let res = KMeans::<f32, _>::params(0).check();
207 assert!(matches!(res, Err(KMeansParamsError::NClusters)))
208 }
209
210 #[test]
211 fn tolerance_has_to_positive() {
212 let res = KMeans::params(1).tolerance(-1.).check();
213 assert!(matches!(res, Err(KMeansParamsError::Tolerance)))
214 }
215
216 #[test]
217 fn tolerance_cannot_be_zero() {
218 let res = KMeans::params(1).tolerance(0.).check();
219 assert!(matches!(res, Err(KMeansParamsError::Tolerance)))
220 }
221
222 #[test]
223 fn max_n_iterations_cannot_be_zero() {
224 let res = KMeans::params(1).tolerance(1.).max_n_iterations(0).check();
225 assert!(matches!(res, Err(KMeansParamsError::MaxIterations)))
226 }
227
228 #[test]
229 fn n_runs_cannot_be_zero() {
230 let res = KMeans::params(1).tolerance(1.).n_runs(0).check();
231 assert!(matches!(res, Err(KMeansParamsError::NRuns)))
232 }
233}