linfa/
metrics_clustering.rs

1//! Common metrics for clustering
2use crate::dataset::{AsSingleTargets, DatasetBase, Label, Labels, Records};
3use crate::error::Result;
4use crate::Float;
5use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
6use std::collections::HashMap;
7use std::ops::Sub;
8
9/// Evaluates the quality of a clustering using euclidean distance.
10pub trait SilhouetteScore<F> {
11    /// Evaluates the quality of a clustering.
12    ///
13    /// Given a clustered dataset,
14    /// the silhouette score for each sample is computed as
15    /// the relative difference between the average distance
16    /// of the sample to other samples in the same cluster and
17    /// the minimum average distance of the sample to samples in
18    /// another cluster. This value goes from -1 to +1 when the point
19    /// is respectively closer (in average) to points in another cluster and to points in its own cluster.
20    ///
21    /// Finally, the silhouette score for the clustering is evaluated as the mean
22    /// silhouette score of each sample.
23    fn silhouette_score(&self) -> Result<F>;
24}
25
26struct DistanceCount<F> {
27    total_distance: F,
28    count: usize,
29}
30
31impl<F: Float> DistanceCount<F> {
32    /// Sets the total distance from the sample to this cluster to zero
33    pub fn reset(&mut self) {
34        self.total_distance = F::zero();
35    }
36
37    pub fn new(count: usize) -> DistanceCount<F> {
38        DistanceCount {
39            total_distance: F::zero(),
40            count,
41        }
42    }
43
44    /// Divides the total distance from the sample to this cluster by the number of samples in the cluster
45    pub fn mean_distance(&self) -> F {
46        self.total_distance / F::cast(self.count)
47    }
48
49    /// To be used in the cluster in which the sample is located. The distance from the sample to itself
50    /// is zero so it does not get added to the total distance. We can then just divide the total
51    /// distance by 1 - #samples in this cluster
52    pub fn same_label_mean_distance(&self) -> F {
53        if self.count == 1 {
54            return F::zero();
55        }
56        self.total_distance / F::cast(self.count - 1)
57    }
58
59    /// adds the distance of `other_sample` from `eval_sample` to the total distance of `eval_sample` from the current cluster
60    pub fn add_point(&mut self, eval_sample: ArrayView1<F>, other_sample: ArrayView1<F>) {
61        self.total_distance += eval_sample.sub(&other_sample).mapv(|x| x * x).sum().sqrt();
62    }
63}
64
65impl<F: Float, L: Label, D: Data<Elem = F>, T: AsSingleTargets<Elem = L> + Labels<Elem = L>>
66    SilhouetteScore<F> for DatasetBase<ArrayBase<D, Ix2>, T>
67{
68    fn silhouette_score(&self) -> Result<F> {
69        let mut labels: HashMap<L, DistanceCount<F>> = self
70            .label_count()
71            .remove(0)
72            .into_iter()
73            .map(|(label, count)| (label, DistanceCount::new(count)))
74            .collect();
75
76        // Single label dataset, all points are in the same cluster.
77        if labels.len() == 1 {
78            return Ok(F::one());
79        }
80
81        // Compute and sum silhouette score for each sample
82        let score = self
83            .sample_iter()
84            .map(|sample| {
85                // Loops through all samples in the dataset and adds
86                // the distance between them and `sample` to the cluster
87                // in which they belong
88
89                for other in self.sample_iter() {
90                    labels
91                        .get_mut(other.1.into_scalar())
92                        .unwrap()
93                        .add_point(sample.0, other.0);
94                }
95
96                // average distance from `sample` to points in its cluster
97                let mut a_x = F::zero();
98                // minimum average distance from `sample` to another cluster
99                // set to none so that it can be initialized as the first value
100                let mut b_x: Option<F> = None;
101
102                for (label, counter) in &mut labels {
103                    if sample.1.into_scalar() == label {
104                        // The cluster of `sample` averages by excluding `sample` from the counting
105                        a_x = counter.same_label_mean_distance();
106                    } else {
107                        // Keep the minimum average distance
108                        b_x = match b_x {
109                            None => Some(counter.mean_distance()),
110                            Some(v) => {
111                                if counter.mean_distance() < v {
112                                    Some(counter.mean_distance())
113                                } else {
114                                    Some(v)
115                                }
116                            }
117                        }
118                    }
119                    counter.reset()
120                }
121                // Since the single label case was taken care of earlier, here there are at least
122                // two clusters so `b_x` can't be `None`
123                let b_x = b_x.unwrap();
124
125                // s(x) = (b(x) - a(x)) / max{a(x), b(x)}
126                if a_x >= b_x {
127                    (b_x - a_x) / a_x
128                } else {
129                    (b_x - a_x) / b_x
130                }
131            })
132            .sum::<F>();
133        let score = score / F::cast(self.records().nsamples());
134        Ok(score)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use crate::metrics_clustering::SilhouetteScore;
141    use crate::{Dataset, DatasetBase};
142    use approx::assert_abs_diff_eq;
143    use ndarray::{concatenate, Array, Array1, Axis, Ix1};
144
145    #[test]
146    fn test_silhouette_score() {
147        // Two very far apart clusters, each with its own label.
148        // This is a very good clustering for silhouette and should return a score very close to +1
149        let records = concatenate![
150            Axis(0),
151            Array::linspace(0f64, 1f64, 10),
152            Array::linspace(10000f64, 10001f64, 10)
153        ]
154        .insert_axis(Axis(1));
155        let records = concatenate![Axis(1), records, records];
156        let targets = concatenate![Axis(0), Array1::from_elem(10, 0), Array1::from_elem(10, 1)];
157        let dataset: Dataset<_, _, Ix1> = (records, targets).into();
158        let score = dataset.silhouette_score().unwrap();
159        assert_abs_diff_eq!(score, 1f64, epsilon = 1e-3);
160
161        // Two clusters separated into halves very far from each other and each very near an half of the other cluster.
162        // Bad but not terrible for silhouette, should return a score slightly negative
163        let records = concatenate![
164            Axis(0),
165            Array::linspace(0f64, 1f64, 5),
166            Array::linspace(1f64, 2f64, 5),
167            Array::linspace(10000f64, 10001f64, 5),
168            Array::linspace(10001f64, 10002f64, 5)
169        ]
170        .insert_axis(Axis(1));
171        let records = concatenate![Axis(1), records, records];
172        let targets = concatenate![
173            Axis(0),
174            Array1::from_elem(5, 0),
175            Array1::from_elem(5, 1),
176            Array1::from_elem(5, 0),
177            Array1::from_elem(5, 1)
178        ];
179        let dataset: Dataset<_, _, Ix1> = (records, targets).into();
180        let score = dataset.silhouette_score().unwrap();
181        assert!(score < 0f64);
182
183        // Very bad clustering with a high number of clusters, I expect a very negative value
184        let records = Array::linspace(0f64, 10f64, 100).insert_axis(Axis(1));
185        let records = concatenate![Axis(1), records, records];
186        let targets = Array1::from_shape_fn(100, |i| (i + 3) % 48);
187        let dataset: Dataset<_, _, Ix1> = (records, targets).into();
188        let score = dataset.silhouette_score().unwrap();
189        assert!(score < -0.5f64)
190    }
191
192    #[test]
193    fn test_empty_labels_as_single_label() {
194        let records = Array::linspace(0f64, 1f64, 10).insert_axis(Axis(1));
195        let dataset: DatasetBase<_, _> = records.into();
196        let score = dataset.silhouette_score().unwrap();
197        assert_abs_diff_eq!(score, 1f64, epsilon = 1e-5);
198    }
199}