linfa/
metrics_clustering.rs1use crate::dataset::{AsSingleTargets, DatasetBase, Label, Labels, Records};
3use crate::error::Result;
4use crate::Float;
5use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
6use std::collections::HashMap;
7use std::ops::Sub;
8
9pub trait SilhouetteScore<F> {
11 fn silhouette_score(&self) -> Result<F>;
24}
25
26struct DistanceCount<F> {
27 total_distance: F,
28 count: usize,
29}
30
31impl<F: Float> DistanceCount<F> {
32 pub fn reset(&mut self) {
34 self.total_distance = F::zero();
35 }
36
37 pub fn new(count: usize) -> DistanceCount<F> {
38 DistanceCount {
39 total_distance: F::zero(),
40 count,
41 }
42 }
43
44 pub fn mean_distance(&self) -> F {
46 self.total_distance / F::cast(self.count)
47 }
48
49 pub fn same_label_mean_distance(&self) -> F {
53 if self.count == 1 {
54 return F::zero();
55 }
56 self.total_distance / F::cast(self.count - 1)
57 }
58
59 pub fn add_point(&mut self, eval_sample: ArrayView1<F>, other_sample: ArrayView1<F>) {
61 self.total_distance += eval_sample.sub(&other_sample).mapv(|x| x * x).sum().sqrt();
62 }
63}
64
65impl<F: Float, L: Label, D: Data<Elem = F>, T: AsSingleTargets<Elem = L> + Labels<Elem = L>>
66 SilhouetteScore<F> for DatasetBase<ArrayBase<D, Ix2>, T>
67{
68 fn silhouette_score(&self) -> Result<F> {
69 let mut labels: HashMap<L, DistanceCount<F>> = self
70 .label_count()
71 .remove(0)
72 .into_iter()
73 .map(|(label, count)| (label, DistanceCount::new(count)))
74 .collect();
75
76 if labels.len() == 1 {
78 return Ok(F::one());
79 }
80
81 let score = self
83 .sample_iter()
84 .map(|sample| {
85 for other in self.sample_iter() {
90 labels
91 .get_mut(other.1.into_scalar())
92 .unwrap()
93 .add_point(sample.0, other.0);
94 }
95
96 let mut a_x = F::zero();
98 let mut b_x: Option<F> = None;
101
102 for (label, counter) in &mut labels {
103 if sample.1.into_scalar() == label {
104 a_x = counter.same_label_mean_distance();
106 } else {
107 b_x = match b_x {
109 None => Some(counter.mean_distance()),
110 Some(v) => {
111 if counter.mean_distance() < v {
112 Some(counter.mean_distance())
113 } else {
114 Some(v)
115 }
116 }
117 }
118 }
119 counter.reset()
120 }
121 let b_x = b_x.unwrap();
124
125 if a_x >= b_x {
127 (b_x - a_x) / a_x
128 } else {
129 (b_x - a_x) / b_x
130 }
131 })
132 .sum::<F>();
133 let score = score / F::cast(self.records().nsamples());
134 Ok(score)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use crate::metrics_clustering::SilhouetteScore;
141 use crate::{Dataset, DatasetBase};
142 use approx::assert_abs_diff_eq;
143 use ndarray::{concatenate, Array, Array1, Axis, Ix1};
144
145 #[test]
146 fn test_silhouette_score() {
147 let records = concatenate![
150 Axis(0),
151 Array::linspace(0f64, 1f64, 10),
152 Array::linspace(10000f64, 10001f64, 10)
153 ]
154 .insert_axis(Axis(1));
155 let records = concatenate![Axis(1), records, records];
156 let targets = concatenate![Axis(0), Array1::from_elem(10, 0), Array1::from_elem(10, 1)];
157 let dataset: Dataset<_, _, Ix1> = (records, targets).into();
158 let score = dataset.silhouette_score().unwrap();
159 assert_abs_diff_eq!(score, 1f64, epsilon = 1e-3);
160
161 let records = concatenate![
164 Axis(0),
165 Array::linspace(0f64, 1f64, 5),
166 Array::linspace(1f64, 2f64, 5),
167 Array::linspace(10000f64, 10001f64, 5),
168 Array::linspace(10001f64, 10002f64, 5)
169 ]
170 .insert_axis(Axis(1));
171 let records = concatenate![Axis(1), records, records];
172 let targets = concatenate![
173 Axis(0),
174 Array1::from_elem(5, 0),
175 Array1::from_elem(5, 1),
176 Array1::from_elem(5, 0),
177 Array1::from_elem(5, 1)
178 ];
179 let dataset: Dataset<_, _, Ix1> = (records, targets).into();
180 let score = dataset.silhouette_score().unwrap();
181 assert!(score < 0f64);
182
183 let records = Array::linspace(0f64, 10f64, 100).insert_axis(Axis(1));
185 let records = concatenate![Axis(1), records, records];
186 let targets = Array1::from_shape_fn(100, |i| (i + 3) % 48);
187 let dataset: Dataset<_, _, Ix1> = (records, targets).into();
188 let score = dataset.silhouette_score().unwrap();
189 assert!(score < -0.5f64)
190 }
191
192 #[test]
193 fn test_empty_labels_as_single_label() {
194 let records = Array::linspace(0f64, 1f64, 10).insert_axis(Axis(1));
195 let dataset: DatasetBase<_, _> = records.into();
196 let score = dataset.silhouette_score().unwrap();
197 assert_abs_diff_eq!(score, 1f64, epsilon = 1e-5);
198 }
199}