1use linfa::Float;
2use ndarray::{aview1, ArrayBase, Data, Ix2};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6use crate::{
7 distance::Distance, BuildError, NearestNeighbour, NearestNeighbourBox, NearestNeighbourIndex,
8 NnError, Point,
9};
10
11#[derive(Debug)]
13pub struct KdTreeIndex<'a, F: Float, D: Distance<F>>(
14 kdtree::KdTree<F, (Point<'a, F>, usize), &'a [F]>,
15 D,
16);
17
18impl<'a, F: Float, D: Distance<F>> KdTreeIndex<'a, F, D> {
19 pub fn new<DT: Data<Elem = F>>(
21 batch: &'a ArrayBase<DT, Ix2>,
22 leaf_size: usize,
23 dist_fn: D,
24 ) -> Result<Self, BuildError> {
25 if leaf_size == 0 {
26 Err(BuildError::EmptyLeaf)
27 } else if batch.ncols() == 0 {
28 Err(BuildError::ZeroDimension)
29 } else {
30 let mut tree = kdtree::KdTree::with_capacity(batch.ncols().max(1), leaf_size);
31 for (i, point) in batch.rows().into_iter().enumerate() {
32 tree.add(
33 point.to_slice().expect("views should be contiguous"),
34 (point, i),
35 )
36 .unwrap();
37 }
38 Ok(Self(tree, dist_fn))
39 }
40 }
41}
42
43impl From<kdtree::ErrorKind> for NnError {
44 fn from(err: kdtree::ErrorKind) -> Self {
45 match err {
46 kdtree::ErrorKind::WrongDimension => NnError::WrongDimension,
47 kdtree::ErrorKind::NonFiniteCoordinate => panic!("infinite value found"),
48 _ => unreachable!(),
49 }
50 }
51}
52
53impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for KdTreeIndex<'_, F, D> {
54 fn k_nearest(&self, point: Point<'_, F>, k: usize) -> Result<Vec<(Point<F>, usize)>, NnError> {
55 Ok(self
56 .0
57 .nearest(
58 point.to_slice().expect("views should be contiguous"),
59 k,
60 &|a, b| self.1.rdistance(aview1(a), aview1(b)),
61 )?
62 .into_iter()
63 .map(|(_, (pt, pos))| (pt.reborrow(), *pos))
64 .collect())
65 }
66
67 fn within_range(
68 &self,
69 point: Point<'_, F>,
70 range: F,
71 ) -> Result<Vec<(Point<F>, usize)>, NnError> {
72 let range = self.1.dist_to_rdist(range);
73 Ok(self
74 .0
75 .within(
76 point.to_slice().expect("views should be contiguous"),
77 range,
78 &|a, b| self.1.rdistance(aview1(a), aview1(b)),
79 )?
80 .into_iter()
81 .map(|(_, (pt, pos))| (pt.reborrow(), *pos))
82 .collect())
83 }
84}
85
86#[derive(Default, Clone, Debug, PartialEq, Eq)]
96#[cfg_attr(
97 feature = "serde",
98 derive(Serialize, Deserialize),
99 serde(crate = "serde_crate")
100)]
101pub struct KdTree;
102
103impl KdTree {
104 pub fn new() -> Self {
106 Self
107 }
108}
109
110impl NearestNeighbour for KdTree {
111 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
112 &self,
113 batch: &'a ArrayBase<DT, Ix2>,
114 leaf_size: usize,
115 dist_fn: D,
116 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
117 KdTreeIndex::new(batch, leaf_size, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
118 }
119}