1use std::{cmp::Reverse, collections::BinaryHeap};
2
3use linfa::Float;
4use ndarray::{ArrayBase, ArrayView2, Data, Ix2};
5use noisy_float::NoisyFloat;
6#[cfg(feature = "serde")]
7use serde_crate::{Deserialize, Serialize};
8
9use crate::{
10 distance::Distance, heap_elem::MinHeapElem, BuildError, NearestNeighbour, NearestNeighbourBox,
11 NearestNeighbourIndex, NnError, Point,
12};
13
14#[derive(Debug, Clone, PartialEq)]
16pub struct LinearSearchIndex<'a, F: Float, D: Distance<F>>(ArrayView2<'a, F>, D);
17
18impl<'a, F: Float, D: Distance<F>> LinearSearchIndex<'a, F, D> {
19 pub fn new<DT: Data<Elem = F>>(
21 batch: &'a ArrayBase<DT, Ix2>,
22 dist_fn: D,
23 ) -> Result<Self, BuildError> {
24 if batch.ncols() == 0 {
25 Err(BuildError::ZeroDimension)
26 } else {
27 Ok(Self(batch.view(), dist_fn))
28 }
29 }
30}
31
32impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for LinearSearchIndex<'_, F, D> {
33 fn k_nearest(&self, point: Point<'_, F>, k: usize) -> Result<Vec<(Point<F>, usize)>, NnError> {
34 if self.0.ncols() != point.len() {
35 Err(NnError::WrongDimension)
36 } else {
37 let mut heap = BinaryHeap::with_capacity(self.0.nrows());
38 for (i, pt) in self.0.rows().into_iter().enumerate() {
39 let dist = self.1.rdistance(point.reborrow(), pt.reborrow());
40 heap.push(MinHeapElem {
41 elem: (pt.reborrow(), i),
42 dist: Reverse(NoisyFloat::new(dist)),
43 });
44 }
45
46 Ok((0..k.min(heap.len()))
47 .map(|_| heap.pop().unwrap().elem)
48 .collect())
49 }
50 }
51
52 fn within_range(
53 &self,
54 point: Point<'_, F>,
55 range: F,
56 ) -> Result<Vec<(Point<F>, usize)>, NnError> {
57 if self.0.ncols() != point.len() {
58 Err(NnError::WrongDimension)
59 } else {
60 let range = self.1.dist_to_rdist(range);
61 Ok(self
62 .0
63 .rows()
64 .into_iter()
65 .enumerate()
66 .filter(|(_, pt)| self.1.rdistance(point.reborrow(), pt.reborrow()) < range)
67 .map(|(i, pt)| (pt, i))
68 .collect())
69 }
70 }
71}
72
73#[derive(Default, Clone, Debug, PartialEq, Eq)]
77#[cfg_attr(
78 feature = "serde",
79 derive(Serialize, Deserialize),
80 serde(crate = "serde_crate")
81)]
82pub struct LinearSearch;
83
84impl LinearSearch {
85 pub fn new() -> Self {
87 Self
88 }
89}
90
91impl NearestNeighbour for LinearSearch {
92 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
93 &self,
94 batch: &'a ArrayBase<DT, Ix2>,
95 leaf_size: usize,
96 dist_fn: D,
97 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
98 if leaf_size == 0 {
99 return Err(BuildError::EmptyLeaf);
100 }
101 LinearSearchIndex::new(batch, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
102 }
103}