linfa_nn/
linear.rs

1use std::{cmp::Reverse, collections::BinaryHeap};
2
3use linfa::Float;
4use ndarray::{ArrayBase, ArrayView2, Data, Ix2};
5use noisy_float::NoisyFloat;
6#[cfg(feature = "serde")]
7use serde_crate::{Deserialize, Serialize};
8
9use crate::{
10    distance::Distance, heap_elem::MinHeapElem, BuildError, NearestNeighbour, NearestNeighbourBox,
11    NearestNeighbourIndex, NnError, Point,
12};
13
14/// Spatial indexing structure created by [`LinearSearch`]
15#[derive(Debug, Clone, PartialEq)]
16pub struct LinearSearchIndex<'a, F: Float, D: Distance<F>>(ArrayView2<'a, F>, D);
17
18impl<'a, F: Float, D: Distance<F>> LinearSearchIndex<'a, F, D> {
19    /// Creates a new `LinearSearchIndex`
20    pub fn new<DT: Data<Elem = F>>(
21        batch: &'a ArrayBase<DT, Ix2>,
22        dist_fn: D,
23    ) -> Result<Self, BuildError> {
24        if batch.ncols() == 0 {
25            Err(BuildError::ZeroDimension)
26        } else {
27            Ok(Self(batch.view(), dist_fn))
28        }
29    }
30}
31
32impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for LinearSearchIndex<'_, F, D> {
33    fn k_nearest(&self, point: Point<'_, F>, k: usize) -> Result<Vec<(Point<F>, usize)>, NnError> {
34        if self.0.ncols() != point.len() {
35            Err(NnError::WrongDimension)
36        } else {
37            let mut heap = BinaryHeap::with_capacity(self.0.nrows());
38            for (i, pt) in self.0.rows().into_iter().enumerate() {
39                let dist = self.1.rdistance(point.reborrow(), pt.reborrow());
40                heap.push(MinHeapElem {
41                    elem: (pt.reborrow(), i),
42                    dist: Reverse(NoisyFloat::new(dist)),
43                });
44            }
45
46            Ok((0..k.min(heap.len()))
47                .map(|_| heap.pop().unwrap().elem)
48                .collect())
49        }
50    }
51
52    fn within_range(
53        &self,
54        point: Point<'_, F>,
55        range: F,
56    ) -> Result<Vec<(Point<F>, usize)>, NnError> {
57        if self.0.ncols() != point.len() {
58            Err(NnError::WrongDimension)
59        } else {
60            let range = self.1.dist_to_rdist(range);
61            Ok(self
62                .0
63                .rows()
64                .into_iter()
65                .enumerate()
66                .filter(|(_, pt)| self.1.rdistance(point.reborrow(), pt.reborrow()) < range)
67                .map(|(i, pt)| (pt, i))
68                .collect())
69        }
70    }
71}
72
73/// Implementation of linear search, which is the simplest nearest neighbour algorithm. All queries
74/// are implemented by scanning through every point, so all of them are `O(N)`. Calling
75/// `from_batch` returns a [`LinearSearchIndex`].
76#[derive(Default, Clone, Debug, PartialEq, Eq)]
77#[cfg_attr(
78    feature = "serde",
79    derive(Serialize, Deserialize),
80    serde(crate = "serde_crate")
81)]
82pub struct LinearSearch;
83
84impl LinearSearch {
85    /// Creates an instance of `LinearSearch`
86    pub fn new() -> Self {
87        Self
88    }
89}
90
91impl NearestNeighbour for LinearSearch {
92    fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
93        &self,
94        batch: &'a ArrayBase<DT, Ix2>,
95        leaf_size: usize,
96        dist_fn: D,
97    ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
98        if leaf_size == 0 {
99            return Err(BuildError::EmptyLeaf);
100        }
101        LinearSearchIndex::new(batch, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
102    }
103}