1use linfa::Float;
2use ndarray::{Array2, ArrayBase, ArrayView, Axis, Data, Dimension, Ix2, Zip};
3use ndarray_stats::DeviationExt;
4
5#[cfg(feature = "serde")]
6use serde_crate::{Deserialize, Serialize};
7
8pub trait Distance<F: Float>: Clone + Send + Sync + Unpin {
10 fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F;
15
16 #[inline]
21 fn rdistance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
22 self.distance(a, b)
23 }
24
25 #[inline]
27 fn rdist_to_dist(&self, rdist: F) -> F {
28 rdist
29 }
30
31 #[inline]
33 fn dist_to_rdist(&self, dist: F) -> F {
34 dist
35 }
36}
37
38#[cfg_attr(
40 feature = "serde",
41 derive(Serialize, Deserialize),
42 serde(crate = "serde_crate")
43)]
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct L1Dist;
46impl<F: Float> Distance<F> for L1Dist {
47 #[inline]
48 fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
49 a.l1_dist(&b).unwrap()
50 }
51}
52
53#[cfg_attr(
55 feature = "serde",
56 derive(Serialize, Deserialize),
57 serde(crate = "serde_crate")
58)]
59#[derive(Debug, Clone, PartialEq, Eq)]
60pub struct L2Dist;
61impl<F: Float> Distance<F> for L2Dist {
62 #[inline]
63 fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
64 F::from(a.l2_dist(&b).unwrap()).unwrap()
65 }
66
67 #[inline]
68 fn rdistance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
69 F::from(a.sq_l2_dist(&b).unwrap()).unwrap()
70 }
71
72 #[inline]
73 fn rdist_to_dist(&self, rdist: F) -> F {
74 rdist.sqrt()
75 }
76
77 #[inline]
78 fn dist_to_rdist(&self, dist: F) -> F {
79 dist.powi(2)
80 }
81}
82
83#[cfg_attr(
85 feature = "serde",
86 derive(Serialize, Deserialize),
87 serde(crate = "serde_crate")
88)]
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct LInfDist;
91impl<F: Float> Distance<F> for LInfDist {
92 #[inline]
93 fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
94 a.linf_dist(&b).unwrap()
95 }
96}
97
98#[cfg_attr(
100 feature = "serde",
101 derive(Serialize, Deserialize),
102 serde(crate = "serde_crate")
103)]
104#[derive(Debug, Clone, PartialEq)]
105pub struct LpDist<F: Float>(pub F);
106impl<F: Float> LpDist<F> {
107 pub fn new(p: F) -> Self {
108 LpDist(p)
109 }
110}
111impl<F: Float> Distance<F> for LpDist<F> {
112 #[inline]
113 fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
114 Zip::from(&a)
115 .and(&b)
116 .fold(F::zero(), |acc, &a, &b| acc + (a - b).abs().powf(self.0))
117 .powf(F::one() / self.0)
118 }
119}
120
121pub fn to_gaussian_similarity<F: Float>(
126 observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
127 eps: F,
128 dist_fn: &impl Distance<F>,
129) -> Array2<F> {
130 let n_observations = observations.len_of(Axis(0));
131 let mut similarity = Array2::eye(n_observations);
132
133 for i in 0..n_observations {
134 for j in 0..n_observations {
135 let a = observations.row(i);
136 let b = observations.row(j);
137
138 let distance = dist_fn.distance(a, b);
139 similarity[(i, j)] = (-distance / eps).exp();
140 }
141 }
142
143 similarity
144}
145
146#[cfg(test)]
147mod test {
148 use approx::assert_abs_diff_eq;
149 use ndarray::arr1;
150
151 use super::*;
152
153 #[test]
154 fn autotraits() {
155 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
156 has_autotraits::<L1Dist>();
157 has_autotraits::<L2Dist>();
158 has_autotraits::<LInfDist>();
159 has_autotraits::<LpDist<f64>>();
160 }
161
162 fn dist_test<D: Distance<f64>>(dist: D, result: f64) {
163 let a = arr1(&[0.5, 6.6]);
164 let b = arr1(&[4.4, 3.0]);
165 let ab = dist.distance(a.view(), b.view());
166 assert_abs_diff_eq!(ab, result, epsilon = 1e-3);
167 assert_abs_diff_eq!(dist.rdist_to_dist(dist.dist_to_rdist(ab)), ab);
168
169 let a = arr1(&[f64::INFINITY, 6.6]);
170 let b = arr1(&[4.4, f64::NEG_INFINITY]);
171 assert!(dist.distance(a.view(), b.view()).is_infinite());
172
173 let a = arr1(&[0.5, 6.6]);
175 let b = arr1(&[4.4, 3.0]);
176 let c = arr1(&[-4.5, 3.3]);
177 let ab = dist.distance(a.view(), b.view());
178 let bc = dist.distance(b.view(), c.view());
179 let ac = dist.distance(a.view(), c.view());
180 assert!(ab + bc > ac)
181 }
182
183 #[test]
184 fn l1_dist() {
185 dist_test(L1Dist, 7.5);
186 }
187
188 #[test]
189 fn l2_dist() {
190 dist_test(L2Dist, 5.3075);
191
192 let a = arr1(&[0.5, 6.6]);
194 let b = arr1(&[4.4, 3.0]);
195 assert_abs_diff_eq!(L2Dist.rdistance(a.view(), b.view()), 28.17, epsilon = 1e-3);
196 }
197
198 #[test]
199 fn linf_dist() {
200 dist_test(LInfDist, 3.9);
201 }
202
203 #[test]
204 fn lp_dist() {
205 dist_test(LpDist(3.3), 4.635);
206 }
207}