linfa_nn/
lib.rs

1//! `linfa-nn` provides Rust implementations of common spatial indexing algorithms, as well as a
2//! trait-based interface for performing nearest-neighbour and range queries using these
3//! algorithms.
4//!
5//! ## The big picture
6//!
7//! `linfa-nn` is a crate in the `linfa` ecosystem, a wider effort to
8//! bootstrap a toolkit for classical Machine Learning implemented in pure Rust,
9//! kin in spirit to Python's `scikit-learn`.
10//!
11//! You can find a roadmap (and a selection of good first issues)
12//! [here](https://github.com/LukeMathWalker/linfa/issues) - contributors are more than welcome!
13//!
14//! ## Current state
15//!
16//! Right now `linfa-nn` provides the following algorithms:
17//! * [Linear Scan](LinearSearch)
18//! * [KD Tree](KdTree)
19//! * [Ball Tree](BallTree)
20//!
21//! The [`CommonNearestNeighbour`](struct.CommonNearestNeighbour) enum should be used to dispatch
22//! between all of the above algorithms flexibly.
23
24use distance::Distance;
25use linfa::Float;
26use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
27#[cfg(feature = "serde")]
28use serde_crate::{Deserialize, Serialize};
29use thiserror::Error;
30
31mod balltree;
32mod heap_elem;
33mod kdtree;
34mod linear;
35
36pub mod distance;
37
38pub use crate::{balltree::*, kdtree::*, linear::*};
39
40pub(crate) type Point<'a, F> = ArrayView1<'a, F>;
41pub(crate) type NearestNeighbourBox<'a, F> = Box<dyn 'a + Send + Sync + NearestNeighbourIndex<F>>;
42
43/// Error returned when building nearest neighbour indices
44#[derive(Error, Debug)]
45pub enum BuildError {
46    #[error("points have dimension of 0")]
47    ZeroDimension,
48    #[error("leaf size is 0")]
49    EmptyLeaf,
50}
51
52/// Error returned when performing spatial queries on nearest neighbour indices
53#[derive(Error, Debug)]
54pub enum NnError {
55    #[error("dimensions of query point and stored points are different")]
56    WrongDimension,
57}
58
59/// Nearest neighbour algorithm builds a spatial index structure out of a batch of points. The
60/// distance between points is calculated using a provided distance function. The index implements
61/// the [`NearestNeighbourIndex`] trait and allows for efficient
62/// computing of nearest neighbour and range queries.
63pub trait NearestNeighbour: std::fmt::Debug + Send + Sync + Unpin {
64    /// Builds a spatial index using a MxN two-dimensional array representing M points with N
65    /// dimensions. Also takes `leaf_size`, which specifies the number of elements in the leaf
66    /// nodes of tree-like index structures.
67    ///
68    /// Returns an error if the points have dimensionality of 0 or if the leaf size is 0. If any
69    /// value in the batch is NaN or infinite, the behaviour is unspecified.
70    #[allow(clippy::wrong_self_convention)]
71    fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
72        &self,
73        batch: &'a ArrayBase<DT, Ix2>,
74        leaf_size: usize,
75        dist_fn: D,
76    ) -> Result<NearestNeighbourBox<'a, F>, BuildError>;
77
78    /// Builds a spatial index using a default leaf size. See `from_batch_with_leaf_size` for more
79    /// information.
80    #[allow(clippy::wrong_self_convention)]
81    fn from_batch<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
82        &self,
83        batch: &'a ArrayBase<DT, Ix2>,
84        dist_fn: D,
85    ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
86        self.from_batch_with_leaf_size(batch, 2usize.pow(4), dist_fn)
87    }
88}
89
90/// A spatial index structure over a set of points, created by `NearestNeighbour`. Allows efficient
91/// computation of nearest neighbour and range queries over the set of points. Individual points
92/// are represented as one-dimensional array views.
93pub trait NearestNeighbourIndex<F: Float>: Send + Sync + Unpin {
94    /// Returns the `k` points in the index that are the closest to the provided point, along with
95    /// their positions in the original dataset. Points are returned in ascending order of the
96    /// distance away from the provided points, and less than `k` points will be returned if the
97    /// index contains fewer than `k`.
98    ///
99    /// Returns an error if the provided point has different dimensionality than the index's
100    /// points.
101    fn k_nearest(&self, point: Point<'_, F>, k: usize) -> Result<Vec<(Point<F>, usize)>, NnError>;
102
103    /// Returns all the points in the index that are within the specified distance to the provided
104    /// point, along with their positions in the original dataset. The points are not guaranteed to
105    /// be in any order, though many algorithms return the points in order of distance.
106    ///
107    /// Returns an error if the provided point has different dimensionality than the index's
108    /// points.
109    fn within_range(
110        &self,
111        point: Point<'_, F>,
112        range: F,
113    ) -> Result<Vec<(Point<F>, usize)>, NnError>;
114}
115
116/// Enum that dispatches to one of the crate's [`NearestNeighbour`]
117/// implementations based on value. This enum should be used instead of using types like
118/// `LinearSearch` and `KdTree` directly.
119///
120/// ## Example
121///
122/// ```rust
123/// use rand_xoshiro::Xoshiro256Plus;
124/// use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
125/// use ndarray::{Array1, Array2};
126/// use linfa_nn::{distance::*, CommonNearestNeighbour, NearestNeighbour};
127///
128/// // Use seedable RNG for generating points
129/// let mut rng = Xoshiro256Plus::seed_from_u64(40);
130/// let n_features = 3;
131/// let distr = Uniform::new(-500., 500.);
132/// // Randomly generate points for building the index
133/// let points = Array2::random_using((5000, n_features), distr, &mut rng);
134///
135/// // Build a K-D tree with Euclidean distance as the distance function
136/// let nn = CommonNearestNeighbour::KdTree.from_batch(&points, L2Dist).unwrap();
137///
138/// let pt = Array1::random_using(n_features, distr, &mut rng);
139/// // Compute the 10 nearest points to `pt` in the index
140/// let nearest = nn.k_nearest(pt.view(), 10).unwrap();
141/// // Compute all points within 100 units of `pt`
142/// let range = nn.within_range(pt.view(), 100.0).unwrap();
143/// ```
144#[non_exhaustive]
145#[derive(Debug, Clone, PartialEq, Eq)]
146#[cfg_attr(
147    feature = "serde",
148    derive(Serialize, Deserialize),
149    serde(crate = "serde_crate")
150)]
151pub enum CommonNearestNeighbour {
152    /// Linear search
153    LinearSearch,
154    /// KD Tree
155    KdTree,
156    /// Ball Tree
157    BallTree,
158}
159
160impl NearestNeighbour for CommonNearestNeighbour {
161    fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
162        &self,
163        batch: &'a ArrayBase<DT, Ix2>,
164        leaf_size: usize,
165        dist_fn: D,
166    ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
167        match self {
168            Self::LinearSearch => LinearSearch.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
169            Self::KdTree => KdTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
170            Self::BallTree => BallTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
171        }
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn autotraits() {
181        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
182        has_autotraits::<CommonNearestNeighbour>();
183        has_autotraits::<NearestNeighbourBox<'static, f64>>();
184        has_autotraits::<BuildError>();
185        has_autotraits::<NnError>();
186    }
187}