linfa_nn/
kdtree.rs

1use linfa::Float;
2use ndarray::{aview1, ArrayBase, Data, Ix2};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6use crate::{
7    distance::Distance, BuildError, NearestNeighbour, NearestNeighbourBox, NearestNeighbourIndex,
8    NnError, Point,
9};
10
11/// Spatial indexing structure created by [`KdTree`]
12#[derive(Debug)]
13pub struct KdTreeIndex<'a, F: Float, D: Distance<F>>(
14    kdtree::KdTree<F, (Point<'a, F>, usize), &'a [F]>,
15    D,
16);
17
18impl<'a, F: Float, D: Distance<F>> KdTreeIndex<'a, F, D> {
19    /// Creates a new `KdTreeIndex`
20    pub fn new<DT: Data<Elem = F>>(
21        batch: &'a ArrayBase<DT, Ix2>,
22        leaf_size: usize,
23        dist_fn: D,
24    ) -> Result<Self, BuildError> {
25        if leaf_size == 0 {
26            Err(BuildError::EmptyLeaf)
27        } else if batch.ncols() == 0 {
28            Err(BuildError::ZeroDimension)
29        } else {
30            let mut tree = kdtree::KdTree::with_capacity(batch.ncols().max(1), leaf_size);
31            for (i, point) in batch.rows().into_iter().enumerate() {
32                tree.add(
33                    point.to_slice().expect("views should be contiguous"),
34                    (point, i),
35                )
36                .unwrap();
37            }
38            Ok(Self(tree, dist_fn))
39        }
40    }
41}
42
43impl From<kdtree::ErrorKind> for NnError {
44    fn from(err: kdtree::ErrorKind) -> Self {
45        match err {
46            kdtree::ErrorKind::WrongDimension => NnError::WrongDimension,
47            kdtree::ErrorKind::NonFiniteCoordinate => panic!("infinite value found"),
48            _ => unreachable!(),
49        }
50    }
51}
52
53impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for KdTreeIndex<'_, F, D> {
54    fn k_nearest(&self, point: Point<'_, F>, k: usize) -> Result<Vec<(Point<F>, usize)>, NnError> {
55        Ok(self
56            .0
57            .nearest(
58                point.to_slice().expect("views should be contiguous"),
59                k,
60                &|a, b| self.1.rdistance(aview1(a), aview1(b)),
61            )?
62            .into_iter()
63            .map(|(_, (pt, pos))| (pt.reborrow(), *pos))
64            .collect())
65    }
66
67    fn within_range(
68        &self,
69        point: Point<'_, F>,
70        range: F,
71    ) -> Result<Vec<(Point<F>, usize)>, NnError> {
72        let range = self.1.dist_to_rdist(range);
73        Ok(self
74            .0
75            .within(
76                point.to_slice().expect("views should be contiguous"),
77                range,
78                &|a, b| self.1.rdistance(aview1(a), aview1(b)),
79            )?
80            .into_iter()
81            .map(|(_, (pt, pos))| (pt.reborrow(), *pos))
82            .collect())
83    }
84}
85
86/// Implementation of K-D tree, a fast space-partitioning data structure.  For each parent node,
87/// the indexed points are split with a hyperplane into two child nodes. Due to its tree-like
88/// structure, the K-D tree performs spatial queries in `O(k * logN)` time, where `k` is the number
89/// of points returned by the query. Calling `from_batch` returns a [`KdTree`].
90///
91/// More details can be found [here](https://en.wikipedia.org/wiki/K-d_tree).
92///
93/// Unlike other `NearestNeighbour` implementations, `KdTree` requires that points be laid out
94/// contiguously in memory and will panic otherwise.
95#[derive(Default, Clone, Debug, PartialEq, Eq)]
96#[cfg_attr(
97    feature = "serde",
98    derive(Serialize, Deserialize),
99    serde(crate = "serde_crate")
100)]
101pub struct KdTree;
102
103impl KdTree {
104    /// Creates an instance of `KdTree`
105    pub fn new() -> Self {
106        Self
107    }
108}
109
110impl NearestNeighbour for KdTree {
111    fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
112        &self,
113        batch: &'a ArrayBase<DT, Ix2>,
114        leaf_size: usize,
115        dist_fn: D,
116    ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
117        KdTreeIndex::new(batch, leaf_size, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
118    }
119}