1use super::algorithm::{update_cluster_memberships, update_min_dists};
2use linfa::Float;
3use linfa_nn::distance::Distance;
4use ndarray::parallel::prelude::*;
5use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
6use ndarray_rand::rand::distributions::{Distribution, WeightedIndex};
7use ndarray_rand::rand::Rng;
8use ndarray_rand::rand::{self, SeedableRng};
9use rand_xoshiro::Xoshiro256Plus;
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
13
14#[cfg_attr(
15 feature = "serde",
16 derive(Serialize, Deserialize),
17 serde(crate = "serde_crate")
18)]
19#[derive(Clone, Debug, PartialEq)]
20#[non_exhaustive]
21pub enum KMeansInit<F: Float> {
23 Random,
25 Precomputed(Array2<F>),
27 KMeansPlusPlus,
30 KMeansPara,
35}
36
37#[cfg_attr(
38 feature = "serde",
39 derive(Serialize, Deserialize),
40 serde(crate = "serde_crate")
41)]
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43#[non_exhaustive]
44pub enum KMeansAlgorithm {
55 Lloyd,
63 Hamerly,
79}
80
81impl<F: Float> KMeansInit<F> {
82 pub(crate) fn run<R: Rng, D: Distance<F>>(
84 &self,
85 dist_fn: &D,
86 n_clusters: usize,
87 observations: ArrayView2<F>,
88 rng: &mut R,
89 ) -> Array2<F> {
90 match self {
91 Self::Random => random_init(n_clusters, observations, rng),
92 Self::KMeansPlusPlus => k_means_plusplus(dist_fn, n_clusters, observations, rng),
93 Self::KMeansPara => k_means_para(dist_fn, n_clusters, observations, rng),
94 Self::Precomputed(centroids) => {
95 assert_eq!(centroids.nrows(), n_clusters);
97 assert_eq!(centroids.ncols(), observations.ncols());
98 centroids.clone()
99 }
100 }
101 }
102}
103
104fn random_init<F: Float>(
106 n_clusters: usize,
107 observations: ArrayView2<F>,
108 rng: &mut impl Rng,
109) -> Array2<F> {
110 let (n_samples, _) = observations.dim();
111 let indices = rand::seq::index::sample(rng, n_samples, n_clusters).into_vec();
112 observations.select(Axis(0), &indices)
113}
114
115fn weighted_k_means_plusplus<F: Float, D: Distance<F>>(
119 dist_fn: &D,
120 n_clusters: usize,
121 observations: ArrayView2<F>,
122 weights: ArrayView1<F>,
123 rng: &mut impl Rng,
124) -> Array2<F> {
125 let (n_samples, n_features) = observations.dim();
126 assert_eq!(n_samples, weights.len());
127 assert_ne!(weights.sum(), F::zero());
128
129 let mut centroids = Array2::zeros((n_clusters, n_features));
130 let first_idx = WeightedIndex::new(weights.iter())
132 .expect("invalid weights")
133 .sample(rng);
134 centroids.row_mut(0).assign(&observations.row(first_idx));
135
136 let mut dists = Array1::zeros(n_samples);
137 for c_cnt in 1..n_clusters {
138 update_min_dists(
139 dist_fn,
140 ¢roids.slice(s![0..c_cnt, ..]),
141 &observations,
142 &mut dists,
143 );
144
145 dists *= &weights;
148 let centroid_idx = WeightedIndex::new(dists.iter())
149 .map(|idx| idx.sample(rng))
150 .unwrap_or(0);
153 centroids
154 .row_mut(c_cnt)
155 .assign(&observations.row(centroid_idx));
156 }
157 centroids
158}
159
160fn k_means_plusplus<F: Float, D: Distance<F>>(
162 dist_fn: &D,
163 n_clusters: usize,
164 observations: ArrayView2<F>,
165 rng: &mut impl Rng,
166) -> Array2<F> {
167 weighted_k_means_plusplus(
168 dist_fn,
169 n_clusters,
170 observations,
171 Array1::ones(observations.nrows()).view(),
172 rng,
173 )
174}
175
176fn k_means_para<R: Rng, F: Float, D: Distance<F>>(
182 dist_fn: &D,
183 n_clusters: usize,
184 observations: ArrayView2<F>,
185 rng: &mut R,
186) -> Array2<F> {
187 let n_rounds = 8;
191 let candidates_per_round = n_clusters;
192
193 let (n_samples, n_features) = observations.dim();
194 let mut candidates = Array2::zeros((n_clusters * n_rounds, n_features));
195
196 let first_idx = rng.gen_range(0..n_samples);
198 candidates.row_mut(0).assign(&observations.row(first_idx));
199 let mut n_candidates = 1;
200
201 let mut dists = Array1::zeros(n_samples);
202 'outer: for _ in 0..n_rounds {
203 let current_candidates = candidates.slice(s![0..n_candidates, ..]);
204 update_min_dists(dist_fn, ¤t_candidates, &observations, &mut dists);
205 let next_candidates_idx = sample_subsequent_candidates::<R, _>(
209 &dists,
210 F::cast(candidates_per_round),
211 rng.gen_range(0..u64::MAX),
212 );
213
214 for idx in next_candidates_idx.into_iter() {
217 candidates
218 .row_mut(n_candidates)
219 .assign(&observations.row(idx));
220 n_candidates += 1;
221 if n_candidates >= candidates.nrows() {
222 break 'outer;
223 }
224 }
225 }
226
227 let final_candidates = candidates.slice(s![0..n_candidates, ..]);
228 let weights = cluster_membership_counts(dist_fn, &final_candidates, &observations);
230
231 weighted_k_means_plusplus(dist_fn, n_clusters, final_candidates, weights.view(), rng)
234}
235
236#[allow(clippy::extra_unused_type_parameters)]
239fn sample_subsequent_candidates<R: Rng, F: Float>(
240 dists: &Array1<F>,
241 multiplier: F,
242 seed: u64,
243) -> Vec<usize> {
244 let cost = dists.sum();
246 let seed = AtomicU64::new(seed);
248
249 dists
255 .axis_iter(Axis(0))
256 .into_par_iter()
257 .enumerate()
258 .map_init(
259 || Xoshiro256Plus::seed_from_u64(seed.fetch_add(1, Relaxed)),
261 move |rng, (i, d)| {
262 let d = *d.into_scalar();
263 let rand = F::cast(rng.gen_range(0.0..1.0));
264 let prob = multiplier * d / cost;
265 (i, rand, prob)
266 },
267 )
268 .filter_map(|(i, rand, prob)| if rand < prob { Some(i) } else { None })
269 .collect()
270}
271
272fn cluster_membership_counts<F: Float, D: Distance<F>>(
274 dist_fn: &D,
275 centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
276 observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
277) -> Array1<F> {
278 let n_samples = observations.nrows();
279 let n_clusters = centroids.nrows();
280 let mut memberships = Array1::zeros(n_samples);
281 update_cluster_memberships(dist_fn, centroids, observations, &mut memberships);
282 let mut counts = Array1::zeros(n_clusters);
283 memberships.iter().for_each(|&c| counts[c] += F::one());
284 counts
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use approx::{abs_diff_eq, assert_abs_diff_eq, assert_abs_diff_ne};
291 use linfa_nn::distance::{L1Dist, L2Dist};
292 use ndarray::{array, concatenate, Array};
293 use ndarray_rand::rand::SeedableRng;
294 use ndarray_rand::rand_distr::Normal;
295 use ndarray_rand::RandomExt;
296 use rand_xoshiro::Xoshiro256Plus;
297 use std::collections::HashSet;
298
299 #[test]
300 fn autotraits() {
301 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
302 has_autotraits::<KMeansInit<f64>>();
303 }
304
305 #[test]
306 fn test_precomputed() {
307 let mut rng = Xoshiro256Plus::seed_from_u64(40);
308 let centroids = array![[0.0, 1.0], [40.0, 10.0]];
309 let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
310 let c = KMeansInit::Precomputed(centroids.clone()).run(
311 &L2Dist,
312 2,
313 observations.view(),
314 &mut rng,
315 );
316 assert_abs_diff_eq!(c, centroids);
317 }
318
319 #[test]
320 fn test_sample_subsequent_candidates() {
321 let dists = array![0.0, 0.4, 0.5];
322 let candidates = sample_subsequent_candidates::<Xoshiro256Plus, _>(&dists, 8.0, 0);
323 assert_eq!(candidates, vec![1, 2]);
324 }
325
326 #[test]
327 fn test_cluster_membership_counts() {
328 let centroids = array![[0.0, 1.0], [40.0, 10.0], [3.0, 9.0]];
329 let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
330
331 let counts = cluster_membership_counts(&L2Dist, ¢roids, &observations);
332 assert_abs_diff_eq!(counts, array![2.0, 1.0, 0.0]);
333 let counts = cluster_membership_counts(&L1Dist, ¢roids, &observations);
334 assert_abs_diff_eq!(counts, array![1.0, 1.0, 1.0]);
335 }
336
337 #[test]
338 fn test_weighted_kmeans_plusplus() {
339 let mut rng = Xoshiro256Plus::seed_from_u64(42);
340 let obs = Array::random_using((1000, 2), Normal::new(0.0, 100.).unwrap(), &mut rng);
341 let mut weights = Array1::zeros(1000);
342 weights[0] = 2.0;
343 weights[1] = 3.0;
344 let out = weighted_k_means_plusplus(&L2Dist, 2, obs.view(), weights.view(), &mut rng);
345 let mut expected_centroids = {
346 let mut arr = Array2::zeros((2, 2));
347 arr.row_mut(0).assign(&obs.row(0));
348 arr.row_mut(1).assign(&obs.row(1));
349 arr
350 };
351 assert!(
352 abs_diff_eq!(out, expected_centroids) || {
353 expected_centroids.invert_axis(Axis(0));
354 abs_diff_eq!(out, expected_centroids)
355 }
356 );
357 }
358
359 #[test]
360 fn test_k_means_plusplus() {
361 verify_init(KMeansInit::KMeansPlusPlus, L2Dist);
362 verify_init(KMeansInit::KMeansPlusPlus, L1Dist);
363 }
364
365 #[test]
366 fn test_k_means_para() {
367 verify_init(KMeansInit::KMeansPara, L2Dist);
368 verify_init(KMeansInit::KMeansPara, L1Dist);
369 }
370
371 fn verify_init<D: Distance<f64>>(init: KMeansInit<f64>, dist_fn: D) {
373 let mut rng = Xoshiro256Plus::seed_from_u64(42);
374 let degenerate_data = array![[1.0, 2.0]];
376 let out = init.run(&dist_fn, 2, degenerate_data.view(), &mut rng);
377 assert_abs_diff_eq!(out, concatenate![Axis(0), degenerate_data, degenerate_data]);
378
379 let centroids = [20.0, -1000.0, 1000.0];
381 let clusters: Vec<Array2<_>> = centroids
382 .iter()
383 .map(|&c| Array::random_using((50, 2), Normal::new(c, 1.).unwrap(), &mut rng))
384 .collect();
385 let obs = clusters.iter().fold(Array2::default((0, 2)), |a, b| {
386 concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
387 });
388
389 let out = init.run(&dist_fn, centroids.len(), obs.view(), &mut rng);
391 let mut cluster_ids = HashSet::new();
392 for row in out.rows() {
393 assert_abs_diff_ne!(row, Array1::zeros(row.len()), epsilon = 1e-1);
395 let found = clusters
397 .iter()
398 .enumerate()
399 .find_map(|(i, c)| {
400 if c.rows().into_iter().any(|cl| abs_diff_eq!(row, cl)) {
401 Some(i)
402 } else {
403 None
404 }
405 })
406 .unwrap();
407 cluster_ids.insert(found);
408 }
409 assert_eq!(cluster_ids, [0, 1, 2].iter().copied().collect());
411 }
412
413 macro_rules! calc_loss {
414 ($dist_fn:expr, $centroids:expr, $observations:expr) => {{
415 let mut dists = Array1::zeros($observations.nrows());
416 update_min_dists(&$dist_fn, &$centroids, &$observations, &mut dists);
417 dists.sum()
418 }};
419 }
420
421 fn test_compare<D: Distance<f64>>(dist_fn: D) {
422 let mut rng = Xoshiro256Plus::seed_from_u64(42);
423 let centroids = [20.0, -1000.0, 1000.0];
424 let clusters: Vec<Array2<_>> = centroids
425 .iter()
426 .map(|&c| Array::random_using((50, 2), Normal::new(c, 1.).unwrap(), &mut rng))
427 .collect();
428 let obs = clusters.iter().fold(Array2::default((0, 2)), |a, b| {
429 concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
430 });
431
432 let out_rand = random_init(3, obs.view(), &mut rng.clone());
433 let out_pp = k_means_plusplus(&dist_fn, 3, obs.view(), &mut rng.clone());
434 let out_para = k_means_para(&dist_fn, 3, obs.view(), &mut rng);
435 assert!(calc_loss!(dist_fn, out_pp, obs) < calc_loss!(dist_fn, out_rand, obs));
437 assert!(calc_loss!(dist_fn, out_para, obs) < calc_loss!(dist_fn, out_rand, obs));
439 }
440
441 #[test]
442 fn test_compare_l2() {
443 test_compare(L2Dist);
444 }
445
446 #[test]
447 fn test_compare_l1() {
448 test_compare(L1Dist);
449 }
450}