1use crate::Float;
2use linfa_kernel::Kernel;
3use ndarray::Array1;
4
5pub trait Permutable<F: Float> {
6 fn swap_indices(&mut self, i: usize, j: usize);
7 fn distances(&self, idx: usize, length: usize) -> Vec<F>;
8 fn self_distance(&self, idx: usize) -> F;
9 fn inner(&self) -> &Kernel<F>;
10 fn into_inner(self) -> Kernel<F>;
11}
12
13#[derive(Debug, Clone, PartialEq)]
18pub struct PermutableKernel<F: Float> {
19 kernel: Kernel<F>,
20 kernel_diag: Array1<F>,
21 kernel_indices: Vec<usize>,
22 targets: Vec<bool>,
23}
24
25impl<F: Float> PermutableKernel<F> {
26 pub fn new(kernel: Kernel<F>, targets: Vec<bool>) -> PermutableKernel<F> {
27 let kernel_diag = kernel.diagonal();
28 let kernel_indices = (0..kernel.size()).collect::<Vec<_>>();
29
30 PermutableKernel {
31 kernel,
32 kernel_diag,
33 kernel_indices,
34 targets,
35 }
36 }
37}
38
39impl<F: Float> Permutable<F> for PermutableKernel<F> {
40 fn swap_indices(&mut self, i: usize, j: usize) {
42 self.kernel_indices.swap(i, j);
43 }
44
45 fn distances(&self, idx: usize, length: usize) -> Vec<F> {
47 let idx = self.kernel_indices[idx];
48
49 let kernel = self.kernel.column(idx);
50 let target_i = self.targets[idx];
51
52 (0..length)
54 .map(|j| {
55 let val = kernel[self.kernel_indices[j]];
56 let target_j = self.targets[self.kernel_indices[j]];
57
58 if target_j != target_i {
59 -val
60 } else {
61 val
62 }
63 })
64 .collect()
65 }
66
67 fn inner(&self) -> &Kernel<F> {
69 &self.kernel
70 }
71
72 fn into_inner(self) -> Kernel<F> {
74 self.kernel
75 }
76
77 fn self_distance(&self, idx: usize) -> F {
79 let idx = self.kernel_indices[idx];
80
81 self.kernel_diag[idx]
82 }
83}
84
85#[derive(Debug, Clone, PartialEq)]
86pub struct PermutableKernelOneClass<F: Float> {
87 kernel: Kernel<F>,
88 kernel_diag: Array1<F>,
89 kernel_indices: Vec<usize>,
90}
91
92impl<F: Float> PermutableKernelOneClass<F> {
93 pub fn new(kernel: Kernel<F>) -> PermutableKernelOneClass<F> {
94 let kernel_diag = kernel.diagonal();
95 let kernel_indices = (0..kernel.size()).collect::<Vec<_>>();
96
97 PermutableKernelOneClass {
98 kernel,
99 kernel_diag,
100 kernel_indices,
101 }
102 }
103}
104
105impl<F: Float> Permutable<F> for PermutableKernelOneClass<F> {
106 fn swap_indices(&mut self, i: usize, j: usize) {
108 self.kernel_indices.swap(i, j);
109 }
110
111 fn distances(&self, idx: usize, length: usize) -> Vec<F> {
113 let idx = self.kernel_indices[idx];
114
115 let kernel = self.kernel.column(idx);
116
117 (0..length)
119 .map(|j| kernel[self.kernel_indices[j]])
120 .collect()
121 }
122
123 fn inner(&self) -> &Kernel<F> {
125 &self.kernel
126 }
127
128 fn into_inner(self) -> Kernel<F> {
130 self.kernel
131 }
132
133 fn self_distance(&self, idx: usize) -> F {
135 let idx = self.kernel_indices[idx];
136
137 self.kernel_diag[idx]
138 }
139}
140
141#[derive(Debug, Clone, PartialEq)]
142pub struct PermutableKernelRegression<F: Float> {
143 kernel: Kernel<F>,
144 kernel_diag: Array1<F>,
145 kernel_indices: Vec<usize>,
146 signs: Vec<bool>,
147}
148
149impl<F: Float> PermutableKernelRegression<F> {
150 pub fn new(kernel: Kernel<F>) -> PermutableKernelRegression<F> {
151 let kernel_diag = kernel.diagonal();
152 let kernel_indices = (0..2 * kernel.size())
153 .map(|x| {
154 if x < kernel.size() {
155 x
156 } else {
157 x - kernel.size()
158 }
159 })
160 .collect::<Vec<_>>();
161 let signs = (0..kernel.size() * 2)
162 .map(|x| x < kernel.size())
163 .collect::<Vec<_>>();
164
165 PermutableKernelRegression {
166 kernel,
167 kernel_diag,
168 kernel_indices,
169 signs,
170 }
171 }
172}
173
174impl<F: Float> Permutable<F> for PermutableKernelRegression<F> {
175 fn swap_indices(&mut self, i: usize, j: usize) {
177 self.kernel_indices.swap(i, j);
178 self.signs.swap(i, j);
179 }
180
181 fn distances(&self, idx: usize, length: usize) -> Vec<F> {
183 let kernel = self.kernel.column(self.kernel_indices[idx]);
184
185 let sign_i = self.signs[idx];
187 (0..length)
188 .map(|j| {
189 let val = kernel[self.kernel_indices[j]];
190 let sign_j = self.signs[j];
191
192 if sign_i != sign_j {
193 -val
194 } else {
195 val
196 }
197 })
198 .collect()
199 }
200
201 fn inner(&self) -> &Kernel<F> {
203 &self.kernel
204 }
205
206 fn into_inner(self) -> Kernel<F> {
208 self.kernel
209 }
210
211 fn self_distance(&self, idx: usize) -> F {
213 let idx = self.kernel_indices[idx];
214
215 self.kernel_diag[idx]
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::{Permutable, PermutableKernel};
222 use crate::permutable_kernel::{PermutableKernelOneClass, PermutableKernelRegression};
223 use approx::assert_abs_diff_eq;
224 use linfa_kernel::{Kernel, KernelInner, KernelMethod};
225 use ndarray::array;
226
227 #[test]
228 fn autotraits() {
229 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
230 has_autotraits::<PermutableKernel<f64>>();
231 has_autotraits::<PermutableKernelOneClass<f64>>();
232 has_autotraits::<PermutableKernelRegression<f64>>();
233 }
234
235 #[test]
236 fn test_permutable_kernel() {
237 let dist = array![[1.0, 0.3, 0.1], [0.3, 1.0, 0.5], [0.1, 0.5, 1.0]];
238 let targets = vec![true, true, true];
239 let dist = Kernel {
240 inner: KernelInner::Dense(dist),
241 method: KernelMethod::Linear,
242 };
243
244 let mut kernel = PermutableKernel::new(dist, targets);
245
246 assert_abs_diff_eq!(*kernel.distances(0, 3), [1.0, 0.3, 0.1]);
247 assert_abs_diff_eq!(*kernel.distances(1, 3), [0.3, 1.0, 0.5]);
248 assert_abs_diff_eq!(*kernel.distances(2, 3), [0.1, 0.5, 1.0]);
249
250 kernel.swap_indices(0, 1);
252
253 assert_abs_diff_eq!(*kernel.distances(0, 3), [1.0, 0.3, 0.5]);
254 assert_abs_diff_eq!(*kernel.distances(1, 3), [0.3, 1.0, 0.1]);
255 assert_abs_diff_eq!(*kernel.distances(2, 3), [0.5, 0.1, 1.0]);
256
257 kernel.swap_indices(1, 2);
259
260 assert_abs_diff_eq!(*kernel.distances(0, 3), [1.0, 0.5, 0.3]);
261 assert_abs_diff_eq!(*kernel.distances(1, 3), [0.5, 1.0, 0.1]);
262 assert_abs_diff_eq!(*kernel.distances(2, 3), [0.3, 0.1, 1.0]);
263 }
264}