linfa_nn/
distance.rs

1use linfa::Float;
2use ndarray::{Array2, ArrayBase, ArrayView, Axis, Data, Dimension, Ix2, Zip};
3use ndarray_stats::DeviationExt;
4
5#[cfg(feature = "serde")]
6use serde_crate::{Deserialize, Serialize};
7
8/// A distance function that can be used in spatial algorithms such as nearest neighbour.
9pub trait Distance<F: Float>: Clone + Send + Sync + Unpin {
10    /// Computes the distance between two points. For most spatial algorithms to work correctly,
11    /// **this metric must satisfy the Triangle Inequality.**
12    ///
13    /// Panics if the points have different dimensions.
14    fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F;
15
16    /// A faster version of the distance metric that keeps the order of the distance function. That
17    /// is, `dist(a, b) > dist(c, d)` implies `rdist(a, b) > rdist(c, d)`. For most algorithms this
18    /// is the same as `distance`. Unlike `distance`, this function does **not** need to satisfy
19    /// the Triangle Inequality.
20    #[inline]
21    fn rdistance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
22        self.distance(a, b)
23    }
24
25    /// Converts the result of `rdistance` to `distance`
26    #[inline]
27    fn rdist_to_dist(&self, rdist: F) -> F {
28        rdist
29    }
30
31    /// Converts the result of `distance` to `rdistance`
32    #[inline]
33    fn dist_to_rdist(&self, dist: F) -> F {
34        dist
35    }
36}
37
38/// L1 or [Manhattan](https://en.wikipedia.org/wiki/Taxicab_geometry) distance
39#[cfg_attr(
40    feature = "serde",
41    derive(Serialize, Deserialize),
42    serde(crate = "serde_crate")
43)]
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct L1Dist;
46impl<F: Float> Distance<F> for L1Dist {
47    #[inline]
48    fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
49        a.l1_dist(&b).unwrap()
50    }
51}
52
53/// L2 or [Euclidean](https://en.wikipedia.org/wiki/Euclidean_distance) distance
54#[cfg_attr(
55    feature = "serde",
56    derive(Serialize, Deserialize),
57    serde(crate = "serde_crate")
58)]
59#[derive(Debug, Clone, PartialEq, Eq)]
60pub struct L2Dist;
61impl<F: Float> Distance<F> for L2Dist {
62    #[inline]
63    fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
64        F::from(a.l2_dist(&b).unwrap()).unwrap()
65    }
66
67    #[inline]
68    fn rdistance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
69        F::from(a.sq_l2_dist(&b).unwrap()).unwrap()
70    }
71
72    #[inline]
73    fn rdist_to_dist(&self, rdist: F) -> F {
74        rdist.sqrt()
75    }
76
77    #[inline]
78    fn dist_to_rdist(&self, dist: F) -> F {
79        dist.powi(2)
80    }
81}
82
83/// L-infinte or [Chebyshev](https://en.wikipedia.org/wiki/Chebyshev_distance) distance
84#[cfg_attr(
85    feature = "serde",
86    derive(Serialize, Deserialize),
87    serde(crate = "serde_crate")
88)]
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct LInfDist;
91impl<F: Float> Distance<F> for LInfDist {
92    #[inline]
93    fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
94        a.linf_dist(&b).unwrap()
95    }
96}
97
98/// L-p or [Minkowsky](https://en.wikipedia.org/wiki/Minkowski_distance) distance
99#[cfg_attr(
100    feature = "serde",
101    derive(Serialize, Deserialize),
102    serde(crate = "serde_crate")
103)]
104#[derive(Debug, Clone, PartialEq)]
105pub struct LpDist<F: Float>(pub F);
106impl<F: Float> LpDist<F> {
107    pub fn new(p: F) -> Self {
108        LpDist(p)
109    }
110}
111impl<F: Float> Distance<F> for LpDist<F> {
112    #[inline]
113    fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
114        Zip::from(&a)
115            .and(&b)
116            .fold(F::zero(), |acc, &a, &b| acc + (a - b).abs().powf(self.0))
117            .powf(F::one() / self.0)
118    }
119}
120
121/// Computes a similarity matrix with gaussian kernel and scaling parameter `eps`
122///
123/// The generated matrix is a upper triangular matrix with dimension NxN (number of observations) and contains the similarity between all permutations of observations
124/// similarity
125pub fn to_gaussian_similarity<F: Float>(
126    observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
127    eps: F,
128    dist_fn: &impl Distance<F>,
129) -> Array2<F> {
130    let n_observations = observations.len_of(Axis(0));
131    let mut similarity = Array2::eye(n_observations);
132
133    for i in 0..n_observations {
134        for j in 0..n_observations {
135            let a = observations.row(i);
136            let b = observations.row(j);
137
138            let distance = dist_fn.distance(a, b);
139            similarity[(i, j)] = (-distance / eps).exp();
140        }
141    }
142
143    similarity
144}
145
146#[cfg(test)]
147mod test {
148    use approx::assert_abs_diff_eq;
149    use ndarray::arr1;
150
151    use super::*;
152
153    #[test]
154    fn autotraits() {
155        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
156        has_autotraits::<L1Dist>();
157        has_autotraits::<L2Dist>();
158        has_autotraits::<LInfDist>();
159        has_autotraits::<LpDist<f64>>();
160    }
161
162    fn dist_test<D: Distance<f64>>(dist: D, result: f64) {
163        let a = arr1(&[0.5, 6.6]);
164        let b = arr1(&[4.4, 3.0]);
165        let ab = dist.distance(a.view(), b.view());
166        assert_abs_diff_eq!(ab, result, epsilon = 1e-3);
167        assert_abs_diff_eq!(dist.rdist_to_dist(dist.dist_to_rdist(ab)), ab);
168
169        let a = arr1(&[f64::INFINITY, 6.6]);
170        let b = arr1(&[4.4, f64::NEG_INFINITY]);
171        assert!(dist.distance(a.view(), b.view()).is_infinite());
172
173        // Triangle equality
174        let a = arr1(&[0.5, 6.6]);
175        let b = arr1(&[4.4, 3.0]);
176        let c = arr1(&[-4.5, 3.3]);
177        let ab = dist.distance(a.view(), b.view());
178        let bc = dist.distance(b.view(), c.view());
179        let ac = dist.distance(a.view(), c.view());
180        assert!(ab + bc > ac)
181    }
182
183    #[test]
184    fn l1_dist() {
185        dist_test(L1Dist, 7.5);
186    }
187
188    #[test]
189    fn l2_dist() {
190        dist_test(L2Dist, 5.3075);
191
192        // Check squared distance
193        let a = arr1(&[0.5, 6.6]);
194        let b = arr1(&[4.4, 3.0]);
195        assert_abs_diff_eq!(L2Dist.rdistance(a.view(), b.view()), 28.17, epsilon = 1e-3);
196    }
197
198    #[test]
199    fn linf_dist() {
200        dist_test(LInfDist, 3.9);
201    }
202
203    #[test]
204    fn lp_dist() {
205        dist_test(LpDist(3.3), 4.635);
206    }
207}