linfa_clustering/k_means/
init.rs

1use super::algorithm::{update_cluster_memberships, update_min_dists};
2use linfa::Float;
3use linfa_nn::distance::Distance;
4use ndarray::parallel::prelude::*;
5use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
6use ndarray_rand::rand::distributions::{Distribution, WeightedIndex};
7use ndarray_rand::rand::Rng;
8use ndarray_rand::rand::{self, SeedableRng};
9use rand_xoshiro::Xoshiro256Plus;
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
13
14#[cfg_attr(
15    feature = "serde",
16    derive(Serialize, Deserialize),
17    serde(crate = "serde_crate")
18)]
19#[derive(Clone, Debug, PartialEq)]
20#[non_exhaustive]
21/// Specifies centroid initialization algorithm for KMeans.
22pub enum KMeansInit<F: Float> {
23    /// Pick random points as centroids.
24    Random,
25    /// Precomputed list of centroids, represented as an array of (n_centroids, n_features).
26    Precomputed(Array2<F>),
27    /// K-means++ algorithm. Using this over random initialization causes K-means to converge
28    /// faster for almost all cases, since K-means++ produces better centroids.
29    KMeansPlusPlus,
30    /// K-means|| algorithm, a parallelized version of K-means++. Performs much better than
31    /// K-means++ when the number of clusters is large (>100) while producing similar centroids, so
32    /// use this for larger datasets.  Details on the algorithm can be found
33    /// [here](http://vldb.org/pvldb/vol5/p622_bahmanbahmani_vldb2012.pdf).
34    KMeansPara,
35}
36
37impl<F: Float> KMeansInit<F> {
38    /// Runs the chosen initialization routine
39    pub(crate) fn run<R: Rng, D: Distance<F>>(
40        &self,
41        dist_fn: &D,
42        n_clusters: usize,
43        observations: ArrayView2<F>,
44        rng: &mut R,
45    ) -> Array2<F> {
46        match self {
47            Self::Random => random_init(n_clusters, observations, rng),
48            Self::KMeansPlusPlus => k_means_plusplus(dist_fn, n_clusters, observations, rng),
49            Self::KMeansPara => k_means_para(dist_fn, n_clusters, observations, rng),
50            Self::Precomputed(centroids) => {
51                // Check centroid dimensions
52                assert_eq!(centroids.nrows(), n_clusters);
53                assert_eq!(centroids.ncols(), observations.ncols());
54                centroids.clone()
55            }
56        }
57    }
58}
59
60/// Pick random points from the input matrix as centroids
61fn random_init<F: Float>(
62    n_clusters: usize,
63    observations: ArrayView2<F>,
64    rng: &mut impl Rng,
65) -> Array2<F> {
66    let (n_samples, _) = observations.dim();
67    let indices = rand::seq::index::sample(rng, n_samples, n_clusters).into_vec();
68    observations.select(Axis(0), &indices)
69}
70
71/// Selects centroids using the KMeans++ initialization algorithm. The weights determine the
72/// likeliness of an input point to be selected as a centroid relative to other points. The higher
73/// the weight, the more likely the point will be selected as a centroid.
74fn weighted_k_means_plusplus<F: Float, D: Distance<F>>(
75    dist_fn: &D,
76    n_clusters: usize,
77    observations: ArrayView2<F>,
78    weights: ArrayView1<F>,
79    rng: &mut impl Rng,
80) -> Array2<F> {
81    let (n_samples, n_features) = observations.dim();
82    assert_eq!(n_samples, weights.len());
83    assert_ne!(weights.sum(), F::zero());
84
85    let mut centroids = Array2::zeros((n_clusters, n_features));
86    // Select 1st centroid from the input randomly purely based on the weights.
87    let first_idx = WeightedIndex::new(weights.iter())
88        .expect("invalid weights")
89        .sample(rng);
90    centroids.row_mut(0).assign(&observations.row(first_idx));
91
92    let mut dists = Array1::zeros(n_samples);
93    for c_cnt in 1..n_clusters {
94        update_min_dists(
95            dist_fn,
96            &centroids.slice(s![0..c_cnt, ..]),
97            &observations,
98            &mut dists,
99        );
100
101        // The probability of a point being selected as the next centroid is proportional to its
102        // distance from its closest centroid multiplied by its weight.
103        dists *= &weights;
104        let centroid_idx = WeightedIndex::new(dists.iter())
105            .map(|idx| idx.sample(rng))
106            // This only errs if all of dists is 0, which means every point is assigned to a
107            // centroid, so extra centroids don't matter and can be any index.
108            .unwrap_or(0);
109        centroids
110            .row_mut(c_cnt)
111            .assign(&observations.row(centroid_idx));
112    }
113    centroids
114}
115
116/// KMeans++ initialization algorithm without biased weights
117fn k_means_plusplus<F: Float, D: Distance<F>>(
118    dist_fn: &D,
119    n_clusters: usize,
120    observations: ArrayView2<F>,
121    rng: &mut impl Rng,
122) -> Array2<F> {
123    weighted_k_means_plusplus(
124        dist_fn,
125        n_clusters,
126        observations,
127        Array1::ones(observations.nrows()).view(),
128        rng,
129    )
130}
131
132/// KMeans|| initialization algorithm
133/// In each iteration, pick some new "candidate centroids" by sampling the probabilities of each
134/// input point in parallel. The probability of a point becoming a centroid is the same as with
135/// KMeans++. After multiple iterations, run weighted KMeans++ on the candidates to produce the
136/// final set of centroids.
137fn k_means_para<R: Rng, F: Float, D: Distance<F>>(
138    dist_fn: &D,
139    n_clusters: usize,
140    observations: ArrayView2<F>,
141    rng: &mut R,
142) -> Array2<F> {
143    // The product of these parameters must exceed n_clusters. The higher they are, the more
144    // candidates are selected, which improves the quality of the centroids but increases running
145    // time. The values provided here are "sweetspots" suggested by the paper.
146    let n_rounds = 8;
147    let candidates_per_round = n_clusters;
148
149    let (n_samples, n_features) = observations.dim();
150    let mut candidates = Array2::zeros((n_clusters * n_rounds, n_features));
151
152    // Pick 1st centroid randomly
153    let first_idx = rng.gen_range(0..n_samples);
154    candidates.row_mut(0).assign(&observations.row(first_idx));
155    let mut n_candidates = 1;
156
157    let mut dists = Array1::zeros(n_samples);
158    'outer: for _ in 0..n_rounds {
159        let current_candidates = candidates.slice(s![0..n_candidates, ..]);
160        update_min_dists(dist_fn, &current_candidates, &observations, &mut dists);
161        // Generate the next set of candidates from the input points, using the same probability
162        // formula as KMeans++. On average this generates candidates equal to
163        // `candidates_per_round`.
164        let next_candidates_idx = sample_subsequent_candidates::<R, _>(
165            &dists,
166            F::cast(candidates_per_round),
167            rng.gen_range(0..u64::MAX),
168        );
169
170        // Append the newly generated candidates to the current cadidates, breaking out of the loop
171        // if too many candidates have been found
172        for idx in next_candidates_idx.into_iter() {
173            candidates
174                .row_mut(n_candidates)
175                .assign(&observations.row(idx));
176            n_candidates += 1;
177            if n_candidates >= candidates.nrows() {
178                break 'outer;
179            }
180        }
181    }
182
183    let final_candidates = candidates.slice(s![0..n_candidates, ..]);
184    // Weigh the candidate centroids by the sizes of the clusters they form in the input points.
185    let weights = cluster_membership_counts(dist_fn, &final_candidates, &observations);
186
187    // The number of candidates is almost certainly higher than the number of centroids, so we
188    // recluster the candidates into the right number of centroids using weighted KMeans++.
189    weighted_k_means_plusplus(dist_fn, n_clusters, final_candidates, weights.view(), rng)
190}
191
192/// Generate candidate centroids by sampling each observation in parallel using a seedable RNG in
193/// every thread. Average number of generated candidates should equal `multiplier`.
194#[allow(clippy::extra_unused_type_parameters)]
195fn sample_subsequent_candidates<R: Rng, F: Float>(
196    dists: &Array1<F>,
197    multiplier: F,
198    seed: u64,
199) -> Vec<usize> {
200    // This sum can also be parallelized
201    let cost = dists.sum();
202    // Using an atomic allows the seed to be modified while seeding RNGs in parallel
203    let seed = AtomicU64::new(seed);
204
205    // Use `map_init` to generate an unique RNG for each Rayon thread, allowing both RNG creation
206    // and random number generation to be parallelized. Alternative approaches included generating
207    // an RNG for every observation and sequentially taking `multiplier` samples from a weighted
208    // index of `dists`. Generating for every observation was too slow, and the sequential approach
209    // yielded lower-quality centroids, so this approach was chosen. See PR #108 for more details.
210    dists
211        .axis_iter(Axis(0))
212        .into_par_iter()
213        .enumerate()
214        .map_init(
215            // XXX we can use `jump` to have differently seeded RNGs instead of reseeding each time
216            || Xoshiro256Plus::seed_from_u64(seed.fetch_add(1, Relaxed)),
217            move |rng, (i, d)| {
218                let d = *d.into_scalar();
219                let rand = F::cast(rng.gen_range(0.0..1.0));
220                let prob = multiplier * d / cost;
221                (i, rand, prob)
222            },
223        )
224        .filter_map(|(i, rand, prob)| if rand < prob { Some(i) } else { None })
225        .collect()
226}
227
228/// Returns the number of observation points that belong to each cluster.
229fn cluster_membership_counts<F: Float, D: Distance<F>>(
230    dist_fn: &D,
231    centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
232    observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
233) -> Array1<F> {
234    let n_samples = observations.nrows();
235    let n_clusters = centroids.nrows();
236    let mut memberships = Array1::zeros(n_samples);
237    update_cluster_memberships(dist_fn, centroids, observations, &mut memberships);
238    let mut counts = Array1::zeros(n_clusters);
239    memberships.iter().for_each(|&c| counts[c] += F::one());
240    counts
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use approx::{abs_diff_eq, assert_abs_diff_eq, assert_abs_diff_ne};
247    use linfa_nn::distance::{L1Dist, L2Dist};
248    use ndarray::{array, concatenate, Array};
249    use ndarray_rand::rand::SeedableRng;
250    use ndarray_rand::rand_distr::Normal;
251    use ndarray_rand::RandomExt;
252    use rand_xoshiro::Xoshiro256Plus;
253    use std::collections::HashSet;
254
255    #[test]
256    fn autotraits() {
257        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
258        has_autotraits::<KMeansInit<f64>>();
259    }
260
261    #[test]
262    fn test_precomputed() {
263        let mut rng = Xoshiro256Plus::seed_from_u64(40);
264        let centroids = array![[0.0, 1.0], [40.0, 10.0]];
265        let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
266        let c = KMeansInit::Precomputed(centroids.clone()).run(
267            &L2Dist,
268            2,
269            observations.view(),
270            &mut rng,
271        );
272        assert_abs_diff_eq!(c, centroids);
273    }
274
275    #[test]
276    fn test_sample_subsequent_candidates() {
277        let dists = array![0.0, 0.4, 0.5];
278        let candidates = sample_subsequent_candidates::<Xoshiro256Plus, _>(&dists, 8.0, 0);
279        assert_eq!(candidates, vec![1, 2]);
280    }
281
282    #[test]
283    fn test_cluster_membership_counts() {
284        let centroids = array![[0.0, 1.0], [40.0, 10.0], [3.0, 9.0]];
285        let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
286
287        let counts = cluster_membership_counts(&L2Dist, &centroids, &observations);
288        assert_abs_diff_eq!(counts, array![2.0, 1.0, 0.0]);
289        let counts = cluster_membership_counts(&L1Dist, &centroids, &observations);
290        assert_abs_diff_eq!(counts, array![1.0, 1.0, 1.0]);
291    }
292
293    #[test]
294    fn test_weighted_kmeans_plusplus() {
295        let mut rng = Xoshiro256Plus::seed_from_u64(42);
296        let obs = Array::random_using((1000, 2), Normal::new(0.0, 100.).unwrap(), &mut rng);
297        let mut weights = Array1::zeros(1000);
298        weights[0] = 2.0;
299        weights[1] = 3.0;
300        let out = weighted_k_means_plusplus(&L2Dist, 2, obs.view(), weights.view(), &mut rng);
301        let mut expected_centroids = {
302            let mut arr = Array2::zeros((2, 2));
303            arr.row_mut(0).assign(&obs.row(0));
304            arr.row_mut(1).assign(&obs.row(1));
305            arr
306        };
307        assert!(
308            abs_diff_eq!(out, expected_centroids) || {
309                expected_centroids.invert_axis(Axis(0));
310                abs_diff_eq!(out, expected_centroids)
311            }
312        );
313    }
314
315    #[test]
316    fn test_k_means_plusplus() {
317        verify_init(KMeansInit::KMeansPlusPlus, L2Dist);
318        verify_init(KMeansInit::KMeansPlusPlus, L1Dist);
319    }
320
321    #[test]
322    fn test_k_means_para() {
323        verify_init(KMeansInit::KMeansPara, L2Dist);
324        verify_init(KMeansInit::KMeansPara, L1Dist);
325    }
326
327    // Run general tests for a given init algorithm
328    fn verify_init<D: Distance<f64>>(init: KMeansInit<f64>, dist_fn: D) {
329        let mut rng = Xoshiro256Plus::seed_from_u64(42);
330        // Make sure we don't panic on degenerate data (n_clusters > n_samples)
331        let degenerate_data = array![[1.0, 2.0]];
332        let out = init.run(&dist_fn, 2, degenerate_data.view(), &mut rng);
333        assert_abs_diff_eq!(out, concatenate![Axis(0), degenerate_data, degenerate_data]);
334
335        // Build 3 separated clusters of points
336        let centroids = [20.0, -1000.0, 1000.0];
337        let clusters: Vec<Array2<_>> = centroids
338            .iter()
339            .map(|&c| Array::random_using((50, 2), Normal::new(c, 1.).unwrap(), &mut rng))
340            .collect();
341        let obs = clusters.iter().fold(Array2::default((0, 2)), |a, b| {
342            concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
343        });
344
345        // Look for the right number of centroids
346        let out = init.run(&dist_fn, centroids.len(), obs.view(), &mut rng);
347        let mut cluster_ids = HashSet::new();
348        for row in out.rows() {
349            // Centroid should not be 0
350            assert_abs_diff_ne!(row, Array1::zeros(row.len()), epsilon = 1e-1);
351            // Find the resultant centroid in 1 of the 3 clusters
352            let found = clusters
353                .iter()
354                .enumerate()
355                .find_map(|(i, c)| {
356                    if c.rows().into_iter().any(|cl| abs_diff_eq!(row, cl)) {
357                        Some(i)
358                    } else {
359                        None
360                    }
361                })
362                .unwrap();
363            cluster_ids.insert(found);
364        }
365        // Centroids should almost always span all 3 clusters
366        assert_eq!(cluster_ids, [0, 1, 2].iter().copied().collect());
367    }
368
369    macro_rules! calc_loss {
370        ($dist_fn:expr, $centroids:expr, $observations:expr) => {{
371            let mut dists = Array1::zeros($observations.nrows());
372            update_min_dists(&$dist_fn, &$centroids, &$observations, &mut dists);
373            dists.sum()
374        }};
375    }
376
377    fn test_compare<D: Distance<f64>>(dist_fn: D) {
378        let mut rng = Xoshiro256Plus::seed_from_u64(42);
379        let centroids = [20.0, -1000.0, 1000.0];
380        let clusters: Vec<Array2<_>> = centroids
381            .iter()
382            .map(|&c| Array::random_using((50, 2), Normal::new(c, 1.).unwrap(), &mut rng))
383            .collect();
384        let obs = clusters.iter().fold(Array2::default((0, 2)), |a, b| {
385            concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
386        });
387
388        let out_rand = random_init(3, obs.view(), &mut rng.clone());
389        let out_pp = k_means_plusplus(&dist_fn, 3, obs.view(), &mut rng.clone());
390        let out_para = k_means_para(&dist_fn, 3, obs.view(), &mut rng);
391        // Loss of Kmeans++ should be better than using random_init
392        assert!(calc_loss!(dist_fn, out_pp, obs) < calc_loss!(dist_fn, out_rand, obs));
393        // Loss of Kmeans|| should be better than using random_init
394        assert!(calc_loss!(dist_fn, out_para, obs) < calc_loss!(dist_fn, out_rand, obs));
395    }
396
397    #[test]
398    fn test_compare_l2() {
399        test_compare(L2Dist);
400    }
401
402    #[test]
403    fn test_compare_l1() {
404        test_compare(L1Dist);
405    }
406}