1#![allow(clippy::type_complexity)]
2use std::{cmp::Reverse, collections::BinaryHeap};
3
4use linfa::Float;
5use ndarray::{Array1, ArrayBase, Data, Ix2};
6use noisy_float::{checkers::FiniteChecker, NoisyFloat};
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10use crate::{
11 distance::Distance,
12 heap_elem::{MaxHeapElem, MinHeapElem},
13 BuildError, NearestNeighbour, NearestNeighbourBox, NearestNeighbourIndex, NnError, Point,
14};
15
16fn partition<F: Float>(
18 mut points: Vec<(Point<F>, usize)>,
19) -> (Vec<(Point<F>, usize)>, Point<F>, Vec<(Point<F>, usize)>) {
20 debug_assert!(points.len() >= 2);
21
22 let max_spread_dim = (0..points[0].0.len())
25 .map(|dim| {
26 let (max, min) = points
28 .iter()
29 .map(|p| p.0[dim])
30 .fold((F::neg_infinity(), F::infinity()), |(a, b), c| {
31 (F::max(a, c), F::min(b, c))
32 });
33
34 (dim, NoisyFloat::<_, FiniteChecker>::new(max - min))
35 })
36 .max_by_key(|&(_, range)| range)
37 .expect("vec has no dimensions")
38 .0;
39
40 let mid = points.len() / 2;
41 let median = order_stat::kth_by(&mut points, mid, |p1, p2| {
43 p1.0[max_spread_dim]
44 .partial_cmp(&p2.0[max_spread_dim])
45 .expect("NaN in data")
46 })
47 .0
48 .reborrow();
49
50 let (mut left, mut right): (Vec<_>, Vec<_>) = points
51 .into_iter()
52 .partition(|pt| pt.0[max_spread_dim] < median[max_spread_dim]);
53 if left.is_empty() {
57 left.push(right.pop().unwrap());
58 }
59 (left, median, right)
60}
61
62fn calc_radius<'a, F: Float, D: Distance<F>>(
64 points: impl Iterator<Item = Point<'a, F>>,
65 center: Point<F>,
66 dist_fn: &D,
67) -> F {
68 let r_rad = points
69 .map(|pt| NoisyFloat::<_, FiniteChecker>::new(dist_fn.rdistance(pt, center)))
70 .max()
71 .unwrap()
72 .raw();
73 dist_fn.rdist_to_dist(r_rad)
74}
75
76#[derive(Debug, PartialEq, Clone)]
77enum BallTreeInner<'a, F: Float> {
78 Leaf {
80 center: Array1<F>,
81 radius: F,
82 points: Vec<(Point<'a, F>, usize)>,
83 },
84 Branch {
86 center: Point<'a, F>,
87 radius: F,
88 left: Box<BallTreeInner<'a, F>>,
89 right: Box<BallTreeInner<'a, F>>,
90 },
91}
92
93impl<'a, F: Float> BallTreeInner<'a, F> {
94 fn new<D: Distance<F>>(
95 points: Vec<(Point<'a, F>, usize)>,
96 leaf_size: usize,
97 dist_fn: &D,
98 ) -> Self {
99 if points.len() <= leaf_size {
100 if let Some(dim) = points.first().map(|p| p.0.len()) {
102 let center = {
105 let mut c = Array1::zeros(dim);
106 points.iter().for_each(|p| c += &p.0);
107 c / F::from(points.len()).unwrap()
108 };
109 let radius = calc_radius(
110 points.iter().map(|p| p.0.reborrow()),
111 center.view(),
112 dist_fn,
113 );
114 BallTreeInner::Leaf {
115 center,
116 radius,
117 points,
118 }
119 } else {
120 BallTreeInner::Leaf {
122 center: Array1::zeros(0),
123 points,
124 radius: F::zero(),
125 }
126 }
127 } else {
128 let (aps, center, bps) = partition(points);
130 debug_assert!(!aps.is_empty() && !bps.is_empty());
131 let radius = calc_radius(
132 aps.iter().chain(bps.iter()).map(|p| p.0.reborrow()),
133 center,
134 dist_fn,
135 );
136 let a_tree = BallTreeInner::new(aps, leaf_size, dist_fn);
137 let b_tree = BallTreeInner::new(bps, leaf_size, dist_fn);
138 BallTreeInner::Branch {
139 center,
140 radius,
141 left: Box::new(a_tree),
142 right: Box::new(b_tree),
143 }
144 }
145 }
146
147 fn rdistance<D: Distance<F>>(&self, p: Point<F>, dist_fn: &D) -> F {
148 let (center, radius) = match self {
149 BallTreeInner::Leaf { center, radius, .. } => (center.view(), radius),
150 BallTreeInner::Branch { center, radius, .. } => (center.reborrow(), radius),
151 };
152
153 let border_dist = dist_fn.distance(p, center.reborrow()) - *radius;
157 dist_fn.dist_to_rdist(border_dist.max(F::zero()))
158 }
159}
160
161#[derive(Debug, Clone, PartialEq)]
163pub struct BallTreeIndex<'a, F: Float, D: Distance<F>> {
164 tree: BallTreeInner<'a, F>,
165 dist_fn: D,
166 dim: usize,
167 len: usize,
168}
169
170impl<'a, F: Float, D: Distance<F>> BallTreeIndex<'a, F, D> {
171 pub fn new<DT: Data<Elem = F>>(
173 batch: &'a ArrayBase<DT, Ix2>,
174 leaf_size: usize,
175 dist_fn: D,
176 ) -> Result<Self, BuildError> {
177 let dim = batch.ncols();
178 let len = batch.nrows();
179 if leaf_size == 0 {
180 Err(BuildError::EmptyLeaf)
181 } else if dim == 0 {
182 Err(BuildError::ZeroDimension)
183 } else {
184 let points: Vec<_> = batch
185 .rows()
186 .into_iter()
187 .enumerate()
188 .map(|(i, pt)| (pt, i))
189 .collect();
190 Ok(BallTreeIndex {
191 tree: BallTreeInner::new(points, leaf_size, &dist_fn),
192 dist_fn,
193 dim,
194 len,
195 })
196 }
197 }
198
199 fn nn_helper(
200 &self,
201 point: Point<'_, F>,
202 k: usize,
203 max_radius: F,
204 ) -> Result<Vec<(Point<F>, usize)>, NnError> {
205 if self.dim != point.len() {
206 Err(NnError::WrongDimension)
207 } else if self.len == 0 {
208 Ok(Vec::new())
209 } else {
210 let mut out: BinaryHeap<MaxHeapElem<_, _>> = BinaryHeap::new();
211 let mut queue = BinaryHeap::new();
212 queue.push(MinHeapElem::new(
213 self.tree.rdistance(point, &self.dist_fn),
214 &self.tree,
215 ));
216
217 while let Some(MinHeapElem {
218 dist: Reverse(dist),
219 elem,
220 }) = queue.pop()
221 {
222 if dist >= max_radius || (out.len() == k && dist >= out.peek().unwrap().dist) {
223 break;
224 }
225
226 match elem {
227 BallTreeInner::Leaf { points, .. } => {
228 for p in points {
229 let dist = self.dist_fn.rdistance(point, p.0.reborrow());
230 if dist < max_radius
231 && (out.len() < k || out.peek().unwrap().dist > dist)
232 {
233 out.push(MaxHeapElem::new(dist, p));
234 if out.len() > k {
235 out.pop();
236 }
237 }
238 }
239 }
240 BallTreeInner::Branch { left, right, .. } => {
241 let dl = left.rdistance(point, &self.dist_fn);
242 let dr = right.rdistance(point, &self.dist_fn);
243
244 if dl <= max_radius {
245 queue.push(MinHeapElem::new(dl, left));
246 }
247 if dr <= max_radius {
248 queue.push(MinHeapElem::new(dr, right));
249 }
250 }
251 }
252 }
253 Ok(out
254 .into_sorted_vec()
255 .into_iter()
256 .map(|e| e.elem)
257 .map(|(pt, i)| (pt.reborrow(), *i))
258 .collect())
259 }
260 }
261}
262
263impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for BallTreeIndex<'_, F, D> {
264 fn k_nearest(&self, point: Point<'_, F>, k: usize) -> Result<Vec<(Point<F>, usize)>, NnError> {
265 self.nn_helper(point, k, F::infinity())
266 }
267
268 fn within_range(
269 &self,
270 point: Point<'_, F>,
271 range: F,
272 ) -> Result<Vec<(Point<F>, usize)>, NnError> {
273 let range = self.dist_fn.dist_to_rdist(range);
274 self.nn_helper(point, self.len, range)
275 }
276}
277
278#[derive(Default, Clone, Debug, PartialEq, Eq)]
286#[cfg_attr(
287 feature = "serde",
288 derive(Serialize, Deserialize),
289 serde(crate = "serde_crate")
290)]
291pub struct BallTree;
292
293impl BallTree {
294 pub fn new() -> Self {
296 Self
297 }
298}
299
300impl NearestNeighbour for BallTree {
301 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
302 &self,
303 batch: &'a ArrayBase<DT, Ix2>,
304 leaf_size: usize,
305 dist_fn: D,
306 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
307 BallTreeIndex::new(batch, leaf_size, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
308 }
309}
310
311#[cfg(test)]
312mod test {
313 use approx::assert_abs_diff_eq;
314 use ndarray::{arr1, arr2, stack, Array1, Array2, Axis};
315
316 use crate::distance::L2Dist;
317
318 use super::*;
319
320 #[test]
321 fn autotraits() {
322 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
323 has_autotraits::<BallTree>();
324 has_autotraits::<BallTreeIndex<f64, L2Dist>>();
325 has_autotraits::<BallTreeInner<f64>>();
326 }
327
328 fn assert_partition(
329 input: Array2<f64>,
330 exp_left: Array2<f64>,
331 exp_med: Array1<f64>,
332 exp_right: Array2<f64>,
333 exp_rad: f64,
334 ) {
335 let vec: Vec<_> = input
336 .rows()
337 .into_iter()
338 .enumerate()
339 .map(|(i, p)| (p, i))
340 .collect();
341 let (l, mid, r) = partition(vec.clone());
342 let l: Vec<_> = l.into_iter().map(|(p, _)| p).collect();
343 let r: Vec<_> = r.into_iter().map(|(p, _)| p).collect();
344 assert_abs_diff_eq!(stack(Axis(0), &l).unwrap(), exp_left);
345 assert_abs_diff_eq!(mid.to_owned(), exp_med);
346 assert_abs_diff_eq!(stack(Axis(0), &r).unwrap(), exp_right);
347 assert_abs_diff_eq!(
348 calc_radius(vec.iter().map(|(p, _)| p.reborrow()), mid, &L2Dist),
349 exp_rad
350 );
351 }
352
353 #[test]
354 fn partition_test() {
355 assert_partition(
357 arr2(&[[0.0, 1.0], [2.0, 3.0]]),
358 arr2(&[[0.0, 1.0]]),
359 arr1(&[2.0, 3.0]),
360 arr2(&[[2.0, 3.0]]),
361 8.0f64.sqrt(),
362 );
363 assert_partition(
364 arr2(&[[2.0, 3.0], [0.0, 1.0]]),
365 arr2(&[[0.0, 1.0]]),
366 arr1(&[2.0, 3.0]),
367 arr2(&[[2.0, 3.0]]),
368 8.0f64.sqrt(),
369 );
370
371 assert_partition(
373 arr2(&[[0.3, 5.0], [4.5, 7.0], [8.1, 1.5]]),
374 arr2(&[[0.3, 5.0]]),
375 arr1(&[4.5, 7.0]),
376 arr2(&[[4.5, 7.0], [8.1, 1.5]]),
377 43.21f64.sqrt(),
378 );
379
380 assert_partition(
382 arr2(&[[1.4, 4.3], [1.4, 4.3], [1.4, 4.3], [1.4, 4.3]]),
383 arr2(&[[1.4, 4.3]]),
384 arr1(&[1.4, 4.3]),
385 arr2(&[[1.4, 4.3], [1.4, 4.3], [1.4, 4.3]]),
386 0.0,
387 );
388 }
389}