1use super::algorithm::{update_cluster_memberships, update_min_dists};
2use linfa::Float;
3use linfa_nn::distance::Distance;
4use ndarray::parallel::prelude::*;
5use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
6use ndarray_rand::rand::distributions::{Distribution, WeightedIndex};
7use ndarray_rand::rand::Rng;
8use ndarray_rand::rand::{self, SeedableRng};
9use rand_xoshiro::Xoshiro256Plus;
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
13
14#[cfg_attr(
15 feature = "serde",
16 derive(Serialize, Deserialize),
17 serde(crate = "serde_crate")
18)]
19#[derive(Clone, Debug, PartialEq)]
20#[non_exhaustive]
21pub enum KMeansInit<F: Float> {
23 Random,
25 Precomputed(Array2<F>),
27 KMeansPlusPlus,
30 KMeansPara,
35}
36
37impl<F: Float> KMeansInit<F> {
38 pub(crate) fn run<R: Rng, D: Distance<F>>(
40 &self,
41 dist_fn: &D,
42 n_clusters: usize,
43 observations: ArrayView2<F>,
44 rng: &mut R,
45 ) -> Array2<F> {
46 match self {
47 Self::Random => random_init(n_clusters, observations, rng),
48 Self::KMeansPlusPlus => k_means_plusplus(dist_fn, n_clusters, observations, rng),
49 Self::KMeansPara => k_means_para(dist_fn, n_clusters, observations, rng),
50 Self::Precomputed(centroids) => {
51 assert_eq!(centroids.nrows(), n_clusters);
53 assert_eq!(centroids.ncols(), observations.ncols());
54 centroids.clone()
55 }
56 }
57 }
58}
59
60fn random_init<F: Float>(
62 n_clusters: usize,
63 observations: ArrayView2<F>,
64 rng: &mut impl Rng,
65) -> Array2<F> {
66 let (n_samples, _) = observations.dim();
67 let indices = rand::seq::index::sample(rng, n_samples, n_clusters).into_vec();
68 observations.select(Axis(0), &indices)
69}
70
71fn weighted_k_means_plusplus<F: Float, D: Distance<F>>(
75 dist_fn: &D,
76 n_clusters: usize,
77 observations: ArrayView2<F>,
78 weights: ArrayView1<F>,
79 rng: &mut impl Rng,
80) -> Array2<F> {
81 let (n_samples, n_features) = observations.dim();
82 assert_eq!(n_samples, weights.len());
83 assert_ne!(weights.sum(), F::zero());
84
85 let mut centroids = Array2::zeros((n_clusters, n_features));
86 let first_idx = WeightedIndex::new(weights.iter())
88 .expect("invalid weights")
89 .sample(rng);
90 centroids.row_mut(0).assign(&observations.row(first_idx));
91
92 let mut dists = Array1::zeros(n_samples);
93 for c_cnt in 1..n_clusters {
94 update_min_dists(
95 dist_fn,
96 ¢roids.slice(s![0..c_cnt, ..]),
97 &observations,
98 &mut dists,
99 );
100
101 dists *= &weights;
104 let centroid_idx = WeightedIndex::new(dists.iter())
105 .map(|idx| idx.sample(rng))
106 .unwrap_or(0);
109 centroids
110 .row_mut(c_cnt)
111 .assign(&observations.row(centroid_idx));
112 }
113 centroids
114}
115
116fn k_means_plusplus<F: Float, D: Distance<F>>(
118 dist_fn: &D,
119 n_clusters: usize,
120 observations: ArrayView2<F>,
121 rng: &mut impl Rng,
122) -> Array2<F> {
123 weighted_k_means_plusplus(
124 dist_fn,
125 n_clusters,
126 observations,
127 Array1::ones(observations.nrows()).view(),
128 rng,
129 )
130}
131
132fn k_means_para<R: Rng, F: Float, D: Distance<F>>(
138 dist_fn: &D,
139 n_clusters: usize,
140 observations: ArrayView2<F>,
141 rng: &mut R,
142) -> Array2<F> {
143 let n_rounds = 8;
147 let candidates_per_round = n_clusters;
148
149 let (n_samples, n_features) = observations.dim();
150 let mut candidates = Array2::zeros((n_clusters * n_rounds, n_features));
151
152 let first_idx = rng.gen_range(0..n_samples);
154 candidates.row_mut(0).assign(&observations.row(first_idx));
155 let mut n_candidates = 1;
156
157 let mut dists = Array1::zeros(n_samples);
158 'outer: for _ in 0..n_rounds {
159 let current_candidates = candidates.slice(s![0..n_candidates, ..]);
160 update_min_dists(dist_fn, ¤t_candidates, &observations, &mut dists);
161 let next_candidates_idx = sample_subsequent_candidates::<R, _>(
165 &dists,
166 F::cast(candidates_per_round),
167 rng.gen_range(0..u64::MAX),
168 );
169
170 for idx in next_candidates_idx.into_iter() {
173 candidates
174 .row_mut(n_candidates)
175 .assign(&observations.row(idx));
176 n_candidates += 1;
177 if n_candidates >= candidates.nrows() {
178 break 'outer;
179 }
180 }
181 }
182
183 let final_candidates = candidates.slice(s![0..n_candidates, ..]);
184 let weights = cluster_membership_counts(dist_fn, &final_candidates, &observations);
186
187 weighted_k_means_plusplus(dist_fn, n_clusters, final_candidates, weights.view(), rng)
190}
191
192#[allow(clippy::extra_unused_type_parameters)]
195fn sample_subsequent_candidates<R: Rng, F: Float>(
196 dists: &Array1<F>,
197 multiplier: F,
198 seed: u64,
199) -> Vec<usize> {
200 let cost = dists.sum();
202 let seed = AtomicU64::new(seed);
204
205 dists
211 .axis_iter(Axis(0))
212 .into_par_iter()
213 .enumerate()
214 .map_init(
215 || Xoshiro256Plus::seed_from_u64(seed.fetch_add(1, Relaxed)),
217 move |rng, (i, d)| {
218 let d = *d.into_scalar();
219 let rand = F::cast(rng.gen_range(0.0..1.0));
220 let prob = multiplier * d / cost;
221 (i, rand, prob)
222 },
223 )
224 .filter_map(|(i, rand, prob)| if rand < prob { Some(i) } else { None })
225 .collect()
226}
227
228fn cluster_membership_counts<F: Float, D: Distance<F>>(
230 dist_fn: &D,
231 centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
232 observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
233) -> Array1<F> {
234 let n_samples = observations.nrows();
235 let n_clusters = centroids.nrows();
236 let mut memberships = Array1::zeros(n_samples);
237 update_cluster_memberships(dist_fn, centroids, observations, &mut memberships);
238 let mut counts = Array1::zeros(n_clusters);
239 memberships.iter().for_each(|&c| counts[c] += F::one());
240 counts
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use approx::{abs_diff_eq, assert_abs_diff_eq, assert_abs_diff_ne};
247 use linfa_nn::distance::{L1Dist, L2Dist};
248 use ndarray::{array, concatenate, Array};
249 use ndarray_rand::rand::SeedableRng;
250 use ndarray_rand::rand_distr::Normal;
251 use ndarray_rand::RandomExt;
252 use rand_xoshiro::Xoshiro256Plus;
253 use std::collections::HashSet;
254
255 #[test]
256 fn autotraits() {
257 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
258 has_autotraits::<KMeansInit<f64>>();
259 }
260
261 #[test]
262 fn test_precomputed() {
263 let mut rng = Xoshiro256Plus::seed_from_u64(40);
264 let centroids = array![[0.0, 1.0], [40.0, 10.0]];
265 let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
266 let c = KMeansInit::Precomputed(centroids.clone()).run(
267 &L2Dist,
268 2,
269 observations.view(),
270 &mut rng,
271 );
272 assert_abs_diff_eq!(c, centroids);
273 }
274
275 #[test]
276 fn test_sample_subsequent_candidates() {
277 let dists = array![0.0, 0.4, 0.5];
278 let candidates = sample_subsequent_candidates::<Xoshiro256Plus, _>(&dists, 8.0, 0);
279 assert_eq!(candidates, vec![1, 2]);
280 }
281
282 #[test]
283 fn test_cluster_membership_counts() {
284 let centroids = array![[0.0, 1.0], [40.0, 10.0], [3.0, 9.0]];
285 let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
286
287 let counts = cluster_membership_counts(&L2Dist, ¢roids, &observations);
288 assert_abs_diff_eq!(counts, array![2.0, 1.0, 0.0]);
289 let counts = cluster_membership_counts(&L1Dist, ¢roids, &observations);
290 assert_abs_diff_eq!(counts, array![1.0, 1.0, 1.0]);
291 }
292
293 #[test]
294 fn test_weighted_kmeans_plusplus() {
295 let mut rng = Xoshiro256Plus::seed_from_u64(42);
296 let obs = Array::random_using((1000, 2), Normal::new(0.0, 100.).unwrap(), &mut rng);
297 let mut weights = Array1::zeros(1000);
298 weights[0] = 2.0;
299 weights[1] = 3.0;
300 let out = weighted_k_means_plusplus(&L2Dist, 2, obs.view(), weights.view(), &mut rng);
301 let mut expected_centroids = {
302 let mut arr = Array2::zeros((2, 2));
303 arr.row_mut(0).assign(&obs.row(0));
304 arr.row_mut(1).assign(&obs.row(1));
305 arr
306 };
307 assert!(
308 abs_diff_eq!(out, expected_centroids) || {
309 expected_centroids.invert_axis(Axis(0));
310 abs_diff_eq!(out, expected_centroids)
311 }
312 );
313 }
314
315 #[test]
316 fn test_k_means_plusplus() {
317 verify_init(KMeansInit::KMeansPlusPlus, L2Dist);
318 verify_init(KMeansInit::KMeansPlusPlus, L1Dist);
319 }
320
321 #[test]
322 fn test_k_means_para() {
323 verify_init(KMeansInit::KMeansPara, L2Dist);
324 verify_init(KMeansInit::KMeansPara, L1Dist);
325 }
326
327 fn verify_init<D: Distance<f64>>(init: KMeansInit<f64>, dist_fn: D) {
329 let mut rng = Xoshiro256Plus::seed_from_u64(42);
330 let degenerate_data = array![[1.0, 2.0]];
332 let out = init.run(&dist_fn, 2, degenerate_data.view(), &mut rng);
333 assert_abs_diff_eq!(out, concatenate![Axis(0), degenerate_data, degenerate_data]);
334
335 let centroids = [20.0, -1000.0, 1000.0];
337 let clusters: Vec<Array2<_>> = centroids
338 .iter()
339 .map(|&c| Array::random_using((50, 2), Normal::new(c, 1.).unwrap(), &mut rng))
340 .collect();
341 let obs = clusters.iter().fold(Array2::default((0, 2)), |a, b| {
342 concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
343 });
344
345 let out = init.run(&dist_fn, centroids.len(), obs.view(), &mut rng);
347 let mut cluster_ids = HashSet::new();
348 for row in out.rows() {
349 assert_abs_diff_ne!(row, Array1::zeros(row.len()), epsilon = 1e-1);
351 let found = clusters
353 .iter()
354 .enumerate()
355 .find_map(|(i, c)| {
356 if c.rows().into_iter().any(|cl| abs_diff_eq!(row, cl)) {
357 Some(i)
358 } else {
359 None
360 }
361 })
362 .unwrap();
363 cluster_ids.insert(found);
364 }
365 assert_eq!(cluster_ids, [0, 1, 2].iter().copied().collect());
367 }
368
369 macro_rules! calc_loss {
370 ($dist_fn:expr, $centroids:expr, $observations:expr) => {{
371 let mut dists = Array1::zeros($observations.nrows());
372 update_min_dists(&$dist_fn, &$centroids, &$observations, &mut dists);
373 dists.sum()
374 }};
375 }
376
377 fn test_compare<D: Distance<f64>>(dist_fn: D) {
378 let mut rng = Xoshiro256Plus::seed_from_u64(42);
379 let centroids = [20.0, -1000.0, 1000.0];
380 let clusters: Vec<Array2<_>> = centroids
381 .iter()
382 .map(|&c| Array::random_using((50, 2), Normal::new(c, 1.).unwrap(), &mut rng))
383 .collect();
384 let obs = clusters.iter().fold(Array2::default((0, 2)), |a, b| {
385 concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
386 });
387
388 let out_rand = random_init(3, obs.view(), &mut rng.clone());
389 let out_pp = k_means_plusplus(&dist_fn, 3, obs.view(), &mut rng.clone());
390 let out_para = k_means_para(&dist_fn, 3, obs.view(), &mut rng);
391 assert!(calc_loss!(dist_fn, out_pp, obs) < calc_loss!(dist_fn, out_rand, obs));
393 assert!(calc_loss!(dist_fn, out_para, obs) < calc_loss!(dist_fn, out_rand, obs));
395 }
396
397 #[test]
398 fn test_compare_l2() {
399 test_compare(L2Dist);
400 }
401
402 #[test]
403 fn test_compare_l1() {
404 test_compare(L1Dist);
405 }
406}