1use std::cmp::Ordering;
2use std::fmt::Debug;
3
4use crate::k_means::{KMeansParams, KMeansValidParams};
5use crate::{k_means::errors::KMeansError, KMeansInit};
6use crate::{IncrKMeansError, KMeansAlgorithm, KMeansParamsError};
7use linfa::{prelude::*, DatasetBase, Float};
8use linfa_nn::distance::{Distance, L2Dist};
9use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, DataMut, Ix1, Ix2, Zip};
10use ndarray_rand::rand::{Rng, SeedableRng};
11use rand_xoshiro::Xoshiro256Plus;
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16#[cfg_attr(
17 feature = "serde",
18 derive(Serialize, Deserialize),
19 serde(crate = "serde_crate")
20)]
21#[derive(Clone, Debug, PartialEq)]
22pub struct KMeans<F: Float, D: Distance<F>> {
203 centroids: Array2<F>,
204 cluster_count: Array1<F>,
205 inertia: F,
206 dist_fn: D,
207}
208
209impl<F: Float> KMeans<F, L2Dist> {
210 pub fn params(nclusters: usize) -> KMeansParams<F, Xoshiro256Plus, L2Dist> {
211 KMeansParams::new(nclusters, Xoshiro256Plus::seed_from_u64(42), L2Dist)
212 }
213
214 pub fn params_with_rng<R: Rng>(nclusters: usize, rng: R) -> KMeansParams<F, R, L2Dist> {
215 KMeansParams::new(nclusters, rng, L2Dist)
216 }
217}
218
219impl<F: Float, D: Distance<F>> KMeans<F, D> {
220 pub fn params_with<R: Rng>(nclusters: usize, rng: R, dist_fn: D) -> KMeansParams<F, R, D> {
221 KMeansParams::new(nclusters, rng, dist_fn)
222 }
223
224 pub fn centroids(&self) -> &Array2<F> {
227 &self.centroids
228 }
229
230 pub fn cluster_count(&self) -> &Array1<F> {
232 &self.cluster_count
233 }
234
235 pub fn inertia(&self) -> F {
239 self.inertia
240 }
241}
242
243impl<F: Float, R: Rng + Clone, D: Distance<F>> KMeansValidParams<F, R, D> {
244 fn fit_hamerly<DA: Data<Elem = F>, T>(
249 &self,
250 dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
251 ) -> Result<KMeans<F, D>, KMeansError> {
252 let mut rng = self.rng().clone();
253 let observations = dataset.records().view();
254 let mut min_inertia = F::infinity();
255 let mut best_centroids = None;
256 let mut best_memberships = None;
257
258 for _ in 0..self.n_runs() {
259 let centroids =
260 self.init_method()
261 .run(self.dist_fn(), self.n_clusters(), observations, &mut rng);
262 let mut hamerly = HamerlyAlgorithm::new(self.dist_fn(), observations, centroids);
263
264 let mut n_iter = 0;
265 let inertia = loop {
266 if n_iter > 0 {
268 hamerly.reassign_observations();
269 }
270 n_iter += 1;
271
272 let update = hamerly.recompute_centroids();
273
274 if update.convergence_dist < self.tolerance() || n_iter == self.max_n_iterations() {
275 break hamerly.inertia();
276 }
277
278 hamerly.update_bounds(&update.distances_moved);
279 };
280
281 if inertia < min_inertia {
282 min_inertia = inertia;
283 let (centroids, memberships) = hamerly.into_parts();
284 best_centroids = Some(centroids);
285 best_memberships = Some(memberships);
286 }
287 }
288
289 let memberships = best_memberships.unwrap_or_else(|| Array1::zeros(dataset.nsamples()));
290 self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships)
291 }
292
293 fn fit_lloyd<DA: Data<Elem = F>, T>(
295 &self,
296 dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
297 ) -> Result<KMeans<F, D>, KMeansError> {
298 let mut rng = self.rng().clone();
299 let observations = dataset.records().view();
300 let n_samples = dataset.nsamples();
301
302 let mut min_inertia = F::infinity();
303 let mut best_centroids = None;
304 let mut memberships = Array1::zeros(n_samples);
305 let mut dists = Array1::zeros(n_samples);
306
307 let n_runs = self.n_runs();
308
309 for _ in 0..n_runs {
310 let mut centroids =
311 self.init_method()
312 .run(self.dist_fn(), self.n_clusters(), observations, &mut rng);
313 let mut n_iter = 0;
314 let inertia = loop {
315 update_memberships_and_dists(
316 self.dist_fn(),
317 ¢roids,
318 &observations,
319 &mut memberships,
320 &mut dists,
321 );
322 let new_centroids = compute_centroids(¢roids, &observations, &memberships);
323 let distance = self
324 .dist_fn()
325 .distance(centroids.view(), new_centroids.view());
326 centroids = new_centroids;
327 n_iter += 1;
328 if distance < self.tolerance() || n_iter == self.max_n_iterations() {
329 break dists.sum();
330 }
331 };
332
333 if inertia < min_inertia {
337 min_inertia = inertia;
338 best_centroids = Some(centroids.clone());
339 }
340 }
341
342 self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships)
343 }
344
345 fn get_kmeans_result<DA: Data<Elem = F>, T>(
346 &self,
347 dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
348 min_inertia: F,
349 best_centroids: Option<Array2<F>>,
350 memberships: Array1<usize>,
351 ) -> Result<KMeans<F, D>, KMeansError> {
352 match best_centroids {
353 Some(centroids) => {
354 let mut cluster_count = Array1::zeros(self.n_clusters());
355 memberships
356 .iter()
357 .for_each(|&c| cluster_count[c] += F::one());
358 Ok(KMeans {
359 centroids,
360 cluster_count,
361 inertia: min_inertia / F::cast(dataset.nsamples()),
362 dist_fn: self.dist_fn().clone(),
363 })
364 }
365 _ => Err(KMeansError::InertiaError),
366 }
367 }
368}
369
370impl<F: Float, R: Rng + Clone, DA: Data<Elem = F>, T, D: Distance<F>>
371 Fit<ArrayBase<DA, Ix2>, T, KMeansError> for KMeansValidParams<F, R, D>
372{
373 type Object = KMeans<F, D>;
374
375 fn fit(
380 &self,
381 dataset: &DatasetBase<ArrayBase<DA, Ix2>, T>,
382 ) -> Result<Self::Object, KMeansError> {
383 match self.algorithm() {
384 KMeansAlgorithm::Lloyd => self.fit_lloyd(dataset),
385 KMeansAlgorithm::Hamerly => self.fit_hamerly(dataset),
386 }
387 }
388}
389
390struct CentroidUpdate<F> {
391 distances_moved: Array1<F>,
392 convergence_dist: F,
393}
394
395struct HamerlyAlgorithm<'a, F: Float, D: Distance<F>> {
397 dist_fn: &'a D,
399 observations: ArrayView2<'a, F>,
401 centroids: Array2<F>,
403 memberships: Array1<usize>,
405 upper_bounds: Array1<F>,
407 lower_bounds: Array1<F>,
409 centroid_counts: Array1<usize>,
411 centroid_sums: Array2<F>,
413 prev_memberships: Array1<usize>,
415}
416
417impl<'a, F: Float, D: Distance<F>> HamerlyAlgorithm<'a, F, D> {
418 fn new(dist_fn: &'a D, observations: ArrayView2<'a, F>, centroids: Array2<F>) -> Self {
419 let n_observations = observations.nrows();
420 let mut memberships = Array1::zeros(n_observations);
421 let mut upper_bounds = Array1::zeros(n_observations);
422 let mut lower_bounds = Array1::zeros(n_observations);
423
424 Zip::from(observations.rows())
425 .and(&mut memberships)
426 .and(&mut upper_bounds)
427 .and(&mut lower_bounds)
428 .par_for_each(|obs, membership, upper, lower| {
429 let (idx, closest_dist, second_dist) =
430 two_closest_centroids(dist_fn, ¢roids, &obs);
431 *membership = idx;
432 *upper = closest_dist;
433 *lower = second_dist;
434 });
435
436 let mut centroid_counts: Array1<usize> = Array1::zeros(centroids.nrows());
437 let mut centroid_sums = Array2::zeros(centroids.dim());
438 for (obs, &m) in observations.rows().into_iter().zip(memberships.iter()) {
439 centroid_counts[m] += 1;
440 let mut row = centroid_sums.row_mut(m);
441 row += &obs;
442 }
443
444 let prev_memberships = Array1::zeros(n_observations);
445
446 Self {
447 dist_fn,
448 observations,
449 centroids,
450 memberships,
451 upper_bounds,
452 lower_bounds,
453 centroid_counts,
454 centroid_sums,
455 prev_memberships,
456 }
457 }
458
459 fn nearest_inter_centroid_distances(&self) -> Array1<F> {
460 let mut dists = Array1::zeros(self.centroids.nrows());
461 for (i, centroid) in self.centroids.rows().into_iter().enumerate() {
462 let (_, _, second_dist) =
463 two_closest_centroids(self.dist_fn, &self.centroids, ¢roid);
464 dists[i] = second_dist;
465 }
466 dists
467 }
468
469 fn reassign_observations(&mut self) {
470 let nearest_center_dists = self.nearest_inter_centroid_distances();
471 let centroids = &self.centroids;
472 let observations = self.observations;
473 let dist_fn = self.dist_fn;
474
475 Zip::from(observations.rows())
476 .and(&mut self.memberships)
477 .and(&mut self.upper_bounds)
478 .and(&mut self.lower_bounds)
479 .and(&mut self.prev_memberships)
480 .par_for_each(|obs, membership, upper, lower, prev_slot| {
481 let current = *membership;
482 *prev_slot = current;
483 let threshold = F::max(nearest_center_dists[current] / F::cast(2), *lower);
484
485 if *upper > threshold {
486 *upper = dist_fn.distance(obs.view(), centroids.row(current).view());
487
488 if *upper > threshold {
489 let (idx, closest_dist, second_dist) =
490 two_closest_centroids(dist_fn, centroids, &obs);
491 *membership = idx;
492 *upper = closest_dist;
493 *lower = second_dist;
494 }
495 }
496 });
497
498 for (i, (&old_membership, &new_membership)) in self
499 .prev_memberships
500 .iter()
501 .zip(self.memberships.iter())
502 .enumerate()
503 {
504 if old_membership != new_membership {
505 let observation = self.observations.row(i);
506 self.centroid_counts[old_membership] -= 1;
507 self.centroid_counts[new_membership] += 1;
508 let mut old_centroid_sum = self.centroid_sums.row_mut(old_membership);
509 old_centroid_sum -= &observation;
510 let mut new_centroid_sum = self.centroid_sums.row_mut(new_membership);
511 new_centroid_sum += &observation;
512 }
513 }
514 }
515
516 fn recompute_centroids(&mut self) -> CentroidUpdate<F> {
518 let mut new_centroids = &self.centroid_sums + &self.centroids;
520 Zip::from(new_centroids.rows_mut())
521 .and(&self.centroid_counts)
522 .for_each(|mut centroid_sum, &n_members| {
523 centroid_sum /= F::cast(n_members + 1);
525 });
526
527 let mut distances_moved = Array1::zeros(self.centroids.nrows());
528 Zip::from(&mut distances_moved)
529 .and(self.centroids.rows())
530 .and(new_centroids.rows())
531 .for_each(|d, old, new| *d = self.dist_fn.distance(old, new));
532
533 let convergence_dist = self
534 .dist_fn
535 .distance(self.centroids.view(), new_centroids.view());
536 self.centroids = new_centroids;
537
538 CentroidUpdate {
539 distances_moved,
540 convergence_dist,
541 }
542 }
543
544 fn update_bounds(&mut self, distances_moved: &Array1<F>) {
545 let (farthest_moved_idx, second_farthest_moved_idx) = two_farthest_indices(distances_moved);
546 Zip::from(&self.memberships)
547 .and(&mut self.upper_bounds)
548 .and(&mut self.lower_bounds)
549 .par_for_each(|¢roid_idx, upper, lower| {
550 *upper += distances_moved[centroid_idx];
551 if centroid_idx == farthest_moved_idx {
552 *lower -= distances_moved[second_farthest_moved_idx];
553 } else {
554 *lower -= distances_moved[farthest_moved_idx];
555 }
556 });
557 }
558
559 fn inertia(&self) -> F {
560 compute_inertia(
561 self.dist_fn,
562 self.observations,
563 &self.memberships,
564 &self.centroids,
565 )
566 }
567
568 fn into_parts(self) -> (Array2<F>, Array1<usize>) {
569 (self.centroids, self.memberships)
570 }
571}
572
573fn two_farthest_indices<F: Float>(distances: &Array1<F>) -> (usize, usize) {
579 if distances.len() < 2 {
580 return (0, 0);
581 }
582 let (mut farthest, mut second_farthest) = if distances[1] >= distances[0] {
583 (1, 0)
584 } else {
585 (0, 1)
586 };
587 for i in 2..distances.len() {
588 if distances[i] >= distances[farthest] {
589 second_farthest = farthest;
590 farthest = i;
591 } else if distances[i] > distances[second_farthest] {
592 second_farthest = i;
593 }
594 }
595 (farthest, second_farthest)
596}
597
598fn compute_inertia<F: Float, D: Distance<F>>(
601 dist_fn: &D,
602 observations: ArrayView2<F>,
603 memberships: &Array1<usize>,
604 centroids: &Array2<F>,
605) -> F {
606 observations
607 .rows()
608 .into_iter()
609 .zip(memberships.iter())
610 .map(|(obs, &m)| dist_fn.rdistance(obs.view(), centroids.row(m).view()))
611 .fold(F::zero(), |acc, d| acc + d)
612}
613
614impl<'a, F: Float + Debug, R: Rng + Clone, DA: Data<Elem = F>, T, D: 'a + Distance<F> + Debug>
615 FitWith<'a, ArrayBase<DA, Ix2>, T, IncrKMeansError<KMeans<F, D>>>
616 for KMeansValidParams<F, R, D>
617{
618 type ObjectIn = Option<KMeans<F, D>>;
619 type ObjectOut = KMeans<F, D>;
620
621 fn fit_with(
636 &self,
637 model: Self::ObjectIn,
638 dataset: &'a DatasetBase<ArrayBase<DA, Ix2>, T>,
639 ) -> Result<Self::ObjectOut, IncrKMeansError<Self::ObjectOut>> {
640 if *self.algorithm() == KMeansAlgorithm::Hamerly {
641 return Err(IncrKMeansError::InvalidParams(
642 KMeansParamsError::IncrementalHamerly,
643 ));
644 }
645 let observations = dataset.records().view();
646 let n_samples = dataset.nsamples();
647
648 let mut model = match model {
649 Some(model) => model,
650 None => {
651 let centroids = if let KMeansInit::Precomputed(centroids) = self.init_method() {
652 centroids.clone()
654 } else {
655 let mut rng = self.rng().clone();
656 let mut dists = Array1::zeros(n_samples);
657 (0..self.n_runs())
660 .map(|_| {
661 let centroids = self.init_method().run(
662 self.dist_fn(),
663 self.n_clusters(),
664 observations,
665 &mut rng,
666 );
667 update_min_dists(self.dist_fn(), ¢roids, &observations, &mut dists);
668 (centroids, dists.sum())
669 })
670 .min_by(|(_, d1), (_, d2)| {
671 if d1 < d2 {
672 Ordering::Less
673 } else {
674 Ordering::Greater
675 }
676 })
677 .unwrap()
678 .0
679 };
680 KMeans {
681 centroids,
682 cluster_count: Array1::zeros(self.n_clusters()),
683 inertia: F::zero(),
684 dist_fn: self.dist_fn().clone(),
685 }
686 }
687 };
688
689 let mut memberships = Array1::zeros(n_samples);
690 let mut dists = Array1::zeros(n_samples);
691 update_memberships_and_dists(
692 self.dist_fn(),
693 &model.centroids,
694 &observations,
695 &mut memberships,
696 &mut dists,
697 );
698 let new_centroids = compute_centroids_incremental(
699 &observations,
700 &memberships,
701 &model.centroids,
702 &mut model.cluster_count,
703 );
704 model.inertia = dists.sum() / F::cast(n_samples);
705 let dist = self
706 .dist_fn()
707 .distance(model.centroids.view(), new_centroids.view());
708 model.centroids = new_centroids;
709
710 if dist < self.tolerance() {
711 Ok(model)
712 } else {
713 Err(IncrKMeansError::NotConverged(model))
714 }
715 }
716}
717
718impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> Transformer<&ArrayBase<DA, Ix2>, Array1<F>>
719 for KMeans<F, D>
720{
721 fn transform(&self, observations: &ArrayBase<DA, Ix2>) -> Array1<F> {
724 let mut dists = Array1::zeros(observations.nrows());
725 update_min_dists(
726 &self.dist_fn,
727 &self.centroids,
728 &observations.view(),
729 &mut dists,
730 );
731 dists
732 }
733}
734
735impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Ix2>, Array1<usize>>
736 for KMeans<F, D>
737{
738 fn predict_inplace(&self, observations: &ArrayBase<DA, Ix2>, memberships: &mut Array1<usize>) {
744 assert_eq!(
745 observations.nrows(),
746 memberships.len(),
747 "The number of data points must match the number of memberships."
748 );
749
750 update_cluster_memberships(
751 &self.dist_fn,
752 &self.centroids,
753 &observations.view(),
754 memberships,
755 );
756 }
757
758 fn default_target(&self, x: &ArrayBase<DA, Ix2>) -> Array1<usize> {
759 Array1::zeros(x.nrows())
760 }
761}
762
763impl<F: Float, DA: Data<Elem = F>, D: Distance<F>> PredictInplace<ArrayBase<DA, Ix1>, usize>
764 for KMeans<F, D>
765{
766 fn predict_inplace(&self, observation: &ArrayBase<DA, Ix1>, membership: &mut usize) {
771 *membership = closest_centroid(&self.dist_fn, &self.centroids, observation).0;
772 }
773
774 fn default_target(&self, _x: &ArrayBase<DA, Ix1>) -> usize {
775 0
776 }
777}
778
779fn compute_centroids<F: Float>(
786 old_centroids: &Array2<F>,
787 observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
789 cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
791) -> Array2<F> {
792 let n_clusters = old_centroids.nrows();
793 let mut counts: Array1<usize> = Array1::ones(n_clusters);
794 let mut centroids = Array2::zeros((n_clusters, observations.ncols()));
795
796 Zip::from(observations.rows())
797 .and(cluster_memberships)
798 .for_each(|observation, &cluster_membership| {
799 let mut centroid = centroids.row_mut(cluster_membership);
800 centroid += &observation;
801 counts[cluster_membership] += 1;
802 });
803 centroids += old_centroids;
805
806 Zip::from(centroids.rows_mut())
807 .and(&counts)
808 .for_each(|mut centroid, &cnt| centroid /= F::cast(cnt));
809 centroids
810}
811
812fn compute_centroids_incremental<F: Float>(
816 observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
817 cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
818 old_centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
819 counts: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
820) -> Array2<F> {
821 let mut centroids = old_centroids.to_owned();
822 Zip::from(observations.rows())
824 .and(cluster_memberships)
825 .for_each(|obs, &c| {
826 counts[c] += F::one();
830 let shift = (&obs - ¢roids.row(c)) / counts[c];
831 let mut centroid = centroids.row_mut(c);
832 centroid += &shift;
833 });
834 centroids
835}
836
837pub(crate) fn update_cluster_memberships<F: Float, D: Distance<F>>(
839 dist_fn: &D,
840 centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
841 observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
842 cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
843) {
844 Zip::from(observations.axis_iter(Axis(0)))
845 .and(cluster_memberships)
846 .par_for_each(|observation, cluster_membership| {
847 *cluster_membership = closest_centroid(dist_fn, centroids, &observation).0
848 });
849}
850
851pub(crate) fn update_min_dists<F: Float, D: Distance<F>>(
853 dist_fn: &D,
854 centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
855 observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
856 dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
857) {
858 Zip::from(observations.axis_iter(Axis(0)))
859 .and(dists)
860 .par_for_each(|observation, dist| {
861 *dist = closest_centroid(dist_fn, centroids, &observation).1
862 });
863}
864
865pub(crate) fn update_memberships_and_dists<F: Float, D: Distance<F>>(
867 dist_fn: &D,
868 centroids: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
869 observations: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
870 cluster_memberships: &mut ArrayBase<impl DataMut<Elem = usize>, Ix1>,
871 dists: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
872) {
873 Zip::from(observations.axis_iter(Axis(0)))
874 .and(cluster_memberships)
875 .and(dists)
876 .par_for_each(|observation, cluster_membership, dist| {
877 let (m, d) = closest_centroid(dist_fn, centroids, &observation);
878 *cluster_membership = m;
879 *dist = d;
880 });
881}
882
883fn two_closest_centroids<F: Float, D: Distance<F>>(
890 dist_fn: &D,
891 centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
893 observation: &ArrayBase<impl Data<Elem = F>, Ix1>,
895) -> (usize, F, F) {
896 if centroids.nrows() == 1 {
897 return (0, F::cast(0), F::cast(0));
898 }
899 let first_centroid = centroids.row(0);
900 let second_centroid = centroids.row(1);
901 let dist1 = dist_fn.distance(observation.view(), first_centroid.view());
902 let dist2 = dist_fn.distance(observation.view(), second_centroid.view());
903
904 let mut closest_index = if dist1 < dist2 { 0 } else { 1 };
905 let mut closest_distance = if dist1 < dist2 { dist1 } else { dist2 };
906 let mut second_closest_distance = if dist1 < dist2 { dist2 } else { dist1 };
907
908 for (centroid_index, centroid) in centroids.rows().into_iter().skip(2).enumerate() {
909 let distance = dist_fn.distance(observation.view(), centroid.view());
910 if closest_distance <= distance && distance < second_closest_distance {
911 second_closest_distance = distance;
912 } else if distance < closest_distance {
913 second_closest_distance = closest_distance;
914 closest_index = centroid_index + 2; closest_distance = distance;
916 }
917 }
918 (closest_index, closest_distance, second_closest_distance)
919}
920
921pub(crate) fn closest_centroid<F: Float, D: Distance<F>>(
924 dist_fn: &D,
925 centroids: &ArrayBase<impl Data<Elem = F>, Ix2>,
927 observation: &ArrayBase<impl Data<Elem = F>, Ix1>,
929) -> (usize, F) {
930 let iterator = centroids.rows().into_iter();
931
932 let first_centroid = centroids.row(0);
933 let (mut closest_index, mut minimum_distance) = (
934 0,
935 dist_fn.rdistance(first_centroid.view(), observation.view()),
936 );
937
938 for (centroid_index, centroid) in iterator.enumerate() {
939 let distance = dist_fn.rdistance(centroid.view(), observation.view());
940 if distance < minimum_distance {
941 closest_index = centroid_index;
942 minimum_distance = distance;
943 }
944 }
945 (closest_index, minimum_distance)
946}
947
948#[cfg(test)]
949mod tests {
950 use super::super::KMeansInit;
951 use super::*;
952 use crate::KMeansParamsError;
953 use approx::assert_abs_diff_eq;
954 use linfa_nn::distance::L1Dist;
955 use ndarray::{array, concatenate, Array, Array1, Array2, Axis};
956 use ndarray_rand::rand::prelude::ThreadRng;
957 use ndarray_rand::rand::SeedableRng;
958 use ndarray_rand::rand_distr::Uniform;
959 use ndarray_rand::RandomExt;
960
961 #[test]
962 fn autotraits() {
963 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
964 has_autotraits::<KMeans<f64, L2Dist>>();
965 has_autotraits::<KMeansAlgorithm>();
966 has_autotraits::<KMeansParamsError>();
967 has_autotraits::<KMeansError>();
968 has_autotraits::<IncrKMeansError<String>>();
969 }
970
971 fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
972 let mut y = Array2::zeros(x.dim());
973 Zip::from(&mut y).and(x).for_each(|yi, &xi| {
974 if xi < 0.4 {
975 *yi = xi * xi;
976 } else if (0.4..0.8).contains(&xi) {
977 *yi = 3. * xi + 1.;
978 } else {
979 *yi = f64::sin(10. * xi);
980 }
981 });
982 y
983 }
984
985 macro_rules! calc_inertia {
986 ($dist:expr, $centroids:expr, $obs:expr, $memberships:expr) => {
987 $obs.rows()
988 .into_iter()
989 .zip($memberships.iter())
990 .map(|(row, &c)| $dist.rdistance(row.view(), $centroids.row(c).view()))
991 .sum::<f64>()
992 };
993 }
994
995 macro_rules! calc_memberships {
996 ($dist:expr, $centroids:expr, $obs:expr) => {{
997 let mut memberships = Array1::zeros($obs.nrows());
998 update_cluster_memberships(&$dist, &$centroids, &$obs, &mut memberships);
999 memberships
1000 }};
1001 }
1002
1003 #[test]
1004 fn test_min_dists() {
1005 let centroids = array![[0.0, 1.0], [40.0, 10.0]];
1006 let observations = array![[3.0, 4.0], [1.0, 3.0], [25.0, 15.0]];
1007 let mut dists = Array1::zeros(observations.nrows());
1008
1009 update_min_dists(&L2Dist, ¢roids, &observations, &mut dists);
1010 assert_abs_diff_eq!(dists, array![18.0, 5.0, 250.0]);
1011 update_min_dists(&L1Dist, ¢roids, &observations, &mut dists);
1012 assert_abs_diff_eq!(dists, array![6.0, 3.0, 20.0]);
1013 }
1014
1015 fn test_n_runs<D: Distance<f64>>(dist_fn: D) {
1016 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1017 let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1018 let yt = function_test_1d(&xt);
1019 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1020
1021 for init in &[
1022 KMeansInit::Random,
1023 KMeansInit::KMeansPlusPlus,
1024 KMeansInit::KMeansPara,
1025 ] {
1026 let dataset = DatasetBase::from(data.clone());
1028 let model = KMeans::params_with(3, rng.clone(), dist_fn.clone())
1029 .n_runs(1)
1030 .init_method(init.clone())
1031 .fit(&dataset)
1032 .expect("KMeans fitted");
1033 let clusters = model.predict(dataset);
1034 let inertia = calc_inertia!(
1035 dist_fn,
1036 model.centroids(),
1037 clusters.records,
1038 clusters.targets
1039 );
1040 let total_dist = model.transform(&clusters.records.view()).sum();
1041 assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5);
1042
1043 let single_cluster: usize = model.predict(&data.row(0));
1044 assert_abs_diff_eq!(single_cluster, clusters.targets[0]);
1045
1046 let dataset2 = DatasetBase::from(clusters.records().clone());
1048 let model2 = KMeans::params_with(3, rng.clone(), dist_fn.clone())
1049 .init_method(init.clone())
1050 .fit(&dataset2)
1051 .expect("KMeans fitted");
1052 let clusters2 = model2.predict(dataset2);
1053 let inertia2 = calc_inertia!(
1054 dist_fn,
1055 model2.centroids(),
1056 clusters2.records,
1057 clusters2.targets
1058 );
1059 let total_dist2 = model2.transform(&clusters2.records.view()).sum();
1060 assert_abs_diff_eq!(inertia2, total_dist2, epsilon = 1e-5);
1061
1062 if *init == KMeansInit::Random {
1064 assert!(inertia2 <= inertia);
1065 }
1066 }
1067 }
1068
1069 #[test]
1070 fn test_n_runs_l2dist() {
1071 test_n_runs(L2Dist);
1072 }
1073
1074 #[test]
1075 fn test_n_runs_l1dist() {
1076 test_n_runs(L1Dist);
1077 }
1078
1079 #[test]
1080 fn compute_centroids_works() {
1081 let cluster_size = 100;
1082 let n_features = 4;
1083
1084 let cluster_1: Array2<f64> =
1086 Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
1087 let memberships_1 = Array1::zeros(cluster_size);
1088 let expected_centroid_1 = cluster_1.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
1089
1090 let cluster_2: Array2<f64> =
1091 Array::random((cluster_size, n_features), Uniform::new(-100., 100.));
1092 let memberships_2 = Array1::ones(cluster_size);
1093 let expected_centroid_2 = cluster_2.sum_axis(Axis(0)) / (cluster_size + 1) as f64;
1094
1095 let observations = concatenate(Axis(0), &[cluster_1.view(), cluster_2.view()]).unwrap();
1097 let memberships =
1098 concatenate(Axis(0), &[memberships_1.view(), memberships_2.view()]).unwrap();
1099
1100 let old_centroids = Array2::zeros((2, n_features));
1102 let centroids = compute_centroids(&old_centroids, &observations, &memberships);
1103 assert_abs_diff_eq!(
1104 centroids.index_axis(Axis(0), 0),
1105 expected_centroid_1,
1106 epsilon = 1e-5
1107 );
1108 assert_abs_diff_eq!(
1109 centroids.index_axis(Axis(0), 1),
1110 expected_centroid_2,
1111 epsilon = 1e-5
1112 );
1113
1114 assert_eq!(centroids.len_of(Axis(0)), 2);
1115 }
1116
1117 #[test]
1118 fn test_compute_extra_centroids() {
1119 let observations = array![[1.0, 2.0]];
1120 let memberships = array![0];
1121 let old_centroids = Array2::ones((2, 2));
1123 let centroids = compute_centroids(&old_centroids, &observations, &memberships);
1124 assert_abs_diff_eq!(centroids, array![[1.0, 1.5], [1.0, 1.0]]);
1125 }
1126
1127 #[test]
1128 fn nothing_is_closer_than_self() {
1130 let n_centroids = 20;
1131 let n_features = 5;
1132 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1133 let centroids: Array2<f64> = Array::random_using(
1134 (n_centroids, n_features),
1135 Uniform::new(-100., 100.),
1136 &mut rng,
1137 );
1138
1139 let expected_memberships = (0..n_centroids).collect::<Array1<_>>();
1140 assert_eq!(
1141 calc_memberships!(L2Dist, centroids, centroids),
1142 expected_memberships
1143 );
1144 assert_eq!(
1145 calc_memberships!(L1Dist, centroids, centroids),
1146 expected_memberships
1147 );
1148 }
1149
1150 #[test]
1151 fn oracle_test_for_closest_centroid() {
1152 let centroids = array![[0., 0.], [1., 2.], [20., 0.], [0., 20.],];
1153 let observations = array![[1., 0.6], [20., 2.], [20., 0.], [7., 20.],];
1154 let l2_memberships = array![0, 2, 2, 3];
1155 let l1_memberships = array![1, 2, 2, 3];
1156
1157 assert_eq!(
1158 calc_memberships!(L2Dist, centroids, observations),
1159 l2_memberships
1160 );
1161 assert_eq!(
1162 calc_memberships!(L1Dist, centroids, observations),
1163 l1_memberships
1164 );
1165 }
1166
1167 #[test]
1168 fn test_compute_centroids_incremental() {
1169 let observations = array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]];
1170 let memberships = array![0, 0, 1, 1];
1171 let centroids = array![[-1., -1.], [3., 4.], [7., 8.]];
1172 let mut counts = array![3.0, 0.0, 1.0];
1173 let centroids =
1174 compute_centroids_incremental(&observations, &memberships, ¢roids, &mut counts);
1175
1176 assert_abs_diff_eq!(centroids, array![[-4. / 5., -6. / 5.], [4., 5.], [7., 8.]]);
1177 assert_abs_diff_eq!(counts, array![5., 2., 1.]);
1178 }
1179
1180 #[test]
1181 fn test_incremental_kmeans() {
1182 let dataset1 = DatasetBase::from(array![[-1.0, -3.0], [0., 0.], [3., 5.], [5., 5.]]);
1183 let dataset2 = DatasetBase::from(array![[-5.0, -5.0], [0., 0.], [10., 10.]]);
1184 let model = KMeans {
1185 centroids: array![[-1., -1.], [3., 4.], [7., 8.]],
1186 cluster_count: array![0., 0., 0.],
1187 inertia: 0.0,
1188 dist_fn: L2Dist,
1189 };
1190 let rng = Xoshiro256Plus::seed_from_u64(45);
1191 let params = KMeans::params_with_rng(3, rng).tolerance(100.0);
1192
1193 let model = params.fit_with(Some(model), &dataset1).unwrap();
1195 assert_abs_diff_eq!(model.centroids(), &array![[-0.5, -1.5], [4., 5.], [7., 8.]]);
1196
1197 let model = params.fit_with(Some(model), &dataset2).unwrap();
1198 assert_abs_diff_eq!(
1199 model.centroids(),
1200 &array![[-6. / 4., -8. / 4.], [4., 5.], [10., 10.]]
1201 );
1202 }
1203
1204 #[test]
1205 fn fit_with_rejects_hamerly() {
1206 let rng = Xoshiro256Plus::seed_from_u64(45);
1207 let params = KMeans::params_with_rng(2, rng)
1208 .algorithm(KMeansAlgorithm::Hamerly)
1209 .init_method(KMeansInit::Precomputed(array![[0., 0.], [10., 10.]]));
1210 let data = DatasetBase::from(array![[1., 1.], [11., 11.]]);
1211 let err = params
1212 .fit_with(None, &data)
1213 .expect_err("Hamerly + fit_with must be rejected");
1214 assert!(matches!(
1215 err,
1216 IncrKMeansError::InvalidParams(KMeansParamsError::IncrementalHamerly)
1217 ));
1218 }
1219
1220 #[test]
1221 fn test_tolerance() {
1222 let rng = Xoshiro256Plus::seed_from_u64(45);
1223 let params = KMeans::params_with_rng(1, rng)
1227 .tolerance(8.5)
1228 .init_method(KMeansInit::Precomputed(array![[0., 0.]]));
1229 let data = DatasetBase::from(array![[1., 1.], [11., 11.]]);
1230 assert!(params.fit_with(None, &data).is_ok());
1231 }
1232
1233 #[test]
1234 fn test_max_n_iterations() {
1235 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1236 let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1237 let yt = function_test_1d(&xt);
1238 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1239 let dataset = DatasetBase::from(data.clone());
1240 let _model = KMeans::params_with(6, rng.clone(), L2Dist)
1243 .n_runs(1)
1244 .max_n_iterations(5)
1245 .init_method(KMeansInit::Random)
1246 .fit(&dataset)
1247 .expect("KMeans fitted");
1248 }
1249
1250 fn sort_centroids(c: &Array2<f64>) -> Array2<f64> {
1251 let mut rows: Vec<Vec<f64>> = c.rows().into_iter().map(|r| r.to_vec()).collect();
1252 rows.sort_by(|a, b| {
1253 for (x, y) in a.iter().zip(b.iter()) {
1254 match x.partial_cmp(y) {
1255 Some(std::cmp::Ordering::Equal) => continue,
1256 Some(ord) => return ord,
1257 None => continue,
1258 }
1259 }
1260 std::cmp::Ordering::Equal
1261 });
1262 let flat: Vec<f64> = rows.into_iter().flatten().collect();
1263 Array2::from_shape_vec((c.nrows(), c.ncols()), flat).unwrap()
1264 }
1265
1266 fn hamerly_lloyd_equivalence<D: Distance<f64>>(dist_fn: D, init: KMeansInit<f64>) {
1267 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1268 let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1269 let yt = function_test_1d(&xt);
1270 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1271 let dataset = DatasetBase::from(data);
1272
1273 let model_lloyd = KMeans::params_with(6, rng.clone(), dist_fn.clone())
1274 .n_runs(3)
1275 .algorithm(KMeansAlgorithm::Lloyd)
1276 .init_method(init.clone())
1277 .fit(&dataset)
1278 .expect("Lloyd fitted");
1279 let model_hamerly = KMeans::params_with(6, rng.clone(), dist_fn)
1280 .n_runs(3)
1281 .algorithm(KMeansAlgorithm::Hamerly)
1282 .init_method(init)
1283 .fit(&dataset)
1284 .expect("Hamerly fitted");
1285
1286 assert_eq!(model_lloyd.centroids().nrows(), 6);
1287 assert_abs_diff_eq!(
1288 model_lloyd.inertia(),
1289 model_hamerly.inertia(),
1290 epsilon = 1e-4
1291 );
1292 assert_abs_diff_eq!(
1293 sort_centroids(model_lloyd.centroids()),
1294 sort_centroids(model_hamerly.centroids()),
1295 epsilon = 1e-4
1296 );
1297 }
1298
1299 #[test]
1300 fn hamerly_lloyd_equivalence_random_l2() {
1301 hamerly_lloyd_equivalence(L2Dist, KMeansInit::Random);
1302 }
1303
1304 #[test]
1305 fn hamerly_lloyd_equivalence_plusplus_l2() {
1306 hamerly_lloyd_equivalence(L2Dist, KMeansInit::KMeansPlusPlus);
1307 }
1308
1309 fn hamerly_lloyd_equivalence_para<D: Distance<f64>>(dist_fn: D) {
1310 let mut rng = Xoshiro256Plus::seed_from_u64(99);
1314 let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1315 let yt = function_test_1d(&xt);
1316 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1317 let dataset = DatasetBase::from(data);
1318 let init = KMeansInit::Precomputed(KMeansInit::KMeansPlusPlus.run(
1319 &dist_fn,
1320 6,
1321 dataset.records().view(),
1322 &mut rng,
1323 ));
1324 hamerly_lloyd_equivalence(dist_fn, init);
1325 }
1326
1327 #[test]
1328 fn hamerly_lloyd_equivalence_para_l2() {
1329 hamerly_lloyd_equivalence_para(L2Dist);
1330 }
1331
1332 #[test]
1333 fn hamerly_lloyd_equivalence_random_l1() {
1334 hamerly_lloyd_equivalence(L1Dist, KMeansInit::Random);
1335 }
1336
1337 #[test]
1338 fn hamerly_lloyd_equivalence_plusplus_l1() {
1339 hamerly_lloyd_equivalence(L1Dist, KMeansInit::KMeansPlusPlus);
1340 }
1341
1342 #[test]
1343 fn hamerly_lloyd_equivalence_para_l1() {
1344 hamerly_lloyd_equivalence_para(L1Dist);
1345 }
1346
1347 #[test]
1348 fn test_two_closest_centroids_l2() {
1349 let centroids = array![[0.0, 0.0], [10.0, 0.0], [0.0, 10.0]];
1350 let obs = array![1.0, 1.0];
1351 let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs);
1352 assert_eq!(idx, 0);
1353 assert_abs_diff_eq!(closest, f64::sqrt(2.0), epsilon = 1e-10);
1354 assert_abs_diff_eq!(second, f64::sqrt(82.0), epsilon = 1e-10);
1355 }
1356
1357 #[test]
1358 fn test_two_closest_centroids_l1() {
1359 let centroids = array![[0.0, 0.0], [10.0, 0.0], [0.0, 10.0]];
1360 let obs = array![1.0, 1.0];
1361 let (idx, closest, second) = two_closest_centroids(&L1Dist, ¢roids, &obs);
1362 assert_eq!(idx, 0);
1363 assert_abs_diff_eq!(closest, 2.0, epsilon = 1e-10);
1364 assert_abs_diff_eq!(second, 10.0, epsilon = 1e-10);
1365 }
1366
1367 #[test]
1368 fn test_two_closest_centroids_single() {
1369 let centroids = array![[5.0, 5.0]];
1370 let obs = array![1.0, 1.0];
1371 let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs);
1372 assert_eq!(idx, 0);
1373 assert_abs_diff_eq!(closest, 0.0);
1374 assert_abs_diff_eq!(second, 0.0);
1375 }
1376
1377 #[test]
1378 fn test_two_closest_centroids_obs_is_centroid() {
1379 let centroids = array![[0.0, 0.0], [3.0, 4.0], [10.0, 0.0]];
1380 let obs = array![3.0, 4.0];
1381 let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs);
1382 assert_eq!(idx, 1);
1383 assert_abs_diff_eq!(closest, 0.0, epsilon = 1e-10);
1384 assert_abs_diff_eq!(second, 5.0, epsilon = 1e-10);
1385 }
1386
1387 #[test]
1388 fn test_two_closest_centroids_equidistant() {
1389 let centroids = array![[2.0, 0.0], [0.0, 2.0]];
1390 let obs = array![1.0, 1.0];
1391 let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs);
1392 assert_eq!(idx, 1);
1394 assert_abs_diff_eq!(closest, f64::sqrt(2.0), epsilon = 1e-10);
1395 assert_abs_diff_eq!(second, f64::sqrt(2.0), epsilon = 1e-10);
1396 }
1397
1398 #[test]
1399 fn test_two_farthest_indices() {
1400 assert_eq!(two_farthest_indices(&array![1.0, 5.0, 3.0, 2.0]), (1, 2));
1402
1403 assert_eq!(two_farthest_indices(&array![3.0, 3.0, 3.0]), (2, 1));
1405
1406 assert_eq!(two_farthest_indices(&array![2.0, 7.0]), (1, 0));
1408 assert_eq!(two_farthest_indices(&array![7.0, 2.0]), (0, 1));
1409
1410 assert_eq!(two_farthest_indices(&array![8.0, 1.0, 2.0, 9.0]), (3, 0));
1412
1413 assert_eq!(two_farthest_indices(&array![9.0, 1.0, 2.0, 8.0]), (0, 3));
1415
1416 assert_eq!(two_farthest_indices(&array![1.0]), (0, 0));
1418 }
1419
1420 #[test]
1421 fn test_recompute_centroids() {
1422 let obs = array![[0.0, 0.0]];
1423 let centroids = array![[0.0, 0.0], [0.0, 0.0]];
1424 let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1425 hamerly.centroid_sums = array![[8.0, 12.0], [15.0, 30.0]];
1427 hamerly.centroid_counts = array![3_usize, 2];
1428 hamerly.recompute_centroids();
1429 assert_abs_diff_eq!(
1430 hamerly.centroids,
1431 array![[2.0, 3.0], [5.0, 10.0]],
1432 epsilon = 1e-10
1433 );
1434
1435 let centroids2 = array![[7.0, 7.0], [0.0, 0.0]];
1437 let mut hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2);
1438 hamerly2.centroid_sums = array![[0.0, 0.0], [15.0, 30.0]];
1439 hamerly2.centroid_counts = array![0_usize, 2];
1440 hamerly2.recompute_centroids();
1441 assert_abs_diff_eq!(
1442 hamerly2.centroids,
1443 array![[7.0, 7.0], [5.0, 10.0]],
1444 epsilon = 1e-10
1445 );
1446 }
1447
1448 #[test]
1449 fn test_recompute_centroids_distances_moved() {
1450 let obs = array![[0.0, 0.0]];
1451 let centroids = array![[0.0, 0.0], [10.0, 0.0]];
1452 let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1453 hamerly.centroid_sums = array![[2.0, 0.0], [10.0, 6.0]];
1456 hamerly.centroid_counts = array![1_usize, 1];
1457 let update = hamerly.recompute_centroids();
1458 assert_abs_diff_eq!(update.distances_moved, array![1.0, 3.0], epsilon = 1e-10);
1459
1460 let centroids2 = array![[5.0, 5.0], [10.0, 10.0]];
1462 let mut hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2);
1463 hamerly2.centroid_sums = array![[5.0, 5.0], [10.0, 10.0]];
1464 hamerly2.centroid_counts = array![1_usize, 1];
1465 let update2 = hamerly2.recompute_centroids();
1466 assert_abs_diff_eq!(update2.distances_moved, array![0.0, 0.0], epsilon = 1e-10);
1467 }
1468
1469 #[test]
1470 fn test_nearest_inter_centroid_distances() {
1471 let obs = array![[0.0, 0.0]];
1472 let centroids = array![[0.0, 0.0], [3.0, 0.0], [0.0, 4.0]];
1473 let hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1474 let dists = hamerly.nearest_inter_centroid_distances();
1475 assert_abs_diff_eq!(dists, array![3.0, 3.0, 4.0], epsilon = 1e-10);
1476
1477 let centroids2 = array![[0.0, 0.0], [5.0, 0.0]];
1479 let hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2);
1480 let dists2 = hamerly2.nearest_inter_centroid_distances();
1481 assert_abs_diff_eq!(dists2, array![5.0, 5.0], epsilon = 1e-10);
1482 }
1483
1484 #[test]
1485 fn test_hamerly_strategy_new() {
1486 let obs = array![[0.0, 0.0], [1.0, 0.0], [10.0, 10.0]];
1487 let centroids = array![[0.0, 0.0], [10.0, 10.0]];
1488 let hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1489 assert_eq!(hamerly.memberships, array![0_usize, 0, 1]);
1490 assert_eq!(hamerly.centroid_counts, array![2_usize, 1]);
1491 assert_abs_diff_eq!(
1492 hamerly.centroid_sums,
1493 array![[1.0, 0.0], [10.0, 10.0]],
1494 epsilon = 1e-10
1495 );
1496 }
1497
1498 #[test]
1499 fn test_update_bounds_oracle() {
1500 let obs = array![[0.0, 0.0], [10.0, 0.0], [0.0, 0.0]];
1501 let centroids = array![[0.0, 0.0], [10.0, 0.0]];
1502 let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids);
1503 hamerly.memberships = array![0_usize, 1, 0];
1504 hamerly.upper_bounds = array![5.0, 3.0, 4.0];
1505 hamerly.lower_bounds = array![2.0, 1.0, 3.0];
1506 let distances_moved = array![1.0, 0.5];
1507 hamerly.update_bounds(&distances_moved);
1508 assert_abs_diff_eq!(hamerly.upper_bounds, array![6.0, 3.5, 5.0], epsilon = 1e-10);
1509 assert_abs_diff_eq!(hamerly.lower_bounds, array![1.5, 0.0, 2.5], epsilon = 1e-10);
1510 }
1511
1512 #[test]
1513 fn test_compute_inertia() {
1514 let obs = array![[0.0, 0.0], [3.0, 4.0]];
1515 let memberships = array![0_usize, 0];
1516 let centroids = array![[1.0, 1.0]];
1517 let inertia = compute_inertia(&L2Dist, obs.view(), &memberships, ¢roids);
1518 assert_abs_diff_eq!(inertia, 15.0, epsilon = 1e-10);
1520 }
1521
1522 fn test_n_runs_hamerly<D: Distance<f64>>(dist_fn: D) {
1523 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1524 let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1525 let yt = function_test_1d(&xt);
1526 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1527
1528 for init in &[
1529 KMeansInit::Random,
1530 KMeansInit::KMeansPlusPlus,
1531 KMeansInit::KMeansPara,
1532 ] {
1533 let dataset = DatasetBase::from(data.clone());
1534 let model = KMeans::params_with(3, rng.clone(), dist_fn.clone())
1535 .n_runs(1)
1536 .algorithm(KMeansAlgorithm::Hamerly)
1537 .init_method(init.clone())
1538 .fit(&dataset)
1539 .expect("KMeans fitted");
1540 let clusters = model.predict(dataset);
1541 let inertia = calc_inertia!(
1542 dist_fn,
1543 model.centroids(),
1544 clusters.records,
1545 clusters.targets
1546 );
1547 let total_dist = model.transform(&clusters.records.view()).sum();
1548 assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5);
1549
1550 let single_cluster: usize = model.predict(&data.row(0));
1551 assert_abs_diff_eq!(single_cluster, clusters.targets[0]);
1552
1553 let dataset2 = DatasetBase::from(clusters.records().clone());
1554 let model2 = KMeans::params_with(3, rng.clone(), dist_fn.clone())
1555 .algorithm(KMeansAlgorithm::Hamerly)
1556 .init_method(init.clone())
1557 .fit(&dataset2)
1558 .expect("KMeans fitted");
1559 let clusters2 = model2.predict(dataset2);
1560 let inertia2 = calc_inertia!(
1561 dist_fn,
1562 model2.centroids(),
1563 clusters2.records,
1564 clusters2.targets
1565 );
1566 let total_dist2 = model2.transform(&clusters2.records.view()).sum();
1567 assert_abs_diff_eq!(inertia2, total_dist2, epsilon = 1e-5);
1568
1569 if *init == KMeansInit::Random {
1570 assert!(inertia2 <= inertia);
1571 }
1572 }
1573 }
1574
1575 #[test]
1576 fn test_n_runs_hamerly_l2dist() {
1577 test_n_runs_hamerly(L2Dist);
1578 }
1579
1580 #[test]
1581 fn test_n_runs_hamerly_l1dist() {
1582 test_n_runs_hamerly(L1Dist);
1583 }
1584
1585 #[test]
1586 fn test_hamerly_precomputed_centroids() {
1587 let rng = Xoshiro256Plus::seed_from_u64(42);
1588 let data = array![
1589 [0.0, 0.0],
1590 [1.0, 0.0],
1591 [0.0, 1.0],
1592 [10.0, 10.0],
1593 [11.0, 10.0],
1594 [10.0, 11.0]
1595 ];
1596 let init_centroids = array![[0.0, 0.0], [10.0, 10.0]];
1597 let dataset = DatasetBase::from(data);
1598
1599 let model_lloyd = KMeans::params_with(2, rng.clone(), L2Dist)
1600 .n_runs(1)
1601 .algorithm(KMeansAlgorithm::Lloyd)
1602 .init_method(KMeansInit::Precomputed(init_centroids.clone()))
1603 .fit(&dataset)
1604 .expect("Lloyd fitted");
1605 let model_hamerly = KMeans::params_with(2, rng.clone(), L2Dist)
1606 .n_runs(1)
1607 .algorithm(KMeansAlgorithm::Hamerly)
1608 .init_method(KMeansInit::Precomputed(init_centroids))
1609 .fit(&dataset)
1610 .expect("Hamerly fitted");
1611
1612 assert_abs_diff_eq!(
1613 model_lloyd.centroids(),
1614 model_hamerly.centroids(),
1615 epsilon = 1e-1
1616 );
1617 assert_abs_diff_eq!(
1618 model_lloyd.inertia(),
1619 model_hamerly.inertia(),
1620 epsilon = 1e-1
1621 );
1622 }
1623
1624 #[test]
1625 fn test_hamerly_single_cluster() {
1626 let rng = Xoshiro256Plus::seed_from_u64(42);
1627 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1628 let dataset = DatasetBase::from(data);
1629 let model = KMeans::params_with_rng(1, rng)
1630 .algorithm(KMeansAlgorithm::Hamerly)
1631 .fit(&dataset)
1632 .expect("KMeans fitted");
1633 assert_abs_diff_eq!(model.centroids(), &array![[4.0, 5.0]], epsilon = 1e-4);
1634 }
1635
1636 #[test]
1637 fn test_hamerly_n_clusters_eq_n_samples() {
1638 let rng = Xoshiro256Plus::seed_from_u64(42);
1639 let data = array![[1.0, 2.0], [10.0, 20.0], [-5.0, -5.0], [100.0, 0.0]];
1640 let dataset = DatasetBase::from(data.clone());
1641 let model = KMeans::params_with_rng(4, rng)
1642 .algorithm(KMeansAlgorithm::Hamerly)
1643 .init_method(KMeansInit::Precomputed(data))
1644 .fit(&dataset)
1645 .expect("KMeans fitted");
1646 assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10);
1647 }
1648
1649 #[test]
1650 fn test_hamerly_single_observation() {
1651 let rng = Xoshiro256Plus::seed_from_u64(42);
1652 let data = array![[3.0, 7.0]];
1653 let dataset = DatasetBase::from(data);
1654 let model = KMeans::params_with_rng(1, rng)
1655 .algorithm(KMeansAlgorithm::Hamerly)
1656 .fit(&dataset)
1657 .expect("KMeans fitted");
1658 assert_abs_diff_eq!(model.centroids(), &array![[3.0, 7.0]], epsilon = 1e-10);
1659 assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10);
1660 }
1661
1662 #[test]
1663 fn test_hamerly_identical_data() {
1664 let rng = Xoshiro256Plus::seed_from_u64(42);
1665 let data = array![[5.0, 5.0], [5.0, 5.0], [5.0, 5.0], [5.0, 5.0]];
1666 let dataset = DatasetBase::from(data);
1667 let model = KMeans::params_with_rng(1, rng)
1668 .algorithm(KMeansAlgorithm::Hamerly)
1669 .fit(&dataset)
1670 .expect("KMeans fitted");
1671 assert_abs_diff_eq!(model.centroids(), &array![[5.0, 5.0]], epsilon = 1e-10);
1672 assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10);
1673 }
1674
1675 #[test]
1676 fn test_hamerly_high_dimensionality() {
1677 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1678 let data: Array2<f64> = Array::random_using((200, 50), Uniform::new(-100., 100.), &mut rng);
1679 let dataset = DatasetBase::from(data);
1680
1681 let model_lloyd = KMeans::params_with(5, rng.clone(), L2Dist)
1682 .n_runs(1)
1683 .algorithm(KMeansAlgorithm::Lloyd)
1684 .init_method(KMeansInit::Random)
1685 .fit(&dataset)
1686 .expect("Lloyd fitted");
1687 let model_hamerly = KMeans::params_with(5, rng.clone(), L2Dist)
1688 .n_runs(1)
1689 .algorithm(KMeansAlgorithm::Hamerly)
1690 .init_method(KMeansInit::Random)
1691 .fit(&dataset)
1692 .expect("Hamerly fitted");
1693
1694 assert_abs_diff_eq!(
1695 model_lloyd.inertia(),
1696 model_hamerly.inertia(),
1697 epsilon = 1e-5
1698 );
1699 assert_abs_diff_eq!(
1700 model_lloyd.centroids(),
1701 model_hamerly.centroids(),
1702 epsilon = 1e-5
1703 );
1704 }
1705
1706 #[test]
1707 fn test_hamerly_max_n_iterations() {
1708 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1709 let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1710 let yt = function_test_1d(&xt);
1711 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1712 let dataset = DatasetBase::from(data);
1713 let _model = KMeans::params_with(6, rng.clone(), L2Dist)
1714 .n_runs(1)
1715 .max_n_iterations(5)
1716 .algorithm(KMeansAlgorithm::Hamerly)
1717 .init_method(KMeansInit::Random)
1718 .fit(&dataset)
1719 .expect("KMeans fitted");
1720 }
1721
1722 #[test]
1723 fn test_hamerly_tolerance() {
1724 let rng = Xoshiro256Plus::seed_from_u64(45);
1725 let data = DatasetBase::from(array![[1., 1.], [11., 11.]]);
1726 let model = KMeans::params_with_rng(1, rng)
1727 .tolerance(8.5)
1728 .algorithm(KMeansAlgorithm::Hamerly)
1729 .init_method(KMeansInit::Precomputed(array![[0., 0.]]))
1730 .fit(&data)
1731 .expect("KMeans fitted");
1732 assert_abs_diff_eq!(model.centroids(), &array![[4., 4.]], epsilon = 1e-1);
1733 }
1734
1735 #[test]
1736 fn test_hamerly_predict_transform_consistency() {
1737 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1738 let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1739 let yt = function_test_1d(&xt);
1740 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1741 let dataset = DatasetBase::from(data);
1742
1743 let model = KMeans::params_with(3, rng.clone(), L2Dist)
1744 .algorithm(KMeansAlgorithm::Hamerly)
1745 .fit(&dataset)
1746 .expect("Hamerly fitted");
1747
1748 let clusters = model.predict(dataset);
1749 assert!(clusters.targets.iter().all(|&c| c < 3));
1750
1751 let inertia = calc_inertia!(
1752 L2Dist,
1753 model.centroids(),
1754 clusters.records,
1755 clusters.targets
1756 );
1757 let total_dist = model.transform(&clusters.records.view()).sum();
1758 assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5);
1759 }
1760
1761 fn fittable<T: Fit<Array2<f64>, (), KMeansError>>(_: T) {}
1762 #[test]
1763 fn thread_rng_fittable() {
1764 fittable(KMeans::params_with_rng(1, ThreadRng::default()));
1765 }
1766}