Skip to main content

linfa_clustering/k_means/
algorithm.rs

1use std::cmp::Ordering;
2use std::fmt::Debug;
3
4use crate::k_means::{KMeansParams, KMeansValidParams};
5use crate::{k_means::errors::KMeansError, KMeansInit};
6use crate::{IncrKMeansError, KMeansAlgorithm, KMeansParamsError};
7use linfa::{prelude::*, DatasetBase, Float};
8use linfa_nn::distance::{Distance, L2Dist};
9use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, DataMut, Ix1, Ix2, Zip};
10use ndarray_rand::rand::{Rng, SeedableRng};
11use rand_xoshiro::Xoshiro256Plus;
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16#[cfg_attr(
17    feature = "serde",
18    derive(Serialize, Deserialize),
19    serde(crate = "serde_crate")
20)]
21#[derive(Clone, Debug, PartialEq)]
22/// K-means clustering aims to partition a set of unlabeled observations into clusters,
23/// where each observation belongs to the cluster with the nearest mean.
24///
25/// The mean of the points within a cluster is called *centroid*.
26///
27/// Given the set of centroids, you can assign an observation to a cluster
28/// choosing the nearest centroid.
29///
30/// We provide a modified version of the _standard algorithm_ (also known as Lloyd's Algorithm),
31/// called m_k-means, which uses a slightly modified update step to avoid problems with empty
32/// clusters. In addition to Lloyd's algorithm, we also provide Hamerly's accelerated algorithm,
33/// which produces the same results but skips many distance computations using the triangle
34/// inequality. We also provide an incremental version of the algorithm that runs on smaller
35/// batches of input data.
36///
37/// More details on the algorithm can be found in the next section or
38/// [here](https://en.wikipedia.org/wiki/K-means_clustering). Details on m_k-means can be found
39/// [here](https://www.researchgate.net/publication/228414762_A_Modified_k-means_Algorithm_to_Avoid_Empty_Clusters).
40/// Details on Hamerly's algorithm can be found [here](https://cs.baylor.edu/~hamerly/papers/sdm_2010.pdf).
41///
42/// ## Standard algorithm
43///
44/// K-means is an iterative algorithm: it progressively refines the choice of centroids.
45///
46/// It's guaranteed to converge, even though it might not find the optimal set of centroids
47/// (unfortunately it can get stuck in a local minimum, finding the optimal minimum is NP-hard!).
48///
49/// There are three steps in the standard algorithm:
50/// - initialisation step: select initial centroids using one of our provided algorithms.
51/// - assignment step: assign each observation to the nearest cluster
52///   (minimum distance between the observation and the cluster's centroid);
53/// - update step: recompute the centroid of each cluster.
54///
55/// The initialisation step is a one-off, done at the very beginning.
56/// Assignment and update are repeated in a loop until convergence is reached (either the
57/// euclidean distance between the old and the new clusters is below `tolerance` or
58/// we exceed the `max_n_iterations`).
59///
60/// ## Hamerly's algorithm
61///
62/// Hamerly's algorithm is an exact accelerated variant of Lloyd's algorithm: given the same
63/// initial centroids it converges to the same final centroids, but usually in a fraction of the
64/// distance computations. For every observation it maintains an upper bound on the distance to
65/// its currently assigned centroid and a lower bound on the distance to the closest
66/// non-assigned centroid. At each iteration, these bounds together with the inter-centroid
67/// distances are used to cheaply prove that an observation cannot have changed cluster, in
68/// which case the exact distance is not recomputed at all.
69///
70/// Hamerly is typically faster than Lloyd when clusters are reasonably well separated and the
71/// number of clusters is moderate; when clusters overlap heavily or `n_clusters` is very large,
72/// the bookkeeping overhead can outweigh the savings. Hamerly requires a true metric for its
73/// triangle-inequality bounds to hold, so any custom distance function used with it must satisfy
74/// the metric axioms (`L2Dist`, `L1Dist` and `LInfDist` all qualify).
75///
76/// The algorithm variant is selected on [`KMeansParams`](crate::KMeansParams) via
77/// [`algorithm`](crate::KMeansParams::algorithm) with a [`KMeansAlgorithm`](crate::KMeansAlgorithm)
78/// value. Lloyd is the default; pass `KMeansAlgorithm::Hamerly` to opt in. Hamerly only affects
79/// standard batch `fit`: the incremental `fit_with` path always uses Lloyd.
80///
81/// ## Incremental Algorithm
82///
83/// In addition to the standard algorithm, we also provide an incremental version of K-means known
84/// as Mini-Batch K-means. In this algorithm, the dataset is divided into small batches, and the
85/// assignment and update steps are performed on each batch instead of the entire dataset. The
86/// update step also takes previous update steps into account when updating the centroids.
87///
88/// Due to using smaller batches, Mini-Batch K-means takes significantly less time to execute than
89/// the standard K-means algorithm, although it may yield slightly worse centroids.
90///
91/// More details on Mini-Batch K-means can be found [here](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf).
92///
93/// ## Parallelisation
94///
95/// The work performed by the assignment step does not require any coordination:
96/// the closest centroid for each point can be computed independently from the
97/// closest centroid for any of the remaining points.
98///
99/// This makes it a good candidate for parallel execution: `KMeans::fit` parallelises the
100/// assignment step thanks to the `rayon` feature in `ndarray`.
101///
102/// The update step requires a bit more coordination (computing a rolling mean in
103/// parallel) but it is still parallelisable.
104/// Nonetheless, our first attempts have not improved performance
105/// (most likely due to our strategy used to split work between threads), hence
106/// the update step is currently executed on a single thread.
107///
108/// ## Tutorial
109///
110/// Let's do a walkthrough of a training-predict-save example.
111///
112/// ```
113/// use linfa::DatasetBase;
114/// use linfa::traits::{Fit, FitWith, Predict};
115/// use linfa_clustering::{KMeansParams, KMeans, IncrKMeansError};
116/// use linfa_datasets::generate;
117/// use ndarray::{Axis, array, s};
118/// use ndarray_rand::rand::SeedableRng;
119/// use rand_xoshiro::Xoshiro256Plus;
120/// use approx::assert_abs_diff_eq;
121///
122/// // Our random number generator, seeded for reproducibility
123/// let seed = 42;
124/// let mut rng = Xoshiro256Plus::seed_from_u64(seed);
125///
126/// // `expected_centroids` has shape `(n_centroids, n_features)`
127/// // i.e. three points in the 2-dimensional plane
128/// let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
129/// // Let's generate a synthetic dataset: three blobs of observations
130/// // (100 points each) centered around our `expected_centroids`
131/// let data = generate::blobs(100, &expected_centroids, &mut rng);
132/// let n_clusters = expected_centroids.len_of(Axis(0));
133///
134/// // Standard K-means
135/// {
136///     let observations = DatasetBase::from(data.clone());
137///     // Let's configure and run our K-means algorithm
138///     // We use the builder pattern to specify the hyperparameters
139///     // `n_clusters` is the only mandatory parameter.
140///     // If you don't specify the others (e.g. `n_runs`, `tolerance`, `max_n_iterations`)
141///     // default values will be used.
142///     let model = KMeans::params_with_rng(n_clusters, rng.clone())
143///         .tolerance(1e-2)
144///         .fit(&observations)
145///         .expect("KMeans fitted");
146///
147///     // Once we found our set of centroids, we can also assign new points to the nearest cluster
148///     let new_observation = DatasetBase::from(array![[-9., 20.5]]);
149///     // Predict returns the **index** of the nearest cluster
150///     let dataset = model.predict(new_observation);
151///     // We can retrieve the actual centroid of the closest cluster using `.centroids()`
152///     let closest_centroid = &model.centroids().index_axis(Axis(0), dataset.targets()[0]);
153///     assert_abs_diff_eq!(closest_centroid.to_owned(), &array![-10., 20.], epsilon = 1e-1);
154/// }
155///
156/// // Incremental K-means
157/// {
158///     let batch_size = 100;
159///     // Shuffling the dataset is one way of ensuring that the batches contain random points from
160///     // the dataset, which is required for the algorithm to work properly
161///     let observations = DatasetBase::from(data.clone()).shuffle(&mut rng);
162///
163///     let n_clusters = expected_centroids.nrows();
164///     let clf = KMeans::params_with_rng(n_clusters, rng.clone()).tolerance(1e-3);
165///
166///     // Repeatedly run fit_with on every batch in the dataset until we have converged
167///     let model = observations
168///         .sample_chunks(batch_size)
169///         .cycle()
170///         .try_fold(None, |current, batch| {
171///             match clf.fit_with(current, &batch) {
172///                 // Early stop condition for the kmeans loop
173///                 Ok(model) => Err(model),
174///                 // Continue running if not converged
175///                 Err(IncrKMeansError::NotConverged(model)) => Ok(Some(model)),
176///                 Err(err) => panic!("unexpected kmeans error: {}", err),
177///             }
178///         })
179///         .unwrap_err();
180///
181///     let new_observation = DatasetBase::from(array![[-9., 20.5]]);
182///     let dataset = model.predict(new_observation);
183///     let closest_centroid = &model.centroids().index_axis(Axis(0), dataset.targets()[0]);
184///     assert_abs_diff_eq!(closest_centroid.to_owned(), &array![-10., 20.], epsilon = 1e-1);
185/// }
186/// ```
187///
188/*///
189/// // The model can be serialised (and deserialised) to disk using serde
190/// // We'll use the JSON format here for simplicity
191/// let filename = "k_means_model.json";
192/// let writer = std::fs::File::create(filename).expect("Failed to open file.");
193/// serde_json::to_writer(writer, &model).expect("Failed to serialise model.");
194///
195/// let reader = std::fs::File::open(filename).expect("Failed to open file.");
196/// let loaded_model: KMeans<f64> = serde_json::from_reader(reader).expect("Failed to deserialise model");
197///
198/// assert_abs_diff_eq!(model.centroids(), loaded_model.centroids(), epsilon = 1e-10);
199/// assert_eq!(model.hyperparameters(), loaded_model.hyperparameters());
200/// ```
201*/
202pub struct KMeans<F: Float, D: Distance<F>> {
203    centroids: Array2<F>,
204    cluster_count: Array1<F>,
205    inertia: F,
206    dist_fn: D,
207}
208
209impl<F: Float> KMeans<F, L2Dist> {
210    pub fn params(nclusters: usize) -> KMeansParams<F, Xoshiro256Plus, L2Dist> {
211        KMeansParams::new(nclusters, Xoshiro256Plus::seed_from_u64(42), L2Dist)
212    }
213
214    pub fn params_with_rng<R: Rng>(nclusters: usize, rng: R) -> KMeansParams<F, R, L2Dist> {
215        KMeansParams::new(nclusters, rng, L2Dist)
216    }
217}
218
219impl<F: Float, D: Distance<F>> KMeans<F, D> {
220    pub fn params_with<R: Rng>(nclusters: usize, rng: R, dist_fn: D) -> KMeansParams<F, R, D> {
221        KMeansParams::new(nclusters, rng, dist_fn)
222    }
223
224    /// Return the set of centroids as a 2-dimensional matrix with shape
225    /// `(n_centroids, n_features)`.
226    pub fn centroids(&self) -> &Array2<F> {
227        &self.centroids
228    }
229
230    /// Return the number of training points belonging to each cluster
231    pub fn cluster_count(&self) -> &Array1<F> {
232        &self.cluster_count
233    }
234
235    /// Return the sum of distances between each training point and its closest centroid, averaged
236    /// across all training points.  When training incrementally, this value is computed on the
237    /// most recent batch.
238    pub fn inertia(&self) -> F {
239        self.inertia
240    }
241}
242
243impl<F: Float, R: Rng + Clone, D: Distance<F>> KMeansValidParams<F, R, D> {
244    /// Fit KMeans using Hamerly's accelerated algorithm.
245    ///
246    /// Uses triangle inequality to skip unnecessary distance computations.
247    /// Reference: <https://cs.baylor.edu/~hamerly/papers/sdm_2010.pdf>
248    fn fit_hamerly<DA: Data<Elem = F>, T>(
249        &self,
250        dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
251    ) -> Result<KMeans<F, D>, KMeansError> {
252        let mut rng = self.rng().clone();
253        let observations = dataset.records().view();
254        let mut min_inertia = F::infinity();
255        let mut best_centroids = None;
256        let mut best_memberships = None;
257
258        for _ in 0..self.n_runs() {
259            let centroids =
260                self.init_method()
261                    .run(self.dist_fn(), self.n_clusters(), observations, &mut rng);
262            let mut hamerly = HamerlyAlgorithm::new(self.dist_fn(), observations, centroids);
263
264            let mut n_iter = 0;
265            let inertia = loop {
266                // No need to reassign observations on first iteration
267                if n_iter > 0 {
268                    hamerly.reassign_observations();
269                }
270                n_iter += 1;
271
272                let update = hamerly.recompute_centroids();
273
274                if update.convergence_dist < self.tolerance() || n_iter == self.max_n_iterations() {
275                    break hamerly.inertia();
276                }
277
278                hamerly.update_bounds(&update.distances_moved);
279            };
280
281            if inertia < min_inertia {
282                min_inertia = inertia;
283                let (centroids, memberships) = hamerly.into_parts();
284                best_centroids = Some(centroids);
285                best_memberships = Some(memberships);
286            }
287        }
288
289        let memberships = best_memberships.unwrap_or_else(|| Array1::zeros(dataset.nsamples()));
290        self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships)
291    }
292
293    /// Fit KMeans with Lloyd's algorithm.
294    fn fit_lloyd<DA: Data<Elem = F>, T>(
295        &self,
296        dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
297    ) -> Result<KMeans<F, D>, KMeansError> {
298        let mut rng = self.rng().clone();
299        let observations = dataset.records().view();
300        let n_samples = dataset.nsamples();
301
302        let mut min_inertia = F::infinity();
303        let mut best_centroids = None;
304        let mut memberships = Array1::zeros(n_samples);
305        let mut dists = Array1::zeros(n_samples);
306
307        let n_runs = self.n_runs();
308
309        for _ in 0..n_runs {
310            let mut centroids =
311                self.init_method()
312                    .run(self.dist_fn(), self.n_clusters(), observations, &mut rng);
313            let mut n_iter = 0;
314            let inertia = loop {
315                update_memberships_and_dists(
316                    self.dist_fn(),
317                    &centroids,
318                    &observations,
319                    &mut memberships,
320                    &mut dists,
321                );
322                let new_centroids = compute_centroids(&centroids, &observations, &memberships);
323                let distance = self
324                    .dist_fn()
325                    .distance(centroids.view(), new_centroids.view());
326                centroids = new_centroids;
327                n_iter += 1;
328                if distance < self.tolerance() || n_iter == self.max_n_iterations() {
329                    break dists.sum();
330                }
331            };
332
333            // We keep the centroids which minimize the inertia (defined as the sum of
334            // the squared distances of the closest centroid for all observations)
335            // over the n runs of the KMeans algorithm.
336            if inertia < min_inertia {
337                min_inertia = inertia;
338                best_centroids = Some(centroids.clone());
339            }
340        }
341
342        self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships)
343    }
344
345    fn get_kmeans_result<DA: Data<Elem = F>, T>(
346        &self,
347        dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
348        min_inertia: F,
349        best_centroids: Option<Array2<F>>,
350        memberships: Array1<usize>,
351    ) -> Result<KMeans<F, D>, KMeansError> {
352        match best_centroids {
353            Some(centroids) => {
354                let mut cluster_count = Array1::zeros(self.n_clusters());
355                memberships
356                    .iter()
357                    .for_each(|&c| cluster_count[c] += F::one());
358                Ok(KMeans {
359                    centroids,
360                    cluster_count,
361                    inertia: min_inertia / F::cast(dataset.nsamples()),
362                    dist_fn: self.dist_fn().clone(),
363                })
364            }
365            _ => Err(KMeansError::InertiaError),
366        }
367    }
368}
369
370impl<F: Float, R: Rng + Clone, DA: Data<Elem = F>, T, D: Distance<F>>
371    Fit<ArrayBase<DA, Ix2>, T, KMeansError> for KMeansValidParams<F, R, D>
372{
373    type Object = KMeans<F, D>;
374
375    /// Given an input matrix `observations`, with shape `(n_observations, n_features)`,
376    /// `fit` identifies `n_clusters` centroids based on the training data distribution.
377    ///
378    /// An instance of `KMeans` is returned.
379    fn fit(
380        &self,
381        dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
382    ) -> Result<Self::Object, KMeansError> {
383        match self.algorithm() {
384            KMeansAlgorithm::Lloyd => self.fit_lloyd(dataset),
385            KMeansAlgorithm::Hamerly => self.fit_hamerly(dataset),
386        }
387    }
388}
389
390struct CentroidUpdate<F> {
391    distances_moved: Array1<F>,
392    convergence_dist: F,
393}
394
395/// Encapsulates all state and logic for a single Hamerly K-means run.
396struct HamerlyAlgorithm<'a, F: Float, D: Distance<F>> {
397    /// Distance metric used for all point-to-centroid comparisons.
398    dist_fn: &'a D,
399    /// Input data matrix, shape `(n_observations, n_features)`.
400    observations: ArrayView2<'a, F>,
401    /// Current centroid positions, shape `(n_clusters, n_features)`.
402    centroids: Array2<F>,
403    /// Cluster index assigned to each observation.
404    memberships: Array1<usize>,
405    /// Per-observation upper bound on the distance to its assigned centroid.
406    upper_bounds: Array1<F>,
407    /// Per-observation lower bound on the distance to the nearest non-assigned centroid.
408    lower_bounds: Array1<F>,
409    /// Number of observations currently assigned to each centroid.
410    centroid_counts: Array1<usize>,
411    /// Running coordinate sum of observations per centroid, shape `(n_clusters, n_features)`.
412    centroid_sums: Array2<F>,
413    /// Memberships before reassignment
414    prev_memberships: Array1<usize>,
415}
416
417impl<'a, F: Float, D: Distance<F>> HamerlyAlgorithm<'a, F, D> {
418    fn new(dist_fn: &'a D, observations: ArrayView2<'a, F>, centroids: Array2<F>) -> Self {
419        let n_observations = observations.nrows();
420        let mut memberships = Array1::zeros(n_observations);
421        let mut upper_bounds = Array1::zeros(n_observations);
422        let mut lower_bounds = Array1::zeros(n_observations);
423
424        Zip::from(observations.rows())
425            .and(&mut memberships)
426            .and(&mut upper_bounds)
427            .and(&mut lower_bounds)
428            .par_for_each(|obs, membership, upper, lower| {
429                let (idx, closest_dist, second_dist) =
430                    two_closest_centroids(dist_fn, &centroids, &obs);
431                *membership = idx;
432                *upper = closest_dist;
433                *lower = second_dist;
434            });
435
436        let mut centroid_counts: Array1<usize> = Array1::zeros(centroids.nrows());
437        let mut centroid_sums = Array2::zeros(centroids.dim());
438        for (obs, &m) in observations.rows().into_iter().zip(memberships.iter()) {
439            centroid_counts[m] += 1;
440            let mut row = centroid_sums.row_mut(m);
441            row += &obs;
442        }
443
444        let prev_memberships = Array1::zeros(n_observations);
445
446        Self {
447            dist_fn,
448            observations,
449            centroids,
450            memberships,
451            upper_bounds,
452            lower_bounds,
453            centroid_counts,
454            centroid_sums,
455            prev_memberships,
456        }
457    }
458
459    fn nearest_inter_centroid_distances(&self) -> Array1<F> {
460        let mut dists = Array1::zeros(self.centroids.nrows());
461        for (i, centroid) in self.centroids.rows().into_iter().enumerate() {
462            let (_, _, second_dist) =
463                two_closest_centroids(self.dist_fn, &self.centroids, &centroid);
464            dists[i] = second_dist;
465        }
466        dists
467    }
468
469    fn reassign_observations(&mut self) {
470        let nearest_center_dists = self.nearest_inter_centroid_distances();
471        let centroids = &self.centroids;
472        let observations = self.observations;
473        let dist_fn = self.dist_fn;
474
475        Zip::from(observations.rows())
476            .and(&mut self.memberships)
477            .and(&mut self.upper_bounds)
478            .and(&mut self.lower_bounds)
479            .and(&mut self.prev_memberships)
480            .par_for_each(|obs, membership, upper, lower, prev_slot| {
481                let current = *membership;
482                *prev_slot = current;
483                let threshold = F::max(nearest_center_dists[current] / F::cast(2), *lower);
484
485                if *upper > threshold {
486                    *upper = dist_fn.distance(obs.view(), centroids.row(current).view());
487
488                    if *upper > threshold {
489                        let (idx, closest_dist, second_dist) =
490                            two_closest_centroids(dist_fn, centroids, &obs);
491                        *membership = idx;
492                        *upper = closest_dist;
493                        *lower = second_dist;
494                    }
495                }
496            });
497
498        for (i, (&old_membership, &new_membership)) in self
499            .prev_memberships
500            .iter()
501            .zip(self.memberships.iter())
502            .enumerate()
503        {
504            if old_membership != new_membership {
505                let observation = self.observations.row(i);
506                self.centroid_counts[old_membership] -= 1;
507                self.centroid_counts[new_membership] += 1;
508                let mut old_centroid_sum = self.centroid_sums.row_mut(old_membership);
509                old_centroid_sum -= &observation;
510                let mut new_centroid_sum = self.centroid_sums.row_mut(new_membership);
511                new_centroid_sum += &observation;
512            }
513        }
514    }
515
516    /// Recomputes centroids from accumulated centroid sums and counts
517    fn recompute_centroids(&mut self) -> CentroidUpdate<F> {
518        // m_k-means trick: The old centroid is treated as an extra point in each cluster as is done in Lloyd
519        let mut new_centroids = &self.centroid_sums + &self.centroids;
520        Zip::from(new_centroids.rows_mut())
521            .and(&self.centroid_counts)
522            .for_each(|mut centroid_sum, &n_members| {
523                // + 1 because we have added old centroid as an extra point
524                centroid_sum /= F::cast(n_members + 1);
525            });
526
527        let mut distances_moved = Array1::zeros(self.centroids.nrows());
528        Zip::from(&mut distances_moved)
529            .and(self.centroids.rows())
530            .and(new_centroids.rows())
531            .for_each(|d, old, new| *d = self.dist_fn.distance(old, new));
532
533        let convergence_dist = self
534            .dist_fn
535            .distance(self.centroids.view(), new_centroids.view());
536        self.centroids = new_centroids;
537
538        CentroidUpdate {
539            distances_moved,
540            convergence_dist,
541        }
542    }
543
544    fn update_bounds(&mut self, distances_moved: &Array1<F>) {
545        let (farthest_moved_idx, second_farthest_moved_idx) = two_farthest_indices(distances_moved);
546        Zip::from(&self.memberships)
547            .and(&mut self.upper_bounds)
548            .and(&mut self.lower_bounds)
549            .par_for_each(|&centroid_idx, upper, lower| {
550                *upper += distances_moved[centroid_idx];
551                if centroid_idx == farthest_moved_idx {
552                    *lower -= distances_moved[second_farthest_moved_idx];
553                } else {
554                    *lower -= distances_moved[farthest_moved_idx];
555                }
556            });
557    }
558
559    fn inertia(&self) -> F {
560        compute_inertia(
561            self.dist_fn,
562            self.observations,
563            &self.memberships,
564            &self.centroids,
565        )
566    }
567
568    fn into_parts(self) -> (Array2<F>, Array1<usize>) {
569        (self.centroids, self.memberships)
570    }
571}
572
573/// Returns the indices of the two centroids that moved the farthest.
574///
575/// For fewer than two elements the second index duplicates the first; callers
576/// only read `second_farthest` when an observation's own centroid is the
577/// farthest mover, which cannot happen when there is only one centroid.
578fn two_farthest_indices<F: Float>(distances: &Array1<F>) -> (usize, usize) {
579    if distances.len() < 2 {
580        return (0, 0);
581    }
582    let (mut farthest, mut second_farthest) = if distances[1] >= distances[0] {
583        (1, 0)
584    } else {
585        (0, 1)
586    };
587    for i in 2..distances.len() {
588        if distances[i] >= distances[farthest] {
589            second_farthest = farthest;
590            farthest = i;
591        } else if distances[i] > distances[second_farthest] {
592            second_farthest = i;
593        }
594    }
595    (farthest, second_farthest)
596}
597
598/// Computes total inertia: sum of squared distances from each observation to
599/// its assigned centroid.
600fn compute_inertia<F: Float, D: Distance<F>>(
601    dist_fn: &D,
602    observations: ArrayView2<F>,
603    memberships: &Array1<usize>,
604    centroids: &Array2<F>,
605) -> F {
606    observations
607        .rows()
608        .into_iter()
609        .zip(memberships.iter())
610        .map(|(obs, &m)| dist_fn.rdistance(obs.view(), centroids.row(m).view()))
611        .fold(F::zero(), |acc, d| acc + d)
612}
613
614impl<'a, F: Float + Debug, R: Rng + Clone, DA: Data<Elem = F>, T, D: 'a + Distance<F> + Debug>
615    FitWith<'a, ArrayBase<DA, Ix2>, T, IncrKMeansError<KMeans<F, D>>>
616    for KMeansValidParams<F, R, D>
617{
618    type ObjectIn = Option<KMeans<F, D>>;
619    type ObjectOut = KMeans<F, D>;
620
621    /// Performs a single batch update of the Mini-Batch K-means algorithm.
622    ///
623    /// Given an input matrix `observations`, with shape `(n_batch, n_features)` and a previous
624    /// `KMeans` model, the model's centroids are updated with the input matrix. If `model` is
625    /// `None`, then it's initialized using the specified initialization algorithm. The return
626    /// value consists of the updated model and a `bool` value that indicates whether the algorithm
627    /// has converged.
628    ///
629    /// Only [`KMeansAlgorithm::Lloyd`](crate::KMeansAlgorithm::Lloyd) is supported here: the
630    /// Mini-Batch path always uses Lloyd's update. Configuring
631    /// [`KMeansAlgorithm::Hamerly`](crate::KMeansAlgorithm::Hamerly) and then calling
632    /// `fit_with` returns [`KMeansParamsError::IncrementalHamerly`], because Hamerly's
633    /// per-observation bounds rely on a persistent dataset across iterations and cannot
634    /// amortise across independent Mini-Batch batches.
635    fn fit_with(
636        &self,
637        model: Self::ObjectIn,
638        dataset: &'a DatasetBase<ArrayBase<DA, Ix2>, T>,
639    ) -> Result<Self::ObjectOut, IncrKMeansError<Self::ObjectOut>> {
640        if *self.algorithm() == KMeansAlgorithm::Hamerly {
641            return Err(IncrKMeansError::InvalidParams(
642                KMeansParamsError::IncrementalHamerly,
643            ));
644        }
645        let observations = dataset.records().view();
646        let n_samples = dataset.nsamples();
647
648        let mut model = match model {
649            Some(model) => model,
650            None => {
651                let centroids = if let KMeansInit::Precomputed(centroids) = self.init_method() {
652                    // If using precomputed centroids, don't run the init algorithm multiple times
653                    centroids.clone()
654                } else {
655                    let mut rng = self.rng().clone();
656                    let mut dists = Array1::zeros(n_samples);
657                    // Initial centroids derived from the first batch by running the init algorithm
658                    // n_runs times and taking the centroids with the lowest inertia
659                    (0..self.n_runs())
660                        .map(|_| {
661                            let centroids = self.init_method().run(
662                                self.dist_fn(),
663                                self.n_clusters(),
664                                observations,
665                                &mut rng,
666                            );
667                            update_min_dists(self.dist_fn(), &centroids, &observations, &mut dists);
668                            (centroids, dists.sum())
669                        })
670                        .min_by(|(_, d1), (_, d2)| {
671                            if d1 < d2 {
672                                Ordering::Less
673                            } else {
674                                Ordering::Greater
675                            }
676                        })
677                        .unwrap()
678                        .0
679                };
680                KMeans {
681                    centroids,
682                    cluster_count: Array1::zeros(self.n_clusters()),
683                    inertia: F::zero(),
684                    dist_fn: self.dist_fn().clone(),
685                }
686            }
687        };
688
689        let mut memberships = Array1::zeros(n_samples);
690        let mut dists = Array1::zeros(n_samples);
691        update_memberships_and_dists(
692            self.dist_fn(),
693            &model.centroids,
694            &observations,
695            &mut memberships,
696            &mut dists,
697        );
698        let new_centroids = compute_centroids_incremental(
699            &observations,
700            &memberships,
701            &model.centroids,
702            &mut model.cluster_count,
703        );
704        model.inertia = dists.sum() / F::cast(n_samples);
705        let dist = self
706            .dist_fn()
707            .distance(model.centroids.view(), new_centroids.view());
708        model.centroids = new_centroids;
709
710        if dist < self.tolerance() {
711            Ok(model)
712        } else {
713            Err(IncrKMeansError::NotConverged(model))
714        }
715    }
716}
717
718impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> Transformer<&ArrayBase<DA, Ix2>, Array1<F>>
719    for KMeans<F, D>
720{
721    /// Given an input matrix `observations`, with shape `(n_observations, n_features)`,
722    /// `transform` returns, for each observation, its squared distance to its centroid.
723    fn transform(&self, observations: &ArrayBase<DA, Ix2>) -> Array1<F> {
724        let mut dists = Array1::zeros(observations.nrows());
725        update_min_dists(
726            &self.dist_fn,
727            &self.centroids,
728            &observations.view(),
729            &mut dists,
730        );
731        dists
732    }
733}
734
735impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Ix2>, Array1<usize>>
736    for KMeans<F, D>
737{
738    /// Given an input matrix `observations`, with shape `(n_observations, n_features)`,
739    /// `predict` returns, for each observation, the index of the closest cluster/centroid.
740    ///
741    /// You can retrieve the centroid associated to an index using the
742    /// [`centroids` method](#method.centroids).
743    fn predict_inplace(&self, observations: &ArrayBase<DA, Ix2>, memberships: &mut Array1<usize>) {
744        assert_eq!(
745            observations.nrows(),
746            memberships.len(),
747            "The number of data points must match the number of memberships."
748        );
749
750        update_cluster_memberships(
751            &self.dist_fn,
752            &self.centroids,
753            &observations.view(),
754            memberships,
755        );
756    }
757
758    fn default_target(&self, x: &ArrayBase<DA, Ix2>) -> Array1<usize> {
759        Array1::zeros(x.nrows())
760    }
761}
762
763impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Ix1>, usize>
764    for KMeans<F, D>
765{
766    /// Given one input observation, return the index of its closest cluster
767    ///
768    /// You can retrieve the centroid associated to an index using the
769    /// [`centroids` method](#method.centroids).
770    fn predict_inplace(&self, observation: &ArrayBase<DA, Ix1>, membership: &mut usize) {
771        *membership = closest_centroid(&self.dist_fn, &self.centroids, observation).0;
772    }
773
774    fn default_target(&self, _x: &ArrayBase<DA, Ix1>) -> usize {
775        0
776    }
777}
778
779/// K-means is an iterative algorithm.
780/// We will perform the assignment and update steps until we are satisfied
781/// (according to our convergence criteria).
782///
783/// `compute_centroids` returns a 2-dimensional array,
784/// where the i-th row corresponds to the i-th cluster.
785fn compute_centroids<F: Float>(
786    old_centroids: &Array2<F>,
787    // (n_observations, n_features)
788    observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
789    // (n_observations,)
790    cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
791) -> Array2<F> {
792    let n_clusters = old_centroids.nrows();
793    let mut counts: Array1<usize> = Array1::ones(n_clusters);
794    let mut centroids = Array2::zeros((n_clusters, observations.ncols()));
795
796    Zip::from(observations.rows())
797        .and(cluster_memberships)
798        .for_each(|observation, &cluster_membership| {
799            let mut centroid = centroids.row_mut(cluster_membership);
800            centroid += &observation;
801            counts[cluster_membership] += 1;
802        });
803    // m_k-means: Treat the old centroid like another point in the cluster
804    centroids += old_centroids;
805
806    Zip::from(centroids.rows_mut())
807        .and(&counts)
808        .for_each(|mut centroid, &cnt| centroid /= F::cast(cnt));
809    centroids
810}
811
812/// Returns new centroids which has the moving average of all observations in each cluster added to
813/// the old centroids.
814/// Updates `counts` with the number of observations in each cluster.
815fn compute_centroids_incremental<F: Float>(
816    observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
817    cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
818    old_centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
819    counts: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
820) -> Array2<F> {
821    let mut centroids = old_centroids.to_owned();
822    // We can parallelize this
823    Zip::from(observations.rows())
824        .and(cluster_memberships)
825        .for_each(|obs, &c| {
826            // Computes centroids[c] += (observation - centroids[c]) / counts[c]
827            // If cluster is empty for this batch, then this wouldn't even be called, so no
828            // chance of getting NaN.
829            counts[c] += F::one();
830            let shift = (&obs - &centroids.row(c)) / counts[c];
831            let mut centroid = centroids.row_mut(c);
832            centroid += &shift;
833        });
834    centroids
835}
836
837// Update `cluster_memberships` with the index of the cluster each observation belongs to.
838pub(crate) fn update_cluster_memberships<F: Float, D: Distance<F>>(
839    dist_fn: &D,
840    centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
841    observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
842    cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
843) {
844    Zip::from(observations.axis_iter(Axis(0)))
845        .and(cluster_memberships)
846        .par_for_each(|observation, cluster_membership| {
847            *cluster_membership = closest_centroid(dist_fn, centroids, &observation).0
848        });
849}
850
851// Updates `dists` with the distance of each observation from its closest centroid.
852pub(crate) fn update_min_dists<F: Float, D: Distance<F>>(
853    dist_fn: &D,
854    centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
855    observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
856    dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
857) {
858    Zip::from(observations.axis_iter(Axis(0)))
859        .and(dists)
860        .par_for_each(|observation, dist| {
861            *dist = closest_centroid(dist_fn, centroids, &observation).1
862        });
863}
864
865/// Efficient combination of `update_cluster_memberships` and `update_min_dists`.
866pub(crate) fn update_memberships_and_dists<F: Float, D: Distance<F>>(
867    dist_fn: &D,
868    centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
869    observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
870    cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
871    dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
872) {
873    Zip::from(observations.axis_iter(Axis(0)))
874        .and(cluster_memberships)
875        .and(dists)
876        .par_for_each(|observation, cluster_membership, dist| {
877            let (m, d) = closest_centroid(dist_fn, centroids, &observation);
878            *cluster_membership = m;
879            *dist = d;
880        });
881}
882
883/// Given a matrix of centroids with shape (n_centroids, n_features) and an observation,
884/// return the index of the two closest centroids (the index of the corresponding row in `centroids`)
885/// and their distances.
886///
887/// Uses `distance` (not `rdistance`) because Hamerly's triangle-inequality bounds
888/// only hold under a true metric — do not "optimize" this to squared distance.
889fn two_closest_centroids<F: Float, D: Distance<F>>(
890    dist_fn: &D,
891    // (n_centroids, n_features)
892    centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
893    // (n_features)
894    observation: &ArrayBase<impl Data<Elem = F>, Ix1>,
895) -> (usize, F, F) {
896    if centroids.nrows() == 1 {
897        return (0, F::cast(0), F::cast(0));
898    }
899    let first_centroid = centroids.row(0);
900    let second_centroid = centroids.row(1);
901    let dist1 = dist_fn.distance(observation.view(), first_centroid.view());
902    let dist2 = dist_fn.distance(observation.view(), second_centroid.view());
903
904    let mut closest_index = if dist1 < dist2 { 0 } else { 1 };
905    let mut closest_distance = if dist1 < dist2 { dist1 } else { dist2 };
906    let mut second_closest_distance = if dist1 < dist2 { dist2 } else { dist1 };
907
908    for (centroid_index, centroid) in centroids.rows().into_iter().skip(2).enumerate() {
909        let distance = dist_fn.distance(observation.view(), centroid.view());
910        if closest_distance <= distance && distance < second_closest_distance {
911            second_closest_distance = distance;
912        } else if distance < closest_distance {
913            second_closest_distance = closest_distance;
914            closest_index = centroid_index + 2; // We skipped 2 centroids
915            closest_distance = distance;
916        }
917    }
918    (closest_index, closest_distance, second_closest_distance)
919}
920
921/// Given a matrix of centroids with shape (n_centroids, n_features) and an observation,
922/// return the index of the closest centroid (the index of the corresponding row in `centroids`).
923pub(crate) fn closest_centroid<F: Float, D: Distance<F>>(
924    dist_fn: &D,
925    // (n_centroids, n_features)
926    centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
927    // (n_features)
928    observation: &ArrayBase<impl Data<Elem = F>, Ix1>,
929) -> (usize, F) {
930    let iterator = centroids.rows().into_iter();
931
932    let first_centroid = centroids.row(0);
933    let (mut closest_index, mut minimum_distance) = (
934        0,
935        dist_fn.rdistance(first_centroid.view(), observation.view()),
936    );
937
938    for (centroid_index, centroid) in iterator.enumerate() {
939        let distance = dist_fn.rdistance(centroid.view(), observation.view());
940        if distance < minimum_distance {
941            closest_index = centroid_index;
942            minimum_distance = distance;
943        }
944    }
945    (closest_index, minimum_distance)
946}
947
948#[cfg(test)]
949mod tests {
950    use super::super::KMeansInit;
951    use super::*;
952    use crate::KMeansParamsError;
953    use approx::assert_abs_diff_eq;
954    use linfa_nn::distance::L1Dist;
955    use ndarray::{array, concatenate, Array, Array1, Array2, Axis};
956    use ndarray_rand::rand::prelude::ThreadRng;
957    use ndarray_rand::rand::SeedableRng;
958    use ndarray_rand::rand_distr::Uniform;
959    use ndarray_rand::RandomExt;
960
961    #[test]
962    fn autotraits() {
963        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
964        has_autotraits::<KMeans<f64, L2Dist>>();
965        has_autotraits::<KMeansAlgorithm>();
966        has_autotraits::<KMeansParamsError>();
967        has_autotraits::<KMeansError>();
968        has_autotraits::<IncrKMeansError<String>>();
969    }
970
971    fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
972        let mut y = Array2::zeros(x.dim());
973        Zip::from(&mut y).and(x).for_each(|yi, &xi| {
974            if xi < 0.4 {
975                *yi = xi * xi;
976            } else if (0.4..0.8).contains(&xi) {
977                *yi = 3. * xi + 1.;
978            } else {
979                *yi = f64::sin(10. * xi);
980            }
981        });
982        y
983    }
984
985    macro_rules! calc_inertia {
986        ($dist:expr, $centroids:expr, $obs:expr, $memberships:expr) => {
987            $obs.rows()
988                .into_iter()
989                .zip($memberships.iter())
990                .map(|(row, &c)| $dist.rdistance(row.view(), $centroids.row(c).view()))
991                .sum::<f64>()
992        };
993    }
994
995    macro_rules! calc_memberships {
996        ($dist:expr, $centroids:expr, $obs:expr) => {{
997            let mut memberships = Array1::zeros($obs.nrows());
998            update_cluster_memberships(&$dist, &$centroids, &$obs, &mut memberships);
999            memberships
1000        }};
1001    }
1002
1003    #[test]
1004    fn test_min_dists() {
1005        let centroids = array![[0.0, 1.0], [40.0, 10.0]];
1006        let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
1007        let mut dists = Array1::zeros(observations.nrows());
1008
1009        update_min_dists(&L2Dist, &centroids, &observations, &mut dists);
1010        assert_abs_diff_eq!(dists, array![18.0, 5.0, 250.0]);
1011        update_min_dists(&L1Dist, &centroids, &observations, &mut dists);
1012        assert_abs_diff_eq!(dists, array![6.0, 3.0, 20.0]);
1013    }
1014
1015    fn test_n_runs<D: Distance<f64>>(dist_fn: D) {
1016        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1017        let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1018        let yt = function_test_1d(&xt);
1019        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1020
1021        for init in &[
1022            KMeansInit::Random,
1023            KMeansInit::KMeansPlusPlus,
1024            KMeansInit::KMeansPara,
1025        ] {
1026            // First clustering with one iteration
1027            let dataset = DatasetBase::from(data.clone());
1028            let model = KMeans::params_with(3, rng.clone(), dist_fn.clone())
1029                .n_runs(1)
1030                .init_method(init.clone())
1031                .fit(&dataset)
1032                .expect("KMeans fitted");
1033            let clusters = model.predict(dataset);
1034            let inertia = calc_inertia!(
1035                dist_fn,
1036                model.centroids(),
1037                clusters.records,
1038                clusters.targets
1039            );
1040            let total_dist = model.transform(&clusters.records.view()).sum();
1041            assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5);
1042
1043            let single_cluster: usize = model.predict(&data.row(0));
1044            assert_abs_diff_eq!(single_cluster, clusters.targets[0]);
1045
1046            // Second clustering with 10 iterations (default)
1047            let dataset2 = DatasetBase::from(clusters.records().clone());
1048            let model2 = KMeans::params_with(3, rng.clone(), dist_fn.clone())
1049                .init_method(init.clone())
1050                .fit(&dataset2)
1051                .expect("KMeans fitted");
1052            let clusters2 = model2.predict(dataset2);
1053            let inertia2 = calc_inertia!(
1054                dist_fn,
1055                model2.centroids(),
1056                clusters2.records,
1057                clusters2.targets
1058            );
1059            let total_dist2 = model2.transform(&clusters2.records.view()).sum();
1060            assert_abs_diff_eq!(inertia2, total_dist2, epsilon = 1e-5);
1061
1062            // Check we improve inertia (only really makes a difference for random init)
1063            if *init == KMeansInit::Random {
1064                assert!(inertia2 <= inertia);
1065            }
1066        }
1067    }
1068
1069    #[test]
1070    fn test_n_runs_l2dist() {
1071        test_n_runs(L2Dist);
1072    }
1073
1074    #[test]
1075    fn test_n_runs_l1dist() {
1076        test_n_runs(L1Dist);
1077    }
1078
1079    #[test]
1080    fn compute_centroids_works() {
1081        let cluster_size = 100;
1082        let n_features = 4;
1083
1084        // Let's setup a synthetic set of observations, composed of two clusters with known means
1085        let cluster_1: Array2<f64> =
1086            Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
1087        let memberships_1 = Array1::zeros(cluster_size);
1088        let expected_centroid_1 = cluster_1.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
1089
1090        let cluster_2: Array2<f64> =
1091            Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
1092        let memberships_2 = Array1::ones(cluster_size);
1093        let expected_centroid_2 = cluster_2.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
1094
1095        // `concatenate` combines arrays along a given axis: https://docs.rs/ndarray/0.13.0/ndarray/fn.concatenate.html
1096        let observations = concatenate(Axis(0), &[cluster_1.view(), cluster_2.view()]).unwrap();
1097        let memberships =
1098            concatenate(Axis(0), &[memberships_1.view(), memberships_2.view()]).unwrap();
1099
1100        // Does it work?
1101        let old_centroids = Array2::zeros((2, n_features));
1102        let centroids = compute_centroids(&old_centroids, &observations, &memberships);
1103        assert_abs_diff_eq!(
1104            centroids.index_axis(Axis(0), 0),
1105            expected_centroid_1,
1106            epsilon = 1e-5
1107        );
1108        assert_abs_diff_eq!(
1109            centroids.index_axis(Axis(0), 1),
1110            expected_centroid_2,
1111            epsilon = 1e-5
1112        );
1113
1114        assert_eq!(centroids.len_of(Axis(0)), 2);
1115    }
1116
1117    #[test]
1118    fn test_compute_extra_centroids() {
1119        let observations = array![[1.0, 2.0]];
1120        let memberships = array![0];
1121        // Should return an average of 0 for empty clusters
1122        let old_centroids = Array2::ones((2, 2));
1123        let centroids = compute_centroids(&old_centroids, &observations, &memberships);
1124        assert_abs_diff_eq!(centroids, array![[1.0, 1.5], [1.0, 1.0]]);
1125    }
1126
1127    #[test]
1128    // An observation is closest to itself.
1129    fn nothing_is_closer_than_self() {
1130        let n_centroids = 20;
1131        let n_features = 5;
1132        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1133        let centroids: Array2<f64> = Array::random_using(
1134            (n_centroids, n_features),
1135            Uniform::new(-100., 100.),
1136            &mut rng,
1137        );
1138
1139        let expected_memberships = (0..n_centroids).collect::<Array1<_>>();
1140        assert_eq!(
1141            calc_memberships!(L2Dist, centroids, centroids),
1142            expected_memberships
1143        );
1144        assert_eq!(
1145            calc_memberships!(L1Dist, centroids, centroids),
1146            expected_memberships
1147        );
1148    }
1149
1150    #[test]
1151    fn oracle_test_for_closest_centroid() {
1152        let centroids = array![[0., 0.], [1., 2.], [20., 0.], [0., 20.],];
1153        let observations = array![[1., 0.6], [20., 2.], [20., 0.], [7., 20.],];
1154        let l2_memberships = array![0, 2, 2, 3];
1155        let l1_memberships = array![1, 2, 2, 3];
1156
1157        assert_eq!(
1158            calc_memberships!(L2Dist, centroids, observations),
1159            l2_memberships
1160        );
1161        assert_eq!(
1162            calc_memberships!(L1Dist, centroids, observations),
1163            l1_memberships
1164        );
1165    }
1166
1167    #[test]
1168    fn test_compute_centroids_incremental() {
1169        let observations = array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]];
1170        let memberships = array![0, 0, 1, 1];
1171        let centroids = array![[-1., -1.], [3., 4.], [7., 8.]];
1172        let mut counts = array![3.0, 0.0, 1.0];
1173        let centroids =
1174            compute_centroids_incremental(&observations, &memberships, &centroids, &mut counts);
1175
1176        assert_abs_diff_eq!(centroids, array![[-4. / 5., -6. / 5.], [4., 5.], [7., 8.]]);
1177        assert_abs_diff_eq!(counts, array![5., 2., 1.]);
1178    }
1179
1180    #[test]
1181    fn test_incremental_kmeans() {
1182        let dataset1 = DatasetBase::from(array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]]);
1183        let dataset2 = DatasetBase::from(array![[-5.0, -5.0], [0., 0.], [10., 10.]]);
1184        let model = KMeans {
1185            centroids: array![[-1., -1.], [3., 4.], [7., 8.]],
1186            cluster_count: array![0., 0., 0.],
1187            inertia: 0.0,
1188            dist_fn: L2Dist,
1189        };
1190        let rng = Xoshiro256Plus::seed_from_u64(45);
1191        let params = KMeans::params_with_rng(3, rng).tolerance(100.0);
1192
1193        // Should converge on first try
1194        let model = params.fit_with(Some(model), &dataset1).unwrap();
1195        assert_abs_diff_eq!(model.centroids(), &array![[-0.5, -1.5], [4., 5.], [7., 8.]]);
1196
1197        let model = params.fit_with(Some(model), &dataset2).unwrap();
1198        assert_abs_diff_eq!(
1199            model.centroids(),
1200            &array![[-6. / 4., -8. / 4.], [4., 5.], [10., 10.]]
1201        );
1202    }
1203
1204    #[test]
1205    fn fit_with_rejects_hamerly() {
1206        let rng = Xoshiro256Plus::seed_from_u64(45);
1207        let params = KMeans::params_with_rng(2, rng)
1208            .algorithm(KMeansAlgorithm::Hamerly)
1209            .init_method(KMeansInit::Precomputed(array![[0., 0.], [10., 10.]]));
1210        let data = DatasetBase::from(array![[1., 1.], [11., 11.]]);
1211        let err = params
1212            .fit_with(None, &data)
1213            .expect_err("Hamerly + fit_with must be rejected");
1214        assert!(matches!(
1215            err,
1216            IncrKMeansError::InvalidParams(KMeansParamsError::IncrementalHamerly)
1217        ));
1218    }
1219
1220    #[test]
1221    fn test_tolerance() {
1222        let rng = Xoshiro256Plus::seed_from_u64(45);
1223        // The "correct" centroid for the dataset is [6, 6], so the centroid distance from the
1224        // initial centroid in the first iteration should be around 8.48. With a tolerance of 8.5,
1225        // KMeans should converge on first iteration.
1226        let params = KMeans::params_with_rng(1, rng)
1227            .tolerance(8.5)
1228            .init_method(KMeansInit::Precomputed(array![[0., 0.]]));
1229        let data = DatasetBase::from(array![[1., 1.], [11., 11.]]);
1230        assert!(params.fit_with(None, &data).is_ok());
1231    }
1232
1233    #[test]
1234    fn test_max_n_iterations() {
1235        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1236        let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1237        let yt = function_test_1d(&xt);
1238        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1239        let dataset = DatasetBase::from(data.clone());
1240        // For data created using the above rng and seed, for 6 clusters, it would take 8 iterations to converge.
1241        // However, when specifying max_n_iterations as 5, the algorithm should stop early gracefully.
1242        let _model = KMeans::params_with(6, rng.clone(), L2Dist)
1243            .n_runs(1)
1244            .max_n_iterations(5)
1245            .init_method(KMeansInit::Random)
1246            .fit(&dataset)
1247            .expect("KMeans fitted");
1248    }
1249
1250    fn sort_centroids(c: &Array2<f64>) -> Array2<f64> {
1251        let mut rows: Vec<Vec<f64>> = c.rows().into_iter().map(|r| r.to_vec()).collect();
1252        rows.sort_by(|a, b| {
1253            for (x, y) in a.iter().zip(b.iter()) {
1254                match x.partial_cmp(y) {
1255                    Some(std::cmp::Ordering::Equal) => continue,
1256                    Some(ord) => return ord,
1257                    None => continue,
1258                }
1259            }
1260            std::cmp::Ordering::Equal
1261        });
1262        let flat: Vec<f64> = rows.into_iter().flatten().collect();
1263        Array2::from_shape_vec((c.nrows(), c.ncols()), flat).unwrap()
1264    }
1265
1266    fn hamerly_lloyd_equivalence<D: Distance<f64>>(dist_fn: D, init: KMeansInit<f64>) {
1267        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1268        let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1269        let yt = function_test_1d(&xt);
1270        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1271        let dataset = DatasetBase::from(data);
1272
1273        let model_lloyd = KMeans::params_with(6, rng.clone(), dist_fn.clone())
1274            .n_runs(3)
1275            .algorithm(KMeansAlgorithm::Lloyd)
1276            .init_method(init.clone())
1277            .fit(&dataset)
1278            .expect("Lloyd fitted");
1279        let model_hamerly = KMeans::params_with(6, rng.clone(), dist_fn)
1280            .n_runs(3)
1281            .algorithm(KMeansAlgorithm::Hamerly)
1282            .init_method(init)
1283            .fit(&dataset)
1284            .expect("Hamerly fitted");
1285
1286        assert_eq!(model_lloyd.centroids().nrows(), 6);
1287        assert_abs_diff_eq!(
1288            model_lloyd.inertia(),
1289            model_hamerly.inertia(),
1290            epsilon = 1e-4
1291        );
1292        assert_abs_diff_eq!(
1293            sort_centroids(model_lloyd.centroids()),
1294            sort_centroids(model_hamerly.centroids()),
1295            epsilon = 1e-4
1296        );
1297    }
1298
1299    #[test]
1300    fn hamerly_lloyd_equivalence_random_l2() {
1301        hamerly_lloyd_equivalence(L2Dist, KMeansInit::Random);
1302    }
1303
1304    #[test]
1305    fn hamerly_lloyd_equivalence_plusplus_l2() {
1306        hamerly_lloyd_equivalence(L2Dist, KMeansInit::KMeansPlusPlus);
1307    }
1308
1309    fn hamerly_lloyd_equivalence_para<D: Distance<f64>>(dist_fn: D) {
1310        // KMeansPara uses Rayon parallelism and is non-deterministic across concurrent test
1311        // runs. Pre-compute centroids deterministically and pass them as Precomputed so
1312        // both Lloyd and Hamerly start from the same initial centroids.
1313        let mut rng = Xoshiro256Plus::seed_from_u64(99);
1314        let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1315        let yt = function_test_1d(&xt);
1316        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1317        let dataset = DatasetBase::from(data);
1318        let init = KMeansInit::Precomputed(KMeansInit::KMeansPlusPlus.run(
1319            &dist_fn,
1320            6,
1321            dataset.records().view(),
1322            &mut rng,
1323        ));
1324        hamerly_lloyd_equivalence(dist_fn, init);
1325    }
1326
1327    #[test]
1328    fn hamerly_lloyd_equivalence_para_l2() {
1329        hamerly_lloyd_equivalence_para(L2Dist);
1330    }
1331
1332    #[test]
1333    fn hamerly_lloyd_equivalence_random_l1() {
1334        hamerly_lloyd_equivalence(L1Dist, KMeansInit::Random);
1335    }
1336
1337    #[test]
1338    fn hamerly_lloyd_equivalence_plusplus_l1() {
1339        hamerly_lloyd_equivalence(L1Dist, KMeansInit::KMeansPlusPlus);
1340    }
1341
1342    #[test]
1343    fn hamerly_lloyd_equivalence_para_l1() {
1344        hamerly_lloyd_equivalence_para(L1Dist);
1345    }
1346
1347    #[test]
1348    fn test_two_closest_centroids_l2() {
1349        let centroids = array![[0.0, 0.0], [10.0, 0.0], [0.0, 10.0]];
1350        let obs = array![1.0, 1.0];
1351        let (idx, closest, second) = two_closest_centroids(&L2Dist, &centroids, &obs);
1352        assert_eq!(idx, 0);
1353        assert_abs_diff_eq!(closest, f64::sqrt(2.0), epsilon = 1e-10);
1354        assert_abs_diff_eq!(second, f64::sqrt(82.0), epsilon = 1e-10);
1355    }
1356
1357    #[test]
1358    fn test_two_closest_centroids_l1() {
1359        let centroids = array![[0.0, 0.0], [10.0, 0.0], [0.0, 10.0]];
1360        let obs = array![1.0, 1.0];
1361        let (idx, closest, second) = two_closest_centroids(&L1Dist, &centroids, &obs);
1362        assert_eq!(idx, 0);
1363        assert_abs_diff_eq!(closest, 2.0, epsilon = 1e-10);
1364        assert_abs_diff_eq!(second, 10.0, epsilon = 1e-10);
1365    }
1366
1367    #[test]
1368    fn test_two_closest_centroids_single() {
1369        let centroids = array![[5.0, 5.0]];
1370        let obs = array![1.0, 1.0];
1371        let (idx, closest, second) = two_closest_centroids(&L2Dist, &centroids, &obs);
1372        assert_eq!(idx, 0);
1373        assert_abs_diff_eq!(closest, 0.0);
1374        assert_abs_diff_eq!(second, 0.0);
1375    }
1376
1377    #[test]
1378    fn test_two_closest_centroids_obs_is_centroid() {
1379        let centroids = array![[0.0, 0.0], [3.0, 4.0], [10.0, 0.0]];
1380        let obs = array![3.0, 4.0];
1381        let (idx, closest, second) = two_closest_centroids(&L2Dist, &centroids, &obs);
1382        assert_eq!(idx, 1);
1383        assert_abs_diff_eq!(closest, 0.0, epsilon = 1e-10);
1384        assert_abs_diff_eq!(second, 5.0, epsilon = 1e-10);
1385    }
1386
1387    #[test]
1388    fn test_two_closest_centroids_equidistant() {
1389        let centroids = array![[2.0, 0.0], [0.0, 2.0]];
1390        let obs = array![1.0, 1.0];
1391        let (idx, closest, second) = two_closest_centroids(&L2Dist, &centroids, &obs);
1392        // When equidistant, index 1 is chosen because `if dist1 < dist2` is false
1393        assert_eq!(idx, 1);
1394        assert_abs_diff_eq!(closest, f64::sqrt(2.0), epsilon = 1e-10);
1395        assert_abs_diff_eq!(second, f64::sqrt(2.0), epsilon = 1e-10);
1396    }
1397
1398    #[test]
1399    fn test_two_farthest_indices() {
1400        // Distinct values
1401        assert_eq!(two_farthest_indices(&array![1.0, 5.0, 3.0, 2.0]), (1, 2));
1402
1403        // All equal: repeated >= swaps chain through all indices
1404        assert_eq!(two_farthest_indices(&array![3.0, 3.0, 3.0]), (2, 1));
1405
1406        // Two elements
1407        assert_eq!(two_farthest_indices(&array![2.0, 7.0]), (1, 0));
1408        assert_eq!(two_farthest_indices(&array![7.0, 2.0]), (0, 1));
1409
1410        // Largest at end
1411        assert_eq!(two_farthest_indices(&array![8.0, 1.0, 2.0, 9.0]), (3, 0));
1412
1413        // Largest at start: second must be the actual runner-up
1414        assert_eq!(two_farthest_indices(&array![9.0, 1.0, 2.0, 8.0]), (0, 3));
1415
1416        // Single element degenerates to (0, 0)
1417        assert_eq!(two_farthest_indices(&array![1.0]), (0, 0));
1418    }
1419
1420    #[test]
1421    fn test_recompute_centroids() {
1422        let obs = array![[0.0, 0.0]];
1423        let centroids = array![[0.0, 0.0], [0.0, 0.0]];
1424        let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1425        // m_k-means: new = (sums + old) / (counts + 1) = [8/4, 12/4], [15/3, 30/3]
1426        hamerly.centroid_sums = array![[8.0, 12.0], [15.0, 30.0]];
1427        hamerly.centroid_counts = array![3_usize, 2];
1428        hamerly.recompute_centroids();
1429        assert_abs_diff_eq!(
1430            hamerly.centroids,
1431            array![[2.0, 3.0], [5.0, 10.0]],
1432            epsilon = 1e-10
1433        );
1434
1435        // Empty cluster: (0 + old) / (0 + 1) = old, so the centroid is preserved.
1436        let centroids2 = array![[7.0, 7.0], [0.0, 0.0]];
1437        let mut hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2);
1438        hamerly2.centroid_sums = array![[0.0, 0.0], [15.0, 30.0]];
1439        hamerly2.centroid_counts = array![0_usize, 2];
1440        hamerly2.recompute_centroids();
1441        assert_abs_diff_eq!(
1442            hamerly2.centroids,
1443            array![[7.0, 7.0], [5.0, 10.0]],
1444            epsilon = 1e-10
1445        );
1446    }
1447
1448    #[test]
1449    fn test_recompute_centroids_distances_moved() {
1450        let obs = array![[0.0, 0.0]];
1451        let centroids = array![[0.0, 0.0], [10.0, 0.0]];
1452        let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1453        // m_k-means: new = (sums + old) / (counts + 1) = [2/2, 0/2], [20/2, 6/2]
1454        //   → [1.0, 0.0] and [10.0, 3.0], moved 1.0 and 3.0 respectively
1455        hamerly.centroid_sums = array![[2.0, 0.0], [10.0, 6.0]];
1456        hamerly.centroid_counts = array![1_usize, 1];
1457        let update = hamerly.recompute_centroids();
1458        assert_abs_diff_eq!(update.distances_moved, array![1.0, 3.0], epsilon = 1e-10);
1459
1460        // No movement
1461        let centroids2 = array![[5.0, 5.0], [10.0, 10.0]];
1462        let mut hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2);
1463        hamerly2.centroid_sums = array![[5.0, 5.0], [10.0, 10.0]];
1464        hamerly2.centroid_counts = array![1_usize, 1];
1465        let update2 = hamerly2.recompute_centroids();
1466        assert_abs_diff_eq!(update2.distances_moved, array![0.0, 0.0], epsilon = 1e-10);
1467    }
1468
1469    #[test]
1470    fn test_nearest_inter_centroid_distances() {
1471        let obs = array![[0.0, 0.0]];
1472        let centroids = array![[0.0, 0.0], [3.0, 0.0], [0.0, 4.0]];
1473        let hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1474        let dists = hamerly.nearest_inter_centroid_distances();
1475        assert_abs_diff_eq!(dists, array![3.0, 3.0, 4.0], epsilon = 1e-10);
1476
1477        // Two centroids: symmetric
1478        let centroids2 = array![[0.0, 0.0], [5.0, 0.0]];
1479        let hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2);
1480        let dists2 = hamerly2.nearest_inter_centroid_distances();
1481        assert_abs_diff_eq!(dists2, array![5.0, 5.0], epsilon = 1e-10);
1482    }
1483
1484    #[test]
1485    fn test_hamerly_strategy_new() {
1486        let obs = array![[0.0, 0.0], [1.0, 0.0], [10.0, 10.0]];
1487        let centroids = array![[0.0, 0.0], [10.0, 10.0]];
1488        let hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1489        assert_eq!(hamerly.memberships, array![0_usize, 0, 1]);
1490        assert_eq!(hamerly.centroid_counts, array![2_usize, 1]);
1491        assert_abs_diff_eq!(
1492            hamerly.centroid_sums,
1493            array![[1.0, 0.0], [10.0, 10.0]],
1494            epsilon = 1e-10
1495        );
1496    }
1497
1498    #[test]
1499    fn test_update_bounds_oracle() {
1500        let obs = array![[0.0, 0.0], [10.0, 0.0], [0.0, 0.0]];
1501        let centroids = array![[0.0, 0.0], [10.0, 0.0]];
1502        let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1503        hamerly.memberships = array![0_usize, 1, 0];
1504        hamerly.upper_bounds = array![5.0, 3.0, 4.0];
1505        hamerly.lower_bounds = array![2.0, 1.0, 3.0];
1506        let distances_moved = array![1.0, 0.5];
1507        hamerly.update_bounds(&distances_moved);
1508        assert_abs_diff_eq!(hamerly.upper_bounds, array![6.0, 3.5, 5.0], epsilon = 1e-10);
1509        assert_abs_diff_eq!(hamerly.lower_bounds, array![1.5, 0.0, 2.5], epsilon = 1e-10);
1510    }
1511
1512    #[test]
1513    fn test_compute_inertia() {
1514        let obs = array![[0.0, 0.0], [3.0, 4.0]];
1515        let memberships = array![0_usize, 0];
1516        let centroids = array![[1.0, 1.0]];
1517        let inertia = compute_inertia(&L2Dist, obs.view(), &memberships, &centroids);
1518        // rdistance: (0-1)^2+(0-1)^2 + (3-1)^2+(4-1)^2 = 2 + 13 = 15
1519        assert_abs_diff_eq!(inertia, 15.0, epsilon = 1e-10);
1520    }
1521
1522    fn test_n_runs_hamerly<D: Distance<f64>>(dist_fn: D) {
1523        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1524        let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1525        let yt = function_test_1d(&xt);
1526        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1527
1528        for init in &[
1529            KMeansInit::Random,
1530            KMeansInit::KMeansPlusPlus,
1531            KMeansInit::KMeansPara,
1532        ] {
1533            let dataset = DatasetBase::from(data.clone());
1534            let model = KMeans::params_with(3, rng.clone(), dist_fn.clone())
1535                .n_runs(1)
1536                .algorithm(KMeansAlgorithm::Hamerly)
1537                .init_method(init.clone())
1538                .fit(&dataset)
1539                .expect("KMeans fitted");
1540            let clusters = model.predict(dataset);
1541            let inertia = calc_inertia!(
1542                dist_fn,
1543                model.centroids(),
1544                clusters.records,
1545                clusters.targets
1546            );
1547            let total_dist = model.transform(&clusters.records.view()).sum();
1548            assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5);
1549
1550            let single_cluster: usize = model.predict(&data.row(0));
1551            assert_abs_diff_eq!(single_cluster, clusters.targets[0]);
1552
1553            let dataset2 = DatasetBase::from(clusters.records().clone());
1554            let model2 = KMeans::params_with(3, rng.clone(), dist_fn.clone())
1555                .algorithm(KMeansAlgorithm::Hamerly)
1556                .init_method(init.clone())
1557                .fit(&dataset2)
1558                .expect("KMeans fitted");
1559            let clusters2 = model2.predict(dataset2);
1560            let inertia2 = calc_inertia!(
1561                dist_fn,
1562                model2.centroids(),
1563                clusters2.records,
1564                clusters2.targets
1565            );
1566            let total_dist2 = model2.transform(&clusters2.records.view()).sum();
1567            assert_abs_diff_eq!(inertia2, total_dist2, epsilon = 1e-5);
1568
1569            if *init == KMeansInit::Random {
1570                assert!(inertia2 <= inertia);
1571            }
1572        }
1573    }
1574
1575    #[test]
1576    fn test_n_runs_hamerly_l2dist() {
1577        test_n_runs_hamerly(L2Dist);
1578    }
1579
1580    #[test]
1581    fn test_n_runs_hamerly_l1dist() {
1582        test_n_runs_hamerly(L1Dist);
1583    }
1584
1585    #[test]
1586    fn test_hamerly_precomputed_centroids() {
1587        let rng = Xoshiro256Plus::seed_from_u64(42);
1588        let data = array![
1589            [0.0, 0.0],
1590            [1.0, 0.0],
1591            [0.0, 1.0],
1592            [10.0, 10.0],
1593            [11.0, 10.0],
1594            [10.0, 11.0]
1595        ];
1596        let init_centroids = array![[0.0, 0.0], [10.0, 10.0]];
1597        let dataset = DatasetBase::from(data);
1598
1599        let model_lloyd = KMeans::params_with(2, rng.clone(), L2Dist)
1600            .n_runs(1)
1601            .algorithm(KMeansAlgorithm::Lloyd)
1602            .init_method(KMeansInit::Precomputed(init_centroids.clone()))
1603            .fit(&dataset)
1604            .expect("Lloyd fitted");
1605        let model_hamerly = KMeans::params_with(2, rng.clone(), L2Dist)
1606            .n_runs(1)
1607            .algorithm(KMeansAlgorithm::Hamerly)
1608            .init_method(KMeansInit::Precomputed(init_centroids))
1609            .fit(&dataset)
1610            .expect("Hamerly fitted");
1611
1612        assert_abs_diff_eq!(
1613            model_lloyd.centroids(),
1614            model_hamerly.centroids(),
1615            epsilon = 1e-1
1616        );
1617        assert_abs_diff_eq!(
1618            model_lloyd.inertia(),
1619            model_hamerly.inertia(),
1620            epsilon = 1e-1
1621        );
1622    }
1623
1624    #[test]
1625    fn test_hamerly_single_cluster() {
1626        let rng = Xoshiro256Plus::seed_from_u64(42);
1627        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1628        let dataset = DatasetBase::from(data);
1629        let model = KMeans::params_with_rng(1, rng)
1630            .algorithm(KMeansAlgorithm::Hamerly)
1631            .fit(&dataset)
1632            .expect("KMeans fitted");
1633        assert_abs_diff_eq!(model.centroids(), &array![[4.0, 5.0]], epsilon = 1e-4);
1634    }
1635
1636    #[test]
1637    fn test_hamerly_n_clusters_eq_n_samples() {
1638        let rng = Xoshiro256Plus::seed_from_u64(42);
1639        let data = array![[1.0, 2.0], [10.0, 20.0], [-5.0, -5.0], [100.0, 0.0]];
1640        let dataset = DatasetBase::from(data.clone());
1641        let model = KMeans::params_with_rng(4, rng)
1642            .algorithm(KMeansAlgorithm::Hamerly)
1643            .init_method(KMeansInit::Precomputed(data))
1644            .fit(&dataset)
1645            .expect("KMeans fitted");
1646        assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10);
1647    }
1648
1649    #[test]
1650    fn test_hamerly_single_observation() {
1651        let rng = Xoshiro256Plus::seed_from_u64(42);
1652        let data = array![[3.0, 7.0]];
1653        let dataset = DatasetBase::from(data);
1654        let model = KMeans::params_with_rng(1, rng)
1655            .algorithm(KMeansAlgorithm::Hamerly)
1656            .fit(&dataset)
1657            .expect("KMeans fitted");
1658        assert_abs_diff_eq!(model.centroids(), &array![[3.0, 7.0]], epsilon = 1e-10);
1659        assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10);
1660    }
1661
1662    #[test]
1663    fn test_hamerly_identical_data() {
1664        let rng = Xoshiro256Plus::seed_from_u64(42);
1665        let data = array![[5.0, 5.0], [5.0, 5.0], [5.0, 5.0], [5.0, 5.0]];
1666        let dataset = DatasetBase::from(data);
1667        let model = KMeans::params_with_rng(1, rng)
1668            .algorithm(KMeansAlgorithm::Hamerly)
1669            .fit(&dataset)
1670            .expect("KMeans fitted");
1671        assert_abs_diff_eq!(model.centroids(), &array![[5.0, 5.0]], epsilon = 1e-10);
1672        assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10);
1673    }
1674
1675    #[test]
1676    fn test_hamerly_high_dimensionality() {
1677        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1678        let data: Array2<f64> = Array::random_using((200, 50), Uniform::new(-100., 100.), &mut rng);
1679        let dataset = DatasetBase::from(data);
1680
1681        let model_lloyd = KMeans::params_with(5, rng.clone(), L2Dist)
1682            .n_runs(1)
1683            .algorithm(KMeansAlgorithm::Lloyd)
1684            .init_method(KMeansInit::Random)
1685            .fit(&dataset)
1686            .expect("Lloyd fitted");
1687        let model_hamerly = KMeans::params_with(5, rng.clone(), L2Dist)
1688            .n_runs(1)
1689            .algorithm(KMeansAlgorithm::Hamerly)
1690            .init_method(KMeansInit::Random)
1691            .fit(&dataset)
1692            .expect("Hamerly fitted");
1693
1694        assert_abs_diff_eq!(
1695            model_lloyd.inertia(),
1696            model_hamerly.inertia(),
1697            epsilon = 1e-5
1698        );
1699        assert_abs_diff_eq!(
1700            model_lloyd.centroids(),
1701            model_hamerly.centroids(),
1702            epsilon = 1e-5
1703        );
1704    }
1705
1706    #[test]
1707    fn test_hamerly_max_n_iterations() {
1708        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1709        let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1710        let yt = function_test_1d(&xt);
1711        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1712        let dataset = DatasetBase::from(data);
1713        let _model = KMeans::params_with(6, rng.clone(), L2Dist)
1714            .n_runs(1)
1715            .max_n_iterations(5)
1716            .algorithm(KMeansAlgorithm::Hamerly)
1717            .init_method(KMeansInit::Random)
1718            .fit(&dataset)
1719            .expect("KMeans fitted");
1720    }
1721
1722    #[test]
1723    fn test_hamerly_tolerance() {
1724        let rng = Xoshiro256Plus::seed_from_u64(45);
1725        let data = DatasetBase::from(array![[1., 1.], [11., 11.]]);
1726        let model = KMeans::params_with_rng(1, rng)
1727            .tolerance(8.5)
1728            .algorithm(KMeansAlgorithm::Hamerly)
1729            .init_method(KMeansInit::Precomputed(array![[0., 0.]]))
1730            .fit(&data)
1731            .expect("KMeans fitted");
1732        assert_abs_diff_eq!(model.centroids(), &array![[4., 4.]], epsilon = 1e-1);
1733    }
1734
1735    #[test]
1736    fn test_hamerly_predict_transform_consistency() {
1737        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1738        let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1739        let yt = function_test_1d(&xt);
1740        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1741        let dataset = DatasetBase::from(data);
1742
1743        let model = KMeans::params_with(3, rng.clone(), L2Dist)
1744            .algorithm(KMeansAlgorithm::Hamerly)
1745            .fit(&dataset)
1746            .expect("Hamerly fitted");
1747
1748        let clusters = model.predict(dataset);
1749        assert!(clusters.targets.iter().all(|&c| c < 3));
1750
1751        let inertia = calc_inertia!(
1752            L2Dist,
1753            model.centroids(),
1754            clusters.records,
1755            clusters.targets
1756        );
1757        let total_dist = model.transform(&clusters.records.view()).sum();
1758        assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5);
1759    }
1760
1761    fn fittable<T: Fit<Array2<f64>, (), KMeansError>>(_: T) {}
1762    #[test]
1763    fn thread_rng_fittable() {
1764        fittable(KMeans::params_with_rng(1, ThreadRng::default()));
1765    }
1766}