Skip to main content

linfa_clustering/k_means/
init.rs

1use super::algorithm::{update_cluster_memberships, update_min_dists};
2use linfa::Float;
3use linfa_nn::distance::Distance;
4use ndarray::parallel::prelude::*;
5use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
6use ndarray_rand::rand::distributions::{Distribution, WeightedIndex};
7use ndarray_rand::rand::Rng;
8use ndarray_rand::rand::{self, SeedableRng};
9use rand_xoshiro::Xoshiro256Plus;
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
13
14#[cfg_attr(
15    feature = "serde",
16    derive(Serialize, Deserialize),
17    serde(crate = "serde_crate")
18)]
19#[derive(Clone, Debug, PartialEq)]
20#[non_exhaustive]
21/// Specifies centroid initialization algorithm for KMeans.
22pub enum KMeansInit<F: Float> {
23    /// Pick random points as centroids.
24    Random,
25    /// Precomputed list of centroids, represented as an array of (n_centroids, n_features).
26    Precomputed(Array2<F>),
27    /// K-means++ algorithm. Using this over random initialization causes K-means to converge
28    /// faster for almost all cases, since K-means++ produces better centroids.
29    KMeansPlusPlus,
30    /// K-means|| algorithm, a parallelized version of K-means++. Performs much better than
31    /// K-means++ when the number of clusters is large (>100) while producing similar centroids, so
32    /// use this for larger datasets.  Details on the algorithm can be found
33    /// [here](http://vldb.org/pvldb/vol5/p622_bahmanbahmani_vldb2012.pdf).
34    KMeansPara,
35}
36
37#[cfg_attr(
38    feature = "serde",
39    derive(Serialize, Deserialize),
40    serde(crate = "serde_crate")
41)]
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43#[non_exhaustive]
44/// Specifies the algorithm used for the KMeans assignment step.
45///
46/// Both variants minimise the same objective and, given identical initial centroids,
47/// converge to the same result. They only differ in how the assignment step is computed.
48/// Select a variant via [`KMeansParams::algorithm`](crate::KMeansParams::algorithm).
49///
50/// This setting only applies to batch `fit`. The incremental Mini-Batch K-means path
51/// (`fit_with`) always uses Lloyd's update, and configuring `Hamerly` alongside
52/// `fit_with` is rejected with
53/// [`KMeansParamsError::IncrementalHamerly`](crate::KMeansParamsError::IncrementalHamerly).
54pub enum KMeansAlgorithm {
55    /// Standard Lloyd's algorithm (also known as the "naive" algorithm).
56    ///
57    /// On every iteration, computes the distance from each observation to every centroid
58    /// to determine the closest one. Simple and predictable; work per iteration is
59    /// `O(n_observations * n_clusters * n_features)`.
60    ///
61    /// Default variant. Works with any [`Distance`](linfa_nn::distance::Distance).
62    Lloyd,
63    /// Hamerly's accelerated algorithm.
64    ///
65    /// Uses the triangle inequality together with per-observation upper/lower distance
66    /// bounds to skip most distance computations once the algorithm has stabilised.
67    /// Produces the same result as Lloyd's algorithm given the same initial centroids,
68    /// and is typically substantially faster for well-separated clusters with a moderate
69    /// number of centroids. For heavily overlapping clusters or very large `n_clusters`
70    /// the bookkeeping overhead can make Lloyd a better choice.
71    ///
72    /// Because the bounds rely on the triangle inequality, the supplied distance
73    /// function must be a true metric. `L2Dist`, `L1Dist` and `LInfDist` satisfy this.
74    ///
75    /// Only supported in batch `fit`; not available for Mini-Batch `fit_with`.
76    ///
77    /// Reference: <https://cs.baylor.edu/~hamerly/papers/sdm_2010.pdf>
78    Hamerly,
79}
80
81impl<F: Float> KMeansInit<F> {
82    /// Runs the chosen initialization routine
83    pub(crate) fn run<R: Rng, D: Distance<F>>(
84        &self,
85        dist_fn: &D,
86        n_clusters: usize,
87        observations: ArrayView2<F>,
88        rng: &mut R,
89    ) -> Array2<F> {
90        match self {
91            Self::Random => random_init(n_clusters, observations, rng),
92            Self::KMeansPlusPlus => k_means_plusplus(dist_fn, n_clusters, observations, rng),
93            Self::KMeansPara => k_means_para(dist_fn, n_clusters, observations, rng),
94            Self::Precomputed(centroids) => {
95                // Check centroid dimensions
96                assert_eq!(centroids.nrows(), n_clusters);
97                assert_eq!(centroids.ncols(), observations.ncols());
98                centroids.clone()
99            }
100        }
101    }
102}
103
104/// Pick random points from the input matrix as centroids
105fn random_init<F: Float>(
106    n_clusters: usize,
107    observations: ArrayView2<F>,
108    rng: &mut impl Rng,
109) -> Array2<F> {
110    let (n_samples, _) = observations.dim();
111    let indices = rand::seq::index::sample(rng, n_samples, n_clusters).into_vec();
112    observations.select(Axis(0), &indices)
113}
114
115/// Selects centroids using the KMeans++ initialization algorithm. The weights determine the
116/// likeliness of an input point to be selected as a centroid relative to other points. The higher
117/// the weight, the more likely the point will be selected as a centroid.
118fn weighted_k_means_plusplus<F: Float, D: Distance<F>>(
119    dist_fn: &D,
120    n_clusters: usize,
121    observations: ArrayView2<F>,
122    weights: ArrayView1<F>,
123    rng: &mut impl Rng,
124) -> Array2<F> {
125    let (n_samples, n_features) = observations.dim();
126    assert_eq!(n_samples, weights.len());
127    assert_ne!(weights.sum(), F::zero());
128
129    let mut centroids = Array2::zeros((n_clusters, n_features));
130    // Select 1st centroid from the input randomly purely based on the weights.
131    let first_idx = WeightedIndex::new(weights.iter())
132        .expect("invalid weights")
133        .sample(rng);
134    centroids.row_mut(0).assign(&observations.row(first_idx));
135
136    let mut dists = Array1::zeros(n_samples);
137    for c_cnt in 1..n_clusters {
138        update_min_dists(
139            dist_fn,
140            &centroids.slice(s![0..c_cnt, ..]),
141            &observations,
142            &mut dists,
143        );
144
145        // The probability of a point being selected as the next centroid is proportional to its
146        // distance from its closest centroid multiplied by its weight.
147        dists *= &weights;
148        let centroid_idx = WeightedIndex::new(dists.iter())
149            .map(|idx| idx.sample(rng))
150            // This only errs if all of dists is 0, which means every point is assigned to a
151            // centroid, so extra centroids don't matter and can be any index.
152            .unwrap_or(0);
153        centroids
154            .row_mut(c_cnt)
155            .assign(&observations.row(centroid_idx));
156    }
157    centroids
158}
159
160/// KMeans++ initialization algorithm without biased weights
161fn k_means_plusplus<F: Float, D: Distance<F>>(
162    dist_fn: &D,
163    n_clusters: usize,
164    observations: ArrayView2<F>,
165    rng: &mut impl Rng,
166) -> Array2<F> {
167    weighted_k_means_plusplus(
168        dist_fn,
169        n_clusters,
170        observations,
171        Array1::ones(observations.nrows()).view(),
172        rng,
173    )
174}
175
176/// KMeans|| initialization algorithm
177/// In each iteration, pick some new "candidate centroids" by sampling the probabilities of each
178/// input point in parallel. The probability of a point becoming a centroid is the same as with
179/// KMeans++. After multiple iterations, run weighted KMeans++ on the candidates to produce the
180/// final set of centroids.
181fn k_means_para<R: Rng, F: Float, D: Distance<F>>(
182    dist_fn: &D,
183    n_clusters: usize,
184    observations: ArrayView2<F>,
185    rng: &mut R,
186) -> Array2<F> {
187    // The product of these parameters must exceed n_clusters. The higher they are, the more
188    // candidates are selected, which improves the quality of the centroids but increases running
189    // time. The values provided here are "sweetspots" suggested by the paper.
190    let n_rounds = 8;
191    let candidates_per_round = n_clusters;
192
193    let (n_samples, n_features) = observations.dim();
194    let mut candidates = Array2::zeros((n_clusters * n_rounds, n_features));
195
196    // Pick 1st centroid randomly
197    let first_idx = rng.gen_range(0..n_samples);
198    candidates.row_mut(0).assign(&observations.row(first_idx));
199    let mut n_candidates = 1;
200
201    let mut dists = Array1::zeros(n_samples);
202    'outer: for _ in 0..n_rounds {
203        let current_candidates = candidates.slice(s![0..n_candidates, ..]);
204        update_min_dists(dist_fn, &current_candidates, &observations, &mut dists);
205        // Generate the next set of candidates from the input points, using the same probability
206        // formula as KMeans++. On average this generates candidates equal to
207        // `candidates_per_round`.
208        let next_candidates_idx = sample_subsequent_candidates::<R, _>(
209            &dists,
210            F::cast(candidates_per_round),
211            rng.gen_range(0..u64::MAX),
212        );
213
214        // Append the newly generated candidates to the current cadidates, breaking out of the loop
215        // if too many candidates have been found
216        for idx in next_candidates_idx.into_iter() {
217            candidates
218                .row_mut(n_candidates)
219                .assign(&observations.row(idx));
220            n_candidates += 1;
221            if n_candidates >= candidates.nrows() {
222                break 'outer;
223            }
224        }
225    }
226
227    let final_candidates = candidates.slice(s![0..n_candidates, ..]);
228    // Weigh the candidate centroids by the sizes of the clusters they form in the input points.
229    let weights = cluster_membership_counts(dist_fn, &final_candidates, &observations);
230
231    // The number of candidates is almost certainly higher than the number of centroids, so we
232    // recluster the candidates into the right number of centroids using weighted KMeans++.
233    weighted_k_means_plusplus(dist_fn, n_clusters, final_candidates, weights.view(), rng)
234}
235
236/// Generate candidate centroids by sampling each observation in parallel using a seedable RNG in
237/// every thread. Average number of generated candidates should equal `multiplier`.
238#[allow(clippy::extra_unused_type_parameters)]
239fn sample_subsequent_candidates<R: Rng, F: Float>(
240    dists: &Array1<F>,
241    multiplier: F,
242    seed: u64,
243) -> Vec<usize> {
244    // This sum can also be parallelized
245    let cost = dists.sum();
246    // Using an atomic allows the seed to be modified while seeding RNGs in parallel
247    let seed = AtomicU64::new(seed);
248
249    // Use `map_init` to generate an unique RNG for each Rayon thread, allowing both RNG creation
250    // and random number generation to be parallelized. Alternative approaches included generating
251    // an RNG for every observation and sequentially taking `multiplier` samples from a weighted
252    // index of `dists`. Generating for every observation was too slow, and the sequential approach
253    // yielded lower-quality centroids, so this approach was chosen. See PR #108 for more details.
254    dists
255        .axis_iter(Axis(0))
256        .into_par_iter()
257        .enumerate()
258        .map_init(
259            // XXX we can use `jump` to have differently seeded RNGs instead of reseeding each time
260            || Xoshiro256Plus::seed_from_u64(seed.fetch_add(1, Relaxed)),
261            move |rng, (i, d)| {
262                let d = *d.into_scalar();
263                let rand = F::cast(rng.gen_range(0.0..1.0));
264                let prob = multiplier * d / cost;
265                (i, rand, prob)
266            },
267        )
268        .filter_map(|(i, rand, prob)| if rand < prob { Some(i) } else { None })
269        .collect()
270}
271
272/// Returns the number of observation points that belong to each cluster.
273fn cluster_membership_counts<F: Float, D: Distance<F>>(
274    dist_fn: &D,
275    centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
276    observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
277) -> Array1<F> {
278    let n_samples = observations.nrows();
279    let n_clusters = centroids.nrows();
280    let mut memberships = Array1::zeros(n_samples);
281    update_cluster_memberships(dist_fn, centroids, observations, &mut memberships);
282    let mut counts = Array1::zeros(n_clusters);
283    memberships.iter().for_each(|&c| counts[c] += F::one());
284    counts
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use approx::{abs_diff_eq, assert_abs_diff_eq, assert_abs_diff_ne};
291    use linfa_nn::distance::{L1Dist, L2Dist};
292    use ndarray::{array, concatenate, Array};
293    use ndarray_rand::rand::SeedableRng;
294    use ndarray_rand::rand_distr::Normal;
295    use ndarray_rand::RandomExt;
296    use rand_xoshiro::Xoshiro256Plus;
297    use std::collections::HashSet;
298
299    #[test]
300    fn autotraits() {
301        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
302        has_autotraits::<KMeansInit<f64>>();
303    }
304
305    #[test]
306    fn test_precomputed() {
307        let mut rng = Xoshiro256Plus::seed_from_u64(40);
308        let centroids = array![[0.0, 1.0], [40.0, 10.0]];
309        let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
310        let c = KMeansInit::Precomputed(centroids.clone()).run(
311            &L2Dist,
312            2,
313            observations.view(),
314            &mut rng,
315        );
316        assert_abs_diff_eq!(c, centroids);
317    }
318
319    #[test]
320    fn test_sample_subsequent_candidates() {
321        let dists = array![0.0, 0.4, 0.5];
322        let candidates = sample_subsequent_candidates::<Xoshiro256Plus, _>(&dists, 8.0, 0);
323        assert_eq!(candidates, vec![1, 2]);
324    }
325
326    #[test]
327    fn test_cluster_membership_counts() {
328        let centroids = array![[0.0, 1.0], [40.0, 10.0], [3.0, 9.0]];
329        let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
330
331        let counts = cluster_membership_counts(&L2Dist, &centroids, &observations);
332        assert_abs_diff_eq!(counts, array![2.0, 1.0, 0.0]);
333        let counts = cluster_membership_counts(&L1Dist, &centroids, &observations);
334        assert_abs_diff_eq!(counts, array![1.0, 1.0, 1.0]);
335    }
336
337    #[test]
338    fn test_weighted_kmeans_plusplus() {
339        let mut rng = Xoshiro256Plus::seed_from_u64(42);
340        let obs = Array::random_using((1000, 2), Normal::new(0.0, 100.).unwrap(), &mut rng);
341        let mut weights = Array1::zeros(1000);
342        weights[0] = 2.0;
343        weights[1] = 3.0;
344        let out = weighted_k_means_plusplus(&L2Dist, 2, obs.view(), weights.view(), &mut rng);
345        let mut expected_centroids = {
346            let mut arr = Array2::zeros((2, 2));
347            arr.row_mut(0).assign(&obs.row(0));
348            arr.row_mut(1).assign(&obs.row(1));
349            arr
350        };
351        assert!(
352            abs_diff_eq!(out, expected_centroids) || {
353                expected_centroids.invert_axis(Axis(0));
354                abs_diff_eq!(out, expected_centroids)
355            }
356        );
357    }
358
359    #[test]
360    fn test_k_means_plusplus() {
361        verify_init(KMeansInit::KMeansPlusPlus, L2Dist);
362        verify_init(KMeansInit::KMeansPlusPlus, L1Dist);
363    }
364
365    #[test]
366    fn test_k_means_para() {
367        verify_init(KMeansInit::KMeansPara, L2Dist);
368        verify_init(KMeansInit::KMeansPara, L1Dist);
369    }
370
371    // Run general tests for a given init algorithm
372    fn verify_init<D: Distance<f64>>(init: KMeansInit<f64>, dist_fn: D) {
373        let mut rng = Xoshiro256Plus::seed_from_u64(42);
374        // Make sure we don't panic on degenerate data (n_clusters > n_samples)
375        let degenerate_data = array![[1.0, 2.0]];
376        let out = init.run(&dist_fn, 2, degenerate_data.view(), &mut rng);
377        assert_abs_diff_eq!(out, concatenate![Axis(0), degenerate_data, degenerate_data]);
378
379        // Build 3 separated clusters of points
380        let centroids = [20.0, -1000.0, 1000.0];
381        let clusters: Vec<Array2<_>> = centroids
382            .iter()
383            .map(|&c| Array::random_using((50, 2), Normal::new(c, 1.).unwrap(), &mut rng))
384            .collect();
385        let obs = clusters.iter().fold(Array2::default((0, 2)), |a, b| {
386            concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
387        });
388
389        // Look for the right number of centroids
390        let out = init.run(&dist_fn, centroids.len(), obs.view(), &mut rng);
391        let mut cluster_ids = HashSet::new();
392        for row in out.rows() {
393            // Centroid should not be 0
394            assert_abs_diff_ne!(row, Array1::zeros(row.len()), epsilon = 1e-1);
395            // Find the resultant centroid in 1 of the 3 clusters
396            let found = clusters
397                .iter()
398                .enumerate()
399                .find_map(|(i, c)| {
400                    if c.rows().into_iter().any(|cl| abs_diff_eq!(row, cl)) {
401                        Some(i)
402                    } else {
403                        None
404                    }
405                })
406                .unwrap();
407            cluster_ids.insert(found);
408        }
409        // Centroids should almost always span all 3 clusters
410        assert_eq!(cluster_ids, [0, 1, 2].iter().copied().collect());
411    }
412
413    macro_rules! calc_loss {
414        ($dist_fn:expr, $centroids:expr, $observations:expr) => {{
415            let mut dists = Array1::zeros($observations.nrows());
416            update_min_dists(&$dist_fn, &$centroids, &$observations, &mut dists);
417            dists.sum()
418        }};
419    }
420
421    fn test_compare<D: Distance<f64>>(dist_fn: D) {
422        let mut rng = Xoshiro256Plus::seed_from_u64(42);
423        let centroids = [20.0, -1000.0, 1000.0];
424        let clusters: Vec<Array2<_>> = centroids
425            .iter()
426            .map(|&c| Array::random_using((50, 2), Normal::new(c, 1.).unwrap(), &mut rng))
427            .collect();
428        let obs = clusters.iter().fold(Array2::default((0, 2)), |a, b| {
429            concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
430        });
431
432        let out_rand = random_init(3, obs.view(), &mut rng.clone());
433        let out_pp = k_means_plusplus(&dist_fn, 3, obs.view(), &mut rng.clone());
434        let out_para = k_means_para(&dist_fn, 3, obs.view(), &mut rng);
435        // Loss of Kmeans++ should be better than using random_init
436        assert!(calc_loss!(dist_fn, out_pp, obs) < calc_loss!(dist_fn, out_rand, obs));
437        // Loss of Kmeans|| should be better than using random_init
438        assert!(calc_loss!(dist_fn, out_para, obs) < calc_loss!(dist_fn, out_rand, obs));
439    }
440
441    #[test]
442    fn test_compare_l2() {
443        test_compare(L2Dist);
444    }
445
446    #[test]
447    fn test_compare_l1() {
448        test_compare(L1Dist);
449    }
450}