linfa_hierarchical/
lib.rs1use std::collections::HashMap;
22
23use kodama::linkage;
24pub use kodama::Method;
25
26use linfa::param_guard::TransformGuard;
27use linfa::traits::Transformer;
28use linfa::Float;
29use linfa::{dataset::DatasetBase, ParamGuard};
30use linfa_kernel::Kernel;
31
32pub use error::{HierarchicalError, Result};
33
34mod error;
35
36#[derive(Clone, Debug, PartialEq)]
42pub enum Criterion<F: Float> {
43 NumClusters(usize),
44 Distance(F),
45}
46
47#[derive(Default, Debug, Clone, PartialEq)]
53pub struct HierarchicalCluster<T: Float>(ValidHierarchicalCluster<T>);
54
55#[derive(Clone, Debug, PartialEq)]
57pub struct ValidHierarchicalCluster<T: Float> {
58 method: Method,
59 stopping: Criterion<T>,
60}
61
62impl<F: Float> ParamGuard for HierarchicalCluster<F> {
63 type Checked = ValidHierarchicalCluster<F>;
64 type Error = HierarchicalError<F>;
65
66 fn check_ref(&self) -> std::result::Result<&Self::Checked, Self::Error> {
67 match self.0.stopping {
68 Criterion::NumClusters(0) => Err(HierarchicalError::InvalidStoppingCondition(
69 self.0.stopping.clone(),
70 )),
71 Criterion::Distance(x) if x.is_negative() || x.is_nan() || x.is_infinite() => Err(
72 HierarchicalError::InvalidStoppingCondition(self.0.stopping.clone()),
73 ),
74 _ => Ok(&self.0),
75 }
76 }
77
78 fn check(self) -> std::result::Result<Self::Checked, Self::Error> {
79 self.check_ref()?;
80 Ok(self.0)
81 }
82}
83impl<F: Float> TransformGuard for HierarchicalCluster<F> {}
84
85impl<F: Float> HierarchicalCluster<F> {
86 pub fn with_method(mut self, method: Method) -> HierarchicalCluster<F> {
88 self.0.method = method;
89 self
90 }
91
92 pub fn num_clusters(mut self, num_clusters: usize) -> HierarchicalCluster<F> {
97 self.0.stopping = Criterion::NumClusters(num_clusters);
98 self
99 }
100
101 pub fn max_distance(mut self, max_distance: F) -> HierarchicalCluster<F> {
106 self.0.stopping = Criterion::Distance(max_distance);
107 self
108 }
109}
110
111impl<F: Float> Transformer<Kernel<F>, DatasetBase<Kernel<F>, Vec<usize>>>
112 for ValidHierarchicalCluster<F>
113{
114 fn transform(&self, kernel: Kernel<F>) -> DatasetBase<Kernel<F>, Vec<usize>> {
118 let threshold = F::cast(1e-6);
120
121 let mut distance = kernel
123 .to_upper_triangle()
124 .into_iter()
125 .map(|x| {
126 if x > threshold {
127 -x.ln()
128 } else {
129 -threshold.ln()
130 }
131 })
132 .collect::<Vec<_>>();
133
134 let num_observations = kernel.size();
136 let res = linkage(&mut distance, num_observations, self.method);
137
138 let mut clusters = (0..num_observations)
141 .map(|x| (x, vec![x]))
142 .collect::<HashMap<_, _>>();
143
144 let mut ct = num_observations;
146
147 for step in res.steps() {
148 let should_stop = match self.stopping {
149 Criterion::NumClusters(max_clusters) => clusters.len() <= max_clusters,
150 Criterion::Distance(dis) => step.dissimilarity >= dis,
151 };
152
153 if should_stop {
155 break;
156 }
157
158 let mut ids = Vec::with_capacity(2);
160 let mut cl = clusters.remove(&step.cluster1).unwrap();
161 ids.append(&mut cl);
162 let mut cl = clusters.remove(&step.cluster2).unwrap();
163 ids.append(&mut cl);
164
165 clusters.insert(ct, ids);
167 ct += 1;
168 }
169
170 let mut tmp = vec![0; num_observations];
172 for (i, (_, ids)) in clusters.into_iter().enumerate() {
173 for id in ids {
174 tmp[id] = i;
175 }
176 }
177
178 DatasetBase::new(kernel, tmp)
180 }
181}
182
183impl<F: Float, T> Transformer<DatasetBase<Kernel<F>, T>, DatasetBase<Kernel<F>, Vec<usize>>>
184 for ValidHierarchicalCluster<F>
185{
186 fn transform(&self, dataset: DatasetBase<Kernel<F>, T>) -> DatasetBase<Kernel<F>, Vec<usize>> {
190 self.transform(dataset.records)
191 }
192}
193
194impl<T: Float> Default for ValidHierarchicalCluster<T> {
197 fn default() -> Self {
198 Self {
199 method: Method::Average,
200 stopping: Criterion::NumClusters(2),
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use crate::HierarchicalError;
208 use linfa::traits::Transformer;
209 use linfa_kernel::{Kernel, KernelMethod};
210 use ndarray::{Array, Axis};
211 use ndarray_rand::{rand_distr::Normal, RandomExt};
212
213 use super::{Criterion, HierarchicalCluster, ValidHierarchicalCluster};
214
215 #[test]
216 fn autotraits() {
217 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
218 has_autotraits::<Criterion<f64>>();
219 has_autotraits::<HierarchicalCluster<f64>>();
220 has_autotraits::<ValidHierarchicalCluster<f64>>();
221 has_autotraits::<HierarchicalError<f64>>();
222 }
223
224 #[test]
225 fn test_blobs() {
226 let npoints = 10;
228 let entries = ndarray::concatenate(
230 Axis(0),
231 &[
232 Array::random((npoints, 2), Normal::new(-1., 0.1).unwrap()).view(),
233 Array::random((npoints, 2), Normal::new(1., 0.1).unwrap()).view(),
234 ],
235 )
236 .unwrap();
237
238 let kernel = Kernel::params()
239 .method(KernelMethod::Gaussian(5.0))
240 .transform(entries.view());
241
242 let kernel = HierarchicalCluster::default()
243 .max_distance(0.1)
244 .transform(kernel)
245 .unwrap();
246
247 let ids = kernel.targets();
249 let first_cluster_id = &ids[0];
250 assert!(ids
251 .iter()
252 .take(npoints)
253 .all(|item| item == first_cluster_id));
254
255 let second_cluster_id = &ids[npoints];
257 assert!(ids
258 .iter()
259 .skip(npoints)
260 .all(|item| item == second_cluster_id));
261
262 assert_ne!(first_cluster_id, second_cluster_id);
264
265 let kernel = HierarchicalCluster::default()
267 .num_clusters(2)
268 .transform(kernel)
269 .unwrap();
270
271 let ids = kernel.targets();
273 let first_cluster_id = &ids[0];
274 assert!(ids
275 .iter()
276 .take(npoints)
277 .all(|item| item == first_cluster_id));
278
279 let second_cluster_id = &ids[npoints];
281 assert!(ids
282 .iter()
283 .skip(npoints)
284 .all(|item| item == second_cluster_id));
285
286 assert_ne!(first_cluster_id, second_cluster_id);
288 }
289
290 #[test]
291 fn test_noise() {
292 let data = Array::random((100, 2), Normal::new(0., 1.0).unwrap());
294
295 let kernel = Kernel::params()
296 .method(KernelMethod::Linear)
297 .transform(data.view());
298
299 let predictions = HierarchicalCluster::default()
300 .max_distance(3.0)
301 .transform(kernel)
302 .unwrap();
303
304 dbg!(&predictions.targets());
305 }
306}