linfa_clustering/k_means/
hyperparams.rs1use crate::KMeansParamsError;
2
3use super::init::KMeansInit;
4use linfa::prelude::*;
5use linfa::Float;
6use linfa_nn::distance::Distance;
7use ndarray_rand::rand::Rng;
8#[cfg(feature = "serde")]
9use serde_crate::{Deserialize, Serialize};
10
11#[cfg_attr(
12 feature = "serde",
13 derive(Serialize, Deserialize),
14 serde(crate = "serde_crate")
15)]
16#[derive(Clone, Debug, PartialEq)]
17pub struct KMeansValidParams<F: Float, R: Rng, D: Distance<F>> {
20 n_runs: usize,
22 tolerance: F,
26 max_n_iterations: u64,
30 n_clusters: usize,
32 init: KMeansInit<F>,
34 rng: R,
36 dist_fn: D,
38}
39
40#[derive(Clone, Debug, PartialEq)]
41#[cfg_attr(
42 feature = "serde",
43 derive(Serialize, Deserialize),
44 serde(crate = "serde_crate")
45)]
46pub struct KMeansParams<F: Float, R: Rng, D: Distance<F>>(KMeansValidParams<F, R, D>);
49
50impl<F: Float, R: Rng, D: Distance<F>> KMeansParams<F, R, D> {
51 pub fn new(n_clusters: usize, rng: R, dist_fn: D) -> Self {
70 Self(KMeansValidParams {
71 n_runs: 10,
72 tolerance: F::cast(1e-4),
73 max_n_iterations: 300,
74 n_clusters,
75 init: KMeansInit::KMeansPlusPlus,
76 rng,
77 dist_fn,
78 })
79 }
80
81 pub fn n_runs(mut self, n_runs: usize) -> Self {
83 self.0.n_runs = n_runs;
84 self
85 }
86
87 pub fn tolerance(mut self, tolerance: F) -> Self {
89 self.0.tolerance = tolerance;
90 self
91 }
92
93 pub fn max_n_iterations(mut self, max_n_iterations: u64) -> Self {
95 self.0.max_n_iterations = max_n_iterations;
96 self
97 }
98
99 pub fn init_method(mut self, init: KMeansInit<F>) -> Self {
101 self.0.init = init;
102 self
103 }
104}
105
106impl<F: Float, R: Rng, D: Distance<F>> ParamGuard for KMeansParams<F, R, D> {
107 type Checked = KMeansValidParams<F, R, D>;
108 type Error = KMeansParamsError;
109
110 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
111 if self.0.n_clusters == 0 {
112 Err(KMeansParamsError::NClusters)
113 } else if self.0.n_runs == 0 {
114 Err(KMeansParamsError::NRuns)
115 } else if self.0.tolerance <= F::zero() {
116 Err(KMeansParamsError::Tolerance)
117 } else if self.0.max_n_iterations == 0 {
118 Err(KMeansParamsError::MaxIterations)
119 } else {
120 Ok(&self.0)
121 }
122 }
123
124 fn check(self) -> Result<Self::Checked, Self::Error> {
125 self.check_ref()?;
126 Ok(self.0)
127 }
128}
129
130impl<F: Float, R: Rng, D: Distance<F>> KMeansValidParams<F, R, D> {
131 pub fn n_runs(&self) -> usize {
133 self.n_runs
134 }
135
136 pub fn tolerance(&self) -> F {
140 self.tolerance
141 }
142
143 pub fn max_n_iterations(&self) -> u64 {
147 self.max_n_iterations
148 }
149
150 pub fn n_clusters(&self) -> usize {
152 self.n_clusters
153 }
154
155 pub fn init_method(&self) -> &KMeansInit<F> {
157 &self.init
158 }
159
160 pub fn rng(&self) -> &R {
162 &self.rng
163 }
164
165 pub fn dist_fn(&self) -> &D {
167 &self.dist_fn
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use crate::{KMeans, KMeansParams, KMeansParamsError, KMeansValidParams};
174 use linfa::ParamGuard;
175 use linfa_nn::distance::L2Dist;
176 use rand_xoshiro::Xoshiro256Plus;
177
178 #[test]
179 fn autotraits() {
180 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
181 has_autotraits::<KMeansParams<f64, Xoshiro256Plus, L2Dist>>();
182 has_autotraits::<KMeansValidParams<f64, Xoshiro256Plus, L2Dist>>();
183 }
184
185 #[test]
186 fn n_clusters_cannot_be_zero() {
187 let res = KMeans::<f32, _>::params(0).check();
188 assert!(matches!(res, Err(KMeansParamsError::NClusters)))
189 }
190
191 #[test]
192 fn tolerance_has_to_positive() {
193 let res = KMeans::params(1).tolerance(-1.).check();
194 assert!(matches!(res, Err(KMeansParamsError::Tolerance)))
195 }
196
197 #[test]
198 fn tolerance_cannot_be_zero() {
199 let res = KMeans::params(1).tolerance(0.).check();
200 assert!(matches!(res, Err(KMeansParamsError::Tolerance)))
201 }
202
203 #[test]
204 fn max_n_iterations_cannot_be_zero() {
205 let res = KMeans::params(1).tolerance(1.).max_n_iterations(0).check();
206 assert!(matches!(res, Err(KMeansParamsError::MaxIterations)))
207 }
208
209 #[test]
210 fn n_runs_cannot_be_zero() {
211 let res = KMeans::params(1).tolerance(1.).n_runs(0).check();
212 assert!(matches!(res, Err(KMeansParamsError::NRuns)))
213 }
214}