linfa_nn/lib.rs
1//! `linfa-nn` provides Rust implementations of common spatial indexing algorithms, as well as a
2//! trait-based interface for performing nearest-neighbour and range queries using these
3//! algorithms.
4//!
5//! ## The big picture
6//!
7//! `linfa-nn` is a crate in the `linfa` ecosystem, a wider effort to
8//! bootstrap a toolkit for classical Machine Learning implemented in pure Rust,
9//! kin in spirit to Python's `scikit-learn`.
10//!
11//! You can find a roadmap (and a selection of good first issues)
12//! [here](https://github.com/LukeMathWalker/linfa/issues) - contributors are more than welcome!
13//!
14//! ## Current state
15//!
16//! Right now `linfa-nn` provides the following algorithms:
17//! * [Linear Scan](LinearSearch)
18//! * [KD Tree](KdTree)
19//! * [Ball Tree](BallTree)
20//!
21//! The [`CommonNearestNeighbour`](struct.CommonNearestNeighbour) enum should be used to dispatch
22//! between all of the above algorithms flexibly.
23
24use distance::Distance;
25use linfa::Float;
26use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
27#[cfg(feature = "serde")]
28use serde_crate::{Deserialize, Serialize};
29use thiserror::Error;
30
31mod balltree;
32mod heap_elem;
33mod kdtree;
34mod linear;
35
36pub mod distance;
37
38pub use crate::{balltree::*, kdtree::*, linear::*};
39
40pub(crate) type Point<'a, F> = ArrayView1<'a, F>;
41pub(crate) type NearestNeighbourBox<'a, F> = Box<dyn 'a + Send + Sync + NearestNeighbourIndex<F>>;
42
43/// Error returned when building nearest neighbour indices
44#[derive(Error, Debug)]
45pub enum BuildError {
46 #[error("points have dimension of 0")]
47 ZeroDimension,
48 #[error("leaf size is 0")]
49 EmptyLeaf,
50}
51
52/// Error returned when performing spatial queries on nearest neighbour indices
53#[derive(Error, Debug)]
54pub enum NnError {
55 #[error("dimensions of query point and stored points are different")]
56 WrongDimension,
57}
58
59/// Nearest neighbour algorithm builds a spatial index structure out of a batch of points. The
60/// distance between points is calculated using a provided distance function. The index implements
61/// the [`NearestNeighbourIndex`] trait and allows for efficient
62/// computing of nearest neighbour and range queries.
63pub trait NearestNeighbour: std::fmt::Debug + Send + Sync + Unpin {
64 /// Builds a spatial index using a MxN two-dimensional array representing M points with N
65 /// dimensions. Also takes `leaf_size`, which specifies the number of elements in the leaf
66 /// nodes of tree-like index structures.
67 ///
68 /// Returns an error if the points have dimensionality of 0 or if the leaf size is 0. If any
69 /// value in the batch is NaN or infinite, the behaviour is unspecified.
70 #[allow(clippy::wrong_self_convention)]
71 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
72 &self,
73 batch: &'a ArrayBase<DT, Ix2>,
74 leaf_size: usize,
75 dist_fn: D,
76 ) -> Result<NearestNeighbourBox<'a, F>, BuildError>;
77
78 /// Builds a spatial index using a default leaf size. See `from_batch_with_leaf_size` for more
79 /// information.
80 #[allow(clippy::wrong_self_convention)]
81 fn from_batch<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
82 &self,
83 batch: &'a ArrayBase<DT, Ix2>,
84 dist_fn: D,
85 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
86 self.from_batch_with_leaf_size(batch, 2usize.pow(4), dist_fn)
87 }
88}
89
90/// A spatial index structure over a set of points, created by `NearestNeighbour`. Allows efficient
91/// computation of nearest neighbour and range queries over the set of points. Individual points
92/// are represented as one-dimensional array views.
93pub trait NearestNeighbourIndex<F: Float>: Send + Sync + Unpin {
94 /// Returns the `k` points in the index that are the closest to the provided point, along with
95 /// their positions in the original dataset. Points are returned in ascending order of the
96 /// distance away from the provided points, and less than `k` points will be returned if the
97 /// index contains fewer than `k`.
98 ///
99 /// Returns an error if the provided point has different dimensionality than the index's
100 /// points.
101 fn k_nearest(&self, point: Point<'_, F>, k: usize) -> Result<Vec<(Point<F>, usize)>, NnError>;
102
103 /// Returns all the points in the index that are within the specified distance to the provided
104 /// point, along with their positions in the original dataset. The points are not guaranteed to
105 /// be in any order, though many algorithms return the points in order of distance.
106 ///
107 /// Returns an error if the provided point has different dimensionality than the index's
108 /// points.
109 fn within_range(
110 &self,
111 point: Point<'_, F>,
112 range: F,
113 ) -> Result<Vec<(Point<F>, usize)>, NnError>;
114}
115
116/// Enum that dispatches to one of the crate's [`NearestNeighbour`]
117/// implementations based on value. This enum should be used instead of using types like
118/// `LinearSearch` and `KdTree` directly.
119///
120/// ## Example
121///
122/// ```rust
123/// use rand_xoshiro::Xoshiro256Plus;
124/// use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
125/// use ndarray::{Array1, Array2};
126/// use linfa_nn::{distance::*, CommonNearestNeighbour, NearestNeighbour};
127///
128/// // Use seedable RNG for generating points
129/// let mut rng = Xoshiro256Plus::seed_from_u64(40);
130/// let n_features = 3;
131/// let distr = Uniform::new(-500., 500.);
132/// // Randomly generate points for building the index
133/// let points = Array2::random_using((5000, n_features), distr, &mut rng);
134///
135/// // Build a K-D tree with Euclidean distance as the distance function
136/// let nn = CommonNearestNeighbour::KdTree.from_batch(&points, L2Dist).unwrap();
137///
138/// let pt = Array1::random_using(n_features, distr, &mut rng);
139/// // Compute the 10 nearest points to `pt` in the index
140/// let nearest = nn.k_nearest(pt.view(), 10).unwrap();
141/// // Compute all points within 100 units of `pt`
142/// let range = nn.within_range(pt.view(), 100.0).unwrap();
143/// ```
144#[non_exhaustive]
145#[derive(Debug, Clone, PartialEq, Eq)]
146#[cfg_attr(
147 feature = "serde",
148 derive(Serialize, Deserialize),
149 serde(crate = "serde_crate")
150)]
151pub enum CommonNearestNeighbour {
152 /// Linear search
153 LinearSearch,
154 /// KD Tree
155 KdTree,
156 /// Ball Tree
157 BallTree,
158}
159
160impl NearestNeighbour for CommonNearestNeighbour {
161 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
162 &self,
163 batch: &'a ArrayBase<DT, Ix2>,
164 leaf_size: usize,
165 dist_fn: D,
166 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
167 match self {
168 Self::LinearSearch => LinearSearch.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
169 Self::KdTree => KdTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
170 Self::BallTree => BallTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn autotraits() {
181 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
182 has_autotraits::<CommonNearestNeighbour>();
183 has_autotraits::<NearestNeighbourBox<'static, f64>>();
184 has_autotraits::<BuildError>();
185 has_autotraits::<NnError>();
186 }
187}