linfa_hierarchical/
lib.rs

1//! # Hierarchical Clustering
2//!
3//! `linfa-hierarchical` provides an implementation of agglomerative hierarchical clustering.
4//! In this clustering algorithm, each point is first considered as a separate cluster. During each
5//! step, two points are merged into new clusters, until a stopping criterion is reached. The distance
6//! between the points is computed as the negative-log transform of the similarity kernel.
7//!
8//! _Documentation_: [latest](https://docs.rs/linfa-hierarchical).
9//!
10//! ## The big picture
11//!
12//! `linfa-hierarchical` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem,
13//! a wider effort to bootstrap a toolkit for classical Machine Learning implemented in pure Rust,
14//! akin in spirit to Python's `scikit-learn`.
15//!
16//! ## Current state
17//!
18//! `linfa-hierarchical` implements agglomerative hierarchical clustering with support of the
19//! [kodama](https://docs.rs/kodama/0.2.3/kodama/) crate.
20
21use std::collections::HashMap;
22
23use kodama::linkage;
24pub use kodama::Method;
25
26use linfa::param_guard::TransformGuard;
27use linfa::traits::Transformer;
28use linfa::Float;
29use linfa::{dataset::DatasetBase, ParamGuard};
30use linfa_kernel::Kernel;
31
32pub use error::{HierarchicalError, Result};
33
34mod error;
35
36/// Criterion when to stop merging
37///
38/// The criterion defines at which point the merging process should stop. This can be either, when
39/// a certain number of clusters is reached, or the distance becomes larger than a maximal
40/// distance.
41#[derive(Clone, Debug, PartialEq)]
42pub enum Criterion<F: Float> {
43    NumClusters(usize),
44    Distance(F),
45}
46
47/// Agglomerative hierarchical clustering
48///
49/// In this clustering algorithm, each point is first considered as a separate cluster. During each
50/// step, two points are merged into new clusters, until a stopping criterion is reached. The distance
51/// between the points is computed as the negative-log transform of the similarity kernel.
52#[derive(Default, Debug, Clone, PartialEq)]
53pub struct HierarchicalCluster<T: Float>(ValidHierarchicalCluster<T>);
54
55/// Checked version of [`HierarchicalCluster`](`HierarchicalCluster`)
56#[derive(Clone, Debug, PartialEq)]
57pub struct ValidHierarchicalCluster<T: Float> {
58    method: Method,
59    stopping: Criterion<T>,
60}
61
62impl<F: Float> ParamGuard for HierarchicalCluster<F> {
63    type Checked = ValidHierarchicalCluster<F>;
64    type Error = HierarchicalError<F>;
65
66    fn check_ref(&self) -> std::result::Result<&Self::Checked, Self::Error> {
67        match self.0.stopping {
68            Criterion::NumClusters(0) => Err(HierarchicalError::InvalidStoppingCondition(
69                self.0.stopping.clone(),
70            )),
71            Criterion::Distance(x) if x.is_negative() || x.is_nan() || x.is_infinite() => Err(
72                HierarchicalError::InvalidStoppingCondition(self.0.stopping.clone()),
73            ),
74            _ => Ok(&self.0),
75        }
76    }
77
78    fn check(self) -> std::result::Result<Self::Checked, Self::Error> {
79        self.check_ref()?;
80        Ok(self.0)
81    }
82}
83impl<F: Float> TransformGuard for HierarchicalCluster<F> {}
84
85impl<F: Float> HierarchicalCluster<F> {
86    /// Select a merging method
87    pub fn with_method(mut self, method: Method) -> HierarchicalCluster<F> {
88        self.0.method = method;
89        self
90    }
91
92    /// Stop merging when a certain number of clusters are reached
93    ///
94    /// In the fitting process points are merged until a certain criterion is reached. With this
95    /// option the merging process will stop, when the number of clusters drops below this value.
96    pub fn num_clusters(mut self, num_clusters: usize) -> HierarchicalCluster<F> {
97        self.0.stopping = Criterion::NumClusters(num_clusters);
98        self
99    }
100
101    /// Stop merging when a certain distance is reached
102    ///
103    /// In the fitting process points are merged until a certain criterion is reached. With this
104    /// option the merging process will stop, then the distance exceeds this value.
105    pub fn max_distance(mut self, max_distance: F) -> HierarchicalCluster<F> {
106        self.0.stopping = Criterion::Distance(max_distance);
107        self
108    }
109}
110
111impl<F: Float> Transformer<Kernel<F>, DatasetBase<Kernel<F>, Vec<usize>>>
112    for ValidHierarchicalCluster<F>
113{
114    /// Perform hierarchical clustering of a similarity matrix
115    ///
116    /// Returns the class id for each data point
117    fn transform(&self, kernel: Kernel<F>) -> DatasetBase<Kernel<F>, Vec<usize>> {
118        // ignore all similarities below this value
119        let threshold = F::cast(1e-6);
120
121        // transform similarities to distances with log transformation
122        let mut distance = kernel
123            .to_upper_triangle()
124            .into_iter()
125            .map(|x| {
126                if x > threshold {
127                    -x.ln()
128                } else {
129                    -threshold.ln()
130                }
131            })
132            .collect::<Vec<_>>();
133
134        // call kodama linkage function
135        let num_observations = kernel.size();
136        let res = linkage(&mut distance, num_observations, self.method);
137
138        // post-process results, iterate through merging step until threshold is reached
139        // at the beginning every node is in its own cluster
140        let mut clusters = (0..num_observations)
141            .map(|x| (x, vec![x]))
142            .collect::<HashMap<_, _>>();
143
144        // counter for new clusters, which are formed as unions of previous ones
145        let mut ct = num_observations;
146
147        for step in res.steps() {
148            let should_stop = match self.stopping {
149                Criterion::NumClusters(max_clusters) => clusters.len() <= max_clusters,
150                Criterion::Distance(dis) => step.dissimilarity >= dis,
151            };
152
153            // break if one of the two stopping condition is reached
154            if should_stop {
155                break;
156            }
157
158            // combine ids from both clusters
159            let mut ids = Vec::with_capacity(2);
160            let mut cl = clusters.remove(&step.cluster1).unwrap();
161            ids.append(&mut cl);
162            let mut cl = clusters.remove(&step.cluster2).unwrap();
163            ids.append(&mut cl);
164
165            // insert into hashmap and increase counter
166            clusters.insert(ct, ids);
167            ct += 1;
168        }
169
170        // flatten resulting clusters and reverse index
171        let mut tmp = vec![0; num_observations];
172        for (i, (_, ids)) in clusters.into_iter().enumerate() {
173            for id in ids {
174                tmp[id] = i;
175            }
176        }
177
178        // return node_index -> cluster_index map
179        DatasetBase::new(kernel, tmp)
180    }
181}
182
183impl<F: Float, T> Transformer<DatasetBase<Kernel<F>, T>, DatasetBase<Kernel<F>, Vec<usize>>>
184    for ValidHierarchicalCluster<F>
185{
186    /// Perform hierarchical clustering of a similarity matrix
187    ///
188    /// Returns the class id for each data point
189    fn transform(&self, dataset: DatasetBase<Kernel<F>, T>) -> DatasetBase<Kernel<F>, Vec<usize>> {
190        self.transform(dataset.records)
191    }
192}
193
194/// Initialize hierarchical clustering, which averages in-cluster points and stops when two
195/// clusters are reached.
196impl<T: Float> Default for ValidHierarchicalCluster<T> {
197    fn default() -> Self {
198        Self {
199            method: Method::Average,
200            stopping: Criterion::NumClusters(2),
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use crate::HierarchicalError;
208    use linfa::traits::Transformer;
209    use linfa_kernel::{Kernel, KernelMethod};
210    use ndarray::{Array, Axis};
211    use ndarray_rand::{rand_distr::Normal, RandomExt};
212
213    use super::{Criterion, HierarchicalCluster, ValidHierarchicalCluster};
214
215    #[test]
216    fn autotraits() {
217        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
218        has_autotraits::<Criterion<f64>>();
219        has_autotraits::<HierarchicalCluster<f64>>();
220        has_autotraits::<ValidHierarchicalCluster<f64>>();
221        has_autotraits::<HierarchicalError<f64>>();
222    }
223
224    #[test]
225    fn test_blobs() {
226        // we have 10 observations per cluster
227        let npoints = 10;
228        // generate data
229        let entries = ndarray::concatenate(
230            Axis(0),
231            &[
232                Array::random((npoints, 2), Normal::new(-1., 0.1).unwrap()).view(),
233                Array::random((npoints, 2), Normal::new(1., 0.1).unwrap()).view(),
234            ],
235        )
236        .unwrap();
237
238        let kernel = Kernel::params()
239            .method(KernelMethod::Gaussian(5.0))
240            .transform(entries.view());
241
242        let kernel = HierarchicalCluster::default()
243            .max_distance(0.1)
244            .transform(kernel)
245            .unwrap();
246
247        // check that all assigned ids are equal for the first cluster
248        let ids = kernel.targets();
249        let first_cluster_id = &ids[0];
250        assert!(ids
251            .iter()
252            .take(npoints)
253            .all(|item| item == first_cluster_id));
254
255        // and for the second
256        let second_cluster_id = &ids[npoints];
257        assert!(ids
258            .iter()
259            .skip(npoints)
260            .all(|item| item == second_cluster_id));
261
262        // the cluster ids shouldn't be equal
263        assert_ne!(first_cluster_id, second_cluster_id);
264
265        // perform hierarchical clustering until we have two clusters left
266        let kernel = HierarchicalCluster::default()
267            .num_clusters(2)
268            .transform(kernel)
269            .unwrap();
270
271        // check that all assigned ids are equal for the first cluster
272        let ids = kernel.targets();
273        let first_cluster_id = &ids[0];
274        assert!(ids
275            .iter()
276            .take(npoints)
277            .all(|item| item == first_cluster_id));
278
279        // and for the second
280        let second_cluster_id = &ids[npoints];
281        assert!(ids
282            .iter()
283            .skip(npoints)
284            .all(|item| item == second_cluster_id));
285
286        // the cluster ids shouldn't be equal
287        assert_ne!(first_cluster_id, second_cluster_id);
288    }
289
290    #[test]
291    fn test_noise() {
292        // generate 1000 normal distributed points
293        let data = Array::random((100, 2), Normal::new(0., 1.0).unwrap());
294
295        let kernel = Kernel::params()
296            .method(KernelMethod::Linear)
297            .transform(data.view());
298
299        let predictions = HierarchicalCluster::default()
300            .max_distance(3.0)
301            .transform(kernel)
302            .unwrap();
303
304        dbg!(&predictions.targets());
305    }
306}