linfa_nn/
balltree.rs

1#![allow(clippy::type_complexity)]
2use std::{cmp::Reverse, collections::BinaryHeap};
3
4use linfa::Float;
5use ndarray::{Array1, ArrayBase, Data, Ix2};
6use noisy_float::{checkers::FiniteChecker, NoisyFloat};
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10use crate::{
11    distance::Distance,
12    heap_elem::{MaxHeapElem, MinHeapElem},
13    BuildError, NearestNeighbour, NearestNeighbourBox, NearestNeighbourIndex, NnError, Point,
14};
15
16// Partition the points using median value
17fn partition<F: Float>(
18    mut points: Vec<(Point<F>, usize)>,
19) -> (Vec<(Point<F>, usize)>, Point<F>, Vec<(Point<F>, usize)>) {
20    debug_assert!(points.len() >= 2);
21
22    // Spread of a dimension is measured using range, which is suceptible to skew. It may be better
23    // to use STD or variance.
24    let max_spread_dim = (0..points[0].0.len())
25        .map(|dim| {
26            // Find the range of each dimension
27            let (max, min) = points
28                .iter()
29                .map(|p| p.0[dim])
30                .fold((F::neg_infinity(), F::infinity()), |(a, b), c| {
31                    (F::max(a, c), F::min(b, c))
32                });
33
34            (dim, NoisyFloat::<_, FiniteChecker>::new(max - min))
35        })
36        .max_by_key(|&(_, range)| range)
37        .expect("vec has no dimensions")
38        .0;
39
40    let mid = points.len() / 2;
41    // Compute median on the chosen dimension in linear time
42    let median = order_stat::kth_by(&mut points, mid, |p1, p2| {
43        p1.0[max_spread_dim]
44            .partial_cmp(&p2.0[max_spread_dim])
45            .expect("NaN in data")
46    })
47    .0
48    .reborrow();
49
50    let (mut left, mut right): (Vec<_>, Vec<_>) = points
51        .into_iter()
52        .partition(|pt| pt.0[max_spread_dim] < median[max_spread_dim]);
53    // We can get an empty left partition with degenerate data where all points are equal and
54    // gathered in the right partition.  This ensures that the larger partition will always shrink,
55    // guaranteeing algorithm termination.
56    if left.is_empty() {
57        left.push(right.pop().unwrap());
58    }
59    (left, median, right)
60}
61
62// Calculates radius of a bounding sphere
63fn calc_radius<'a, F: Float, D: Distance<F>>(
64    points: impl Iterator<Item = Point<'a, F>>,
65    center: Point<F>,
66    dist_fn: &D,
67) -> F {
68    let r_rad = points
69        .map(|pt| NoisyFloat::<_, FiniteChecker>::new(dist_fn.rdistance(pt, center)))
70        .max()
71        .unwrap()
72        .raw();
73    dist_fn.rdist_to_dist(r_rad)
74}
75
76#[derive(Debug, PartialEq, Clone)]
77enum BallTreeInner<'a, F: Float> {
78    // Leaf node sphere
79    Leaf {
80        center: Array1<F>,
81        radius: F,
82        points: Vec<(Point<'a, F>, usize)>,
83    },
84    // Sphere that encompasses both children
85    Branch {
86        center: Point<'a, F>,
87        radius: F,
88        left: Box<BallTreeInner<'a, F>>,
89        right: Box<BallTreeInner<'a, F>>,
90    },
91}
92
93impl<'a, F: Float> BallTreeInner<'a, F> {
94    fn new<D: Distance<F>>(
95        points: Vec<(Point<'a, F>, usize)>,
96        leaf_size: usize,
97        dist_fn: &D,
98    ) -> Self {
99        if points.len() <= leaf_size {
100            // Leaf node
101            if let Some(dim) = points.first().map(|p| p.0.len()) {
102                // Since we don't need to partition, we can center the sphere around the average of
103                // all points
104                let center = {
105                    let mut c = Array1::zeros(dim);
106                    points.iter().for_each(|p| c += &p.0);
107                    c / F::from(points.len()).unwrap()
108                };
109                let radius = calc_radius(
110                    points.iter().map(|p| p.0.reborrow()),
111                    center.view(),
112                    dist_fn,
113                );
114                BallTreeInner::Leaf {
115                    center,
116                    radius,
117                    points,
118                }
119            } else {
120                // In case of an empty tree
121                BallTreeInner::Leaf {
122                    center: Array1::zeros(0),
123                    points,
124                    radius: F::zero(),
125                }
126            }
127        } else {
128            // Non-leaf node
129            let (aps, center, bps) = partition(points);
130            debug_assert!(!aps.is_empty() && !bps.is_empty());
131            let radius = calc_radius(
132                aps.iter().chain(bps.iter()).map(|p| p.0.reborrow()),
133                center,
134                dist_fn,
135            );
136            let a_tree = BallTreeInner::new(aps, leaf_size, dist_fn);
137            let b_tree = BallTreeInner::new(bps, leaf_size, dist_fn);
138            BallTreeInner::Branch {
139                center,
140                radius,
141                left: Box::new(a_tree),
142                right: Box::new(b_tree),
143            }
144        }
145    }
146
147    fn rdistance<D: Distance<F>>(&self, p: Point<F>, dist_fn: &D) -> F {
148        let (center, radius) = match self {
149            BallTreeInner::Leaf { center, radius, .. } => (center.view(), radius),
150            BallTreeInner::Branch { center, radius, .. } => (center.reborrow(), radius),
151        };
152
153        // The distance to a sphere is the distance to its edge, so the distance between a point
154        // and a sphere will always be less than the distance between the point and anything inside
155        // the sphere
156        let border_dist = dist_fn.distance(p, center.reborrow()) - *radius;
157        dist_fn.dist_to_rdist(border_dist.max(F::zero()))
158    }
159}
160
161/// Spatial indexing structure created by [`BallTree`]
162#[derive(Debug, Clone, PartialEq)]
163pub struct BallTreeIndex<'a, F: Float, D: Distance<F>> {
164    tree: BallTreeInner<'a, F>,
165    dist_fn: D,
166    dim: usize,
167    len: usize,
168}
169
170impl<'a, F: Float, D: Distance<F>> BallTreeIndex<'a, F, D> {
171    /// Creates a `BallTreeIndex` using the K-D construction algorithm
172    pub fn new<DT: Data<Elem = F>>(
173        batch: &'a ArrayBase<DT, Ix2>,
174        leaf_size: usize,
175        dist_fn: D,
176    ) -> Result<Self, BuildError> {
177        let dim = batch.ncols();
178        let len = batch.nrows();
179        if leaf_size == 0 {
180            Err(BuildError::EmptyLeaf)
181        } else if dim == 0 {
182            Err(BuildError::ZeroDimension)
183        } else {
184            let points: Vec<_> = batch
185                .rows()
186                .into_iter()
187                .enumerate()
188                .map(|(i, pt)| (pt, i))
189                .collect();
190            Ok(BallTreeIndex {
191                tree: BallTreeInner::new(points, leaf_size, &dist_fn),
192                dist_fn,
193                dim,
194                len,
195            })
196        }
197    }
198
199    fn nn_helper(
200        &self,
201        point: Point<'_, F>,
202        k: usize,
203        max_radius: F,
204    ) -> Result<Vec<(Point<F>, usize)>, NnError> {
205        if self.dim != point.len() {
206            Err(NnError::WrongDimension)
207        } else if self.len == 0 {
208            Ok(Vec::new())
209        } else {
210            let mut out: BinaryHeap<MaxHeapElem<_, _>> = BinaryHeap::new();
211            let mut queue = BinaryHeap::new();
212            queue.push(MinHeapElem::new(
213                self.tree.rdistance(point, &self.dist_fn),
214                &self.tree,
215            ));
216
217            while let Some(MinHeapElem {
218                dist: Reverse(dist),
219                elem,
220            }) = queue.pop()
221            {
222                if dist >= max_radius || (out.len() == k && dist >= out.peek().unwrap().dist) {
223                    break;
224                }
225
226                match elem {
227                    BallTreeInner::Leaf { points, .. } => {
228                        for p in points {
229                            let dist = self.dist_fn.rdistance(point, p.0.reborrow());
230                            if dist < max_radius
231                                && (out.len() < k || out.peek().unwrap().dist > dist)
232                            {
233                                out.push(MaxHeapElem::new(dist, p));
234                                if out.len() > k {
235                                    out.pop();
236                                }
237                            }
238                        }
239                    }
240                    BallTreeInner::Branch { left, right, .. } => {
241                        let dl = left.rdistance(point, &self.dist_fn);
242                        let dr = right.rdistance(point, &self.dist_fn);
243
244                        if dl <= max_radius {
245                            queue.push(MinHeapElem::new(dl, left));
246                        }
247                        if dr <= max_radius {
248                            queue.push(MinHeapElem::new(dr, right));
249                        }
250                    }
251                }
252            }
253            Ok(out
254                .into_sorted_vec()
255                .into_iter()
256                .map(|e| e.elem)
257                .map(|(pt, i)| (pt.reborrow(), *i))
258                .collect())
259        }
260    }
261}
262
263impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for BallTreeIndex<'_, F, D> {
264    fn k_nearest(&self, point: Point<'_, F>, k: usize) -> Result<Vec<(Point<F>, usize)>, NnError> {
265        self.nn_helper(point, k, F::infinity())
266    }
267
268    fn within_range(
269        &self,
270        point: Point<'_, F>,
271        range: F,
272    ) -> Result<Vec<(Point<F>, usize)>, NnError> {
273        let range = self.dist_fn.dist_to_rdist(range);
274        self.nn_helper(point, self.len, range)
275    }
276}
277
278/// Implementation of ball tree, a space partitioning data structure that partitions its points
279/// into nested hyperspheres called "balls". It performs spatial queries in `O(k * logN)` time,
280/// where `k` is the number of points returned by the query. Calling `from_batch` returns a
281/// [`BallTreeIndex`].
282///
283/// More details can be found [here](https://en.wikipedia.org/wiki/Ball_tree). This implementation
284/// is based off of the [ball_tree](https://docs.rs/ball-tree/0.2.0/ball_tree/) crate.
285#[derive(Default, Clone, Debug, PartialEq, Eq)]
286#[cfg_attr(
287    feature = "serde",
288    derive(Serialize, Deserialize),
289    serde(crate = "serde_crate")
290)]
291pub struct BallTree;
292
293impl BallTree {
294    /// Creates an instance of `BallTree`
295    pub fn new() -> Self {
296        Self
297    }
298}
299
300impl NearestNeighbour for BallTree {
301    fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
302        &self,
303        batch: &'a ArrayBase<DT, Ix2>,
304        leaf_size: usize,
305        dist_fn: D,
306    ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
307        BallTreeIndex::new(batch, leaf_size, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
308    }
309}
310
311#[cfg(test)]
312mod test {
313    use approx::assert_abs_diff_eq;
314    use ndarray::{arr1, arr2, stack, Array1, Array2, Axis};
315
316    use crate::distance::L2Dist;
317
318    use super::*;
319
320    #[test]
321    fn autotraits() {
322        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
323        has_autotraits::<BallTree>();
324        has_autotraits::<BallTreeIndex<f64, L2Dist>>();
325        has_autotraits::<BallTreeInner<f64>>();
326    }
327
328    fn assert_partition(
329        input: Array2<f64>,
330        exp_left: Array2<f64>,
331        exp_med: Array1<f64>,
332        exp_right: Array2<f64>,
333        exp_rad: f64,
334    ) {
335        let vec: Vec<_> = input
336            .rows()
337            .into_iter()
338            .enumerate()
339            .map(|(i, p)| (p, i))
340            .collect();
341        let (l, mid, r) = partition(vec.clone());
342        let l: Vec<_> = l.into_iter().map(|(p, _)| p).collect();
343        let r: Vec<_> = r.into_iter().map(|(p, _)| p).collect();
344        assert_abs_diff_eq!(stack(Axis(0), &l).unwrap(), exp_left);
345        assert_abs_diff_eq!(mid.to_owned(), exp_med);
346        assert_abs_diff_eq!(stack(Axis(0), &r).unwrap(), exp_right);
347        assert_abs_diff_eq!(
348            calc_radius(vec.iter().map(|(p, _)| p.reborrow()), mid, &L2Dist),
349            exp_rad
350        );
351    }
352
353    #[test]
354    fn partition_test() {
355        // partition 2 elements
356        assert_partition(
357            arr2(&[[0.0, 1.0], [2.0, 3.0]]),
358            arr2(&[[0.0, 1.0]]),
359            arr1(&[2.0, 3.0]),
360            arr2(&[[2.0, 3.0]]),
361            8.0f64.sqrt(),
362        );
363        assert_partition(
364            arr2(&[[2.0, 3.0], [0.0, 1.0]]),
365            arr2(&[[0.0, 1.0]]),
366            arr1(&[2.0, 3.0]),
367            arr2(&[[2.0, 3.0]]),
368            8.0f64.sqrt(),
369        );
370
371        // Partition along the dimension with highest spread
372        assert_partition(
373            arr2(&[[0.3, 5.0], [4.5, 7.0], [8.1, 1.5]]),
374            arr2(&[[0.3, 5.0]]),
375            arr1(&[4.5, 7.0]),
376            arr2(&[[4.5, 7.0], [8.1, 1.5]]),
377            43.21f64.sqrt(),
378        );
379
380        // Degenerate data
381        assert_partition(
382            arr2(&[[1.4, 4.3], [1.4, 4.3], [1.4, 4.3], [1.4, 4.3]]),
383            arr2(&[[1.4, 4.3]]),
384            arr1(&[1.4, 4.3]),
385            arr2(&[[1.4, 4.3], [1.4, 4.3], [1.4, 4.3]]),
386            0.0,
387        );
388    }
389}