linfa_kernel/
lib.rs

1//! ## Kernel methods
2//!
3//! Kernel methods are a class of algorithms for pattern analysis, whose best known member is the
4//! [support vector machine](https://en.wikipedia.org/wiki/Support_vector_machine). They owe their name to the kernel functions,
5//! which maps the features to some higher-dimensional target space. Common examples for kernel
6//! functions are the radial basis function (euclidean distance) or polynomial kernels.
7//!
8//! ## Current State
9//!
10//! linfa-kernel currently provides an implementation of kernel methods for RBF and polynomial kernels,
11//! with sparse or dense representation. Further a k-neighbour approximation allows to reduce the kernel
12//! matrix size.
13//!
14//! Low-rank kernel approximation are currently missing, but are on the roadmap. Examples for these are the
15//! [Nyström approximation](https://www.jmlr.org/papers/volume6/drineas05a/drineas05a.pdf) or [Quasi Random Fourier Features](http://www-personal.umich.edu/~aniketde/processed_md/Stats608_Aniketde.pdf).
16
17pub mod inner;
18mod sparse;
19
20pub use inner::{Inner, KernelInner};
21use linfa_nn::CommonNearestNeighbour;
22use linfa_nn::NearestNeighbour;
23use ndarray::prelude::*;
24use ndarray::Data;
25#[cfg(feature = "serde")]
26use serde_crate::{Deserialize, Serialize};
27use sprs::{CsMat, CsMatView};
28use std::ops::Mul;
29
30use linfa::{
31    dataset::AsTargets, dataset::DatasetBase, dataset::FromTargetArray, dataset::Records,
32    traits::Transformer, Float,
33};
34
35/// Kernel representation, can be either dense or sparse
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum KernelType {
38    Dense,
39    /// A sparse kernel requires to define a number of neighbours
40    /// between 1 and the total number of samples in input minus one.
41    Sparse(usize),
42}
43
44/// A generic kernel
45#[cfg_attr(
46    feature = "serde",
47    derive(Serialize, Deserialize),
48    serde(crate = "serde_crate")
49)]
50#[derive(Debug, Clone, PartialEq)]
51pub struct KernelBase<K1: Inner, K2: Inner>
52where
53    K1::Elem: Float,
54    K2::Elem: Float,
55{
56    #[cfg_attr(
57        feature = "serde",
58        serde(bound(
59            serialize = "KernelInner<K1, K2>: Serialize",
60            deserialize = "KernelInner<K1, K2>: Deserialize<'de>"
61        ))
62    )]
63    pub inner: KernelInner<K1, K2>,
64    #[cfg_attr(
65        feature = "serde",
66        serde(bound(
67            serialize = "KernelMethod<K1::Elem>: Serialize",
68            deserialize = "KernelMethod<K1::Elem>: Deserialize<'de>"
69        ))
70    )]
71    /// The inner product that will be used by the kernel
72    pub method: KernelMethod<K1::Elem>,
73}
74
75/// Type definition of Kernel that owns its inner matrix
76pub type Kernel<F> = KernelBase<Array2<F>, CsMat<F>>;
77/// Type definition of Kernel that borrows its inner matrix
78pub type KernelView<'a, F> = KernelBase<ArrayView2<'a, F>, CsMatView<'a, F>>;
79
80impl<F: Float, K1: Inner<Elem = F>, K2: Inner<Elem = F>> KernelBase<K1, K2> {
81    /// Whether the kernel is a linear kernel
82    ///
83    /// ## Returns
84    ///
85    /// - `true`: if the kernel is linear
86    /// - `false`: otherwise
87    pub fn is_linear(&self) -> bool {
88        self.method.is_linear()
89    }
90
91    /// Generates the default set of parameters for building a kernel.
92    /// Use this to initialize a set of parameters to be customized using `KernelParams`'s methods
93    pub fn params() -> KernelParams<F, CommonNearestNeighbour> {
94        Self::params_with_nn(CommonNearestNeighbour::KdTree)
95    }
96
97    /// Generate parameters with a specific nearest neighbour algorithm
98    pub fn params_with_nn<N: NearestNeighbour>(nn_algo: N) -> KernelParams<F, N> {
99        KernelParams {
100            kind: KernelType::Dense,
101            method: KernelMethod::Gaussian(F::cast(0.5)),
102            nn_algo,
103        }
104    }
105
106    /// Performs the matrix product between the kernel matrix
107    /// and the input
108    ///
109    /// ## Parameters
110    ///
111    /// - `rhs`: The matrix on the right-hand side of the multiplication
112    ///
113    /// ## Returns
114    ///
115    /// A new matrix containing the matrix product between the kernel
116    /// and `rhs`
117    ///
118    /// ## Panics
119    ///
120    /// If the shapes of kernel and `rhs` are not compatible for multiplication
121    pub fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
122        match &self.inner {
123            KernelInner::Dense(inn) => inn.dot(rhs),
124            KernelInner::Sparse(inn) => inn.dot(rhs),
125        }
126    }
127
128    /// Sums all elements in the same row of the kernel matrix
129    ///
130    /// ## Returns
131    ///
132    /// A new array with the sum of all the elements in each row
133    pub fn sum(&self) -> Array1<F> {
134        match &self.inner {
135            KernelInner::Dense(inn) => inn.sum(),
136            KernelInner::Sparse(inn) => inn.sum(),
137        }
138    }
139
140    /// Gives the size of the side of the square kernel matrix
141    pub fn size(&self) -> usize {
142        match &self.inner {
143            KernelInner::Dense(inn) => inn.size(),
144            KernelInner::Sparse(inn) => inn.size(),
145        }
146    }
147
148    /// Getter for a column of the kernel matrix
149    ///
150    /// ## Params
151    ///
152    /// - `i`: the index of the column
153    ///
154    /// ## Returns
155    ///
156    /// The i-th column of the kernel matrix, stored as a `Vec`
157    ///
158    /// ## Panics
159    ///
160    /// If `i` is out of bounds
161    pub fn column(&self, i: usize) -> Vec<F> {
162        match &self.inner {
163            KernelInner::Dense(inn) => inn.column(i),
164            KernelInner::Sparse(inn) => inn.column(i),
165        }
166    }
167
168    /// Getter for the data in the upper triangle of the kernel
169    /// matrix
170    ///
171    /// ## Returns
172    ///
173    /// A copy of all elements in the upper triangle of the kernel
174    /// matrix, stored in a `Vec`
175    pub fn to_upper_triangle(&self) -> Vec<F> {
176        match &self.inner {
177            KernelInner::Dense(inn) => inn.to_upper_triangle(),
178            KernelInner::Sparse(inn) => inn.to_upper_triangle(),
179        }
180    }
181
182    /// Getter for the elements in the diagonal of the kernel matrix
183    ///
184    /// ## Returns
185    ///
186    /// A new array containing the copy of all elements in the diagonal fo
187    /// the kernel matrix
188    pub fn diagonal(&self) -> Array1<F> {
189        match &self.inner {
190            KernelInner::Dense(inn) => inn.diagonal(),
191            KernelInner::Sparse(inn) => inn.diagonal(),
192        }
193    }
194}
195
196impl<'a, F: Float> Kernel<F> {
197    pub fn new<N: NearestNeighbour>(
198        dataset: ArrayView2<'a, F>,
199        params: &KernelParams<F, N>,
200    ) -> Kernel<F> {
201        let inner = match params.kind {
202            KernelType::Dense => KernelInner::Dense(dense_from_fn(&dataset, &params.method)),
203            KernelType::Sparse(k) => {
204                KernelInner::Sparse(sparse_from_fn(&dataset, k, &params.method, &params.nn_algo))
205            }
206        };
207
208        Kernel {
209            inner,
210            method: params.method.clone(),
211        }
212    }
213
214    /// Gives a KernelView which has a view on the original kernel's inner matrix
215    pub fn view(&'a self) -> KernelView<'a, F> {
216        KernelView {
217            inner: match &self.inner {
218                KernelInner::Dense(inn) => KernelInner::Dense(inn.view()),
219                KernelInner::Sparse(inn) => KernelInner::Sparse(inn.view()),
220            },
221            method: self.method.clone(),
222        }
223    }
224}
225
226impl<F: Float> KernelView<'_, F> {
227    pub fn to_owned(&self) -> Kernel<F> {
228        Kernel {
229            inner: match &self.inner {
230                KernelInner::Dense(inn) => KernelInner::Dense(inn.to_owned()),
231                KernelInner::Sparse(inn) => KernelInner::Sparse(inn.to_owned()),
232            },
233            method: self.method.clone(),
234        }
235    }
236}
237
238impl<F: Float, K1: Inner<Elem = F>, K2: Inner<Elem = F>> Records for KernelBase<K1, K2> {
239    type Elem = F;
240
241    fn nsamples(&self) -> usize {
242        self.size()
243    }
244
245    fn nfeatures(&self) -> usize {
246        self.size()
247    }
248}
249
250/// The inner product definition used by a kernel.
251///
252/// There are three methods available:
253///
254/// - Gaussian(eps):  `d(x, x') = exp(-norm(x - x')/eps) `
255/// - Linear: `d(x, x') = <x, x'>`
256/// - Polynomial(constant, degree):  `d(x, x') = (<x, x'> + costant)^(degree)`
257#[cfg_attr(
258    feature = "serde",
259    derive(Serialize, Deserialize),
260    serde(crate = "serde_crate")
261)]
262#[derive(Debug, Clone, PartialEq)]
263pub enum KernelMethod<F> {
264    /// Gaussian(eps): exp(-norm(x - x')/eps)
265    Gaussian(F),
266    /// Euclidean inner product
267    Linear,
268    /// Polynomial(constant, degree):  ` (<x, x'> + costant)^(degree)`
269    Polynomial(F, F),
270}
271
272impl<F: Float> KernelMethod<F> {
273    pub fn distance(&self, a: ArrayView1<F>, b: ArrayView1<F>) -> F {
274        match *self {
275            KernelMethod::Gaussian(eps) => {
276                let distance = a
277                    .iter()
278                    .zip(b.iter())
279                    .map(|(x, y)| (*x - *y) * (*x - *y))
280                    .sum::<F>();
281
282                (-distance / eps).exp()
283            }
284            KernelMethod::Linear => a.mul(&b).sum(),
285            KernelMethod::Polynomial(c, d) => (a.mul(&b).sum() + c).powf(d),
286        }
287    }
288
289    pub fn is_linear(&self) -> bool {
290        matches!(*self, KernelMethod::Linear)
291    }
292}
293
294/// Defines the set of parameters needed to build a kernel
295#[derive(Debug, Clone, PartialEq)]
296pub struct KernelParams<F, N = CommonNearestNeighbour> {
297    /// Whether to construct a dense or sparse kernel
298    kind: KernelType,
299    /// The inner product used by the kernel
300    method: KernelMethod<F>,
301    /// Nearest neighbour algorithm for calculating adjacency matrices
302    nn_algo: N,
303}
304
305impl<F, N> KernelParams<F, N> {
306    /// Setter for `method`, the inner product used by the kernel
307    pub fn method(mut self, method: KernelMethod<F>) -> Self {
308        self.method = method;
309        self
310    }
311
312    /// Setter for `kind`, whether to construct a dense or sparse kernel
313    pub fn kind(mut self, kind: KernelType) -> Self {
314        self.kind = kind;
315        self
316    }
317
318    /// Setter for `nn_algo`, nearest neighbour algorithm for calculating adjacency matrices
319    pub fn nn_algo(mut self, nn_algo: N) -> Self {
320        self.nn_algo = nn_algo;
321        self
322    }
323}
324
325impl<F: Float, N: NearestNeighbour> Transformer<&Array2<F>, Kernel<F>> for KernelParams<F, N> {
326    /// Builds a kernel from a view of the input data.
327    ///
328    /// ## Parameters
329    ///
330    /// - `x`: view of a matrix of records (#records, #features)
331    ///
332    /// A kernel build from `x` according to the parameters on which
333    /// this method is called
334    ///
335    /// ## Panics
336    ///
337    /// If the kernel type is `Sparse` and the number of neighbors specified is
338    /// not between 1 and #records-1
339    fn transform(&self, x: &Array2<F>) -> Kernel<F> {
340        Kernel::new(x.view(), self)
341    }
342}
343
344impl<'a, F: Float, N: NearestNeighbour> Transformer<ArrayView2<'a, F>, Kernel<F>>
345    for KernelParams<F, N>
346{
347    /// Builds a kernel from a view of the input data.
348    ///
349    /// ## Parameters
350    ///
351    /// - `x`: view of a matrix of records (#records, #features)
352    ///
353    /// A kernel build from `x` according to the parameters on which
354    /// this method is called
355    ///
356    /// ## Panics
357    ///
358    /// If the kernel type is `Sparse` and the number of neighbors specified is
359    /// not between 1 and #records-1
360    fn transform(&self, x: ArrayView2<'a, F>) -> Kernel<F> {
361        Kernel::new(x, self)
362    }
363}
364
365impl<'a, F: Float, N: NearestNeighbour> Transformer<&ArrayView2<'a, F>, Kernel<F>>
366    for KernelParams<F, N>
367{
368    /// Builds a kernel from a view of the input data.
369    ///
370    /// ## Parameters
371    ///
372    /// - `x`: view of a matrix of records (#records, #features)
373    ///
374    /// A kernel build from `x` according to the parameters on which
375    /// this method is called
376    ///
377    /// ## Panics
378    ///
379    /// If the kernel type is `Sparse` and the number of neighbors specified is
380    /// not between 1 and #records-1
381    fn transform(&self, x: &ArrayView2<'a, F>) -> Kernel<F> {
382        Kernel::new(*x, self)
383    }
384}
385
386impl<F: Float, T: AsTargets, N: NearestNeighbour>
387    Transformer<DatasetBase<Array2<F>, T>, DatasetBase<Kernel<F>, T>> for KernelParams<F, N>
388{
389    /// Builds a new Dataset with the kernel as the records and the same targets as the input one.
390    ///
391    /// It takes ownership of the original dataset.
392    ///
393    /// ## Parameters
394    ///
395    /// - `x`: A dataset with a matrix of records (#records, #features) and any targets
396    ///
397    /// ## Returns
398    ///
399    /// A new dataset with:
400    ///  - records: a kernel build from `x.records()` according to the parameters on which
401    ///    this method is called
402    ///  - targets: same as `x.targets()`
403    ///
404    /// ## Panics
405    ///
406    /// If the kernel type is `Sparse` and the number of neighbors specified is
407    /// not between 1 and #records-1
408    fn transform(&self, x: DatasetBase<Array2<F>, T>) -> DatasetBase<Kernel<F>, T> {
409        let kernel = Kernel::new(x.records.view(), self);
410        DatasetBase::new(kernel, x.targets)
411    }
412}
413
414impl<'a, F: Float, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a>, N: NearestNeighbour>
415    Transformer<&'a DatasetBase<Array2<F>, T>, DatasetBase<Kernel<F>, T::View>>
416    for KernelParams<F, N>
417{
418    /// Builds a new Dataset with the kernel as the records and the same targets as the input one.
419    ///
420    /// ## Parameters
421    ///
422    /// - `x`: A dataset with a matrix of records (#records, #features) and any targets
423    ///
424    /// ## Returns
425    ///
426    /// A new dataset with:
427    ///  - records: a kernel build from `x.records()` according to the parameters on which
428    ///    this method is called
429    ///  - targets: same as `x.targets()`
430    ///
431    /// ## Panics
432    ///
433    /// If the kernel type is `Sparse` and the number of neighbors specified is
434    /// not between 1 and #records-1
435    fn transform(&self, x: &'a DatasetBase<Array2<F>, T>) -> DatasetBase<Kernel<F>, T::View> {
436        let kernel = Kernel::new(x.records.view(), self);
437        DatasetBase::new(kernel, T::new_targets_view(x.as_targets()))
438    }
439}
440
441// lifetime 'b allows the kernel to borrow the underlying data
442// for a possibly shorter time than 'a, useful in fold_fit
443impl<
444        'a,
445        'b,
446        F: Float,
447        L: 'b,
448        T: AsTargets<Elem = L> + FromTargetArray<'b>,
449        N: NearestNeighbour,
450    > Transformer<&'b DatasetBase<ArrayView2<'a, F>, T>, DatasetBase<Kernel<F>, T::View>>
451    for KernelParams<F, N>
452{
453    /// Builds a new Dataset with the kernel as the records and the same targets as the input one.
454    ///
455    /// ## Parameters
456    ///
457    /// - `x`: A dataset with a matrix of records (##records, ##features) and any targets
458    ///
459    /// ## Returns
460    ///
461    /// A new dataset with:
462    ///  - records: a kernel build from `x.records()` according to the parameters on which
463    ///    this method is called
464    ///  - targets: a slice of `x.targets()`
465    ///
466    /// ## Panics
467    ///
468    /// If the kernel type is `Sparse` and the number of neighbors specified is
469    /// not between 1 and ##records-1
470    fn transform(
471        &self,
472        x: &'b DatasetBase<ArrayView2<'a, F>, T>,
473    ) -> DatasetBase<Kernel<F>, T::View> {
474        let kernel = Kernel::new(x.records.view(), self);
475
476        DatasetBase::new(kernel, T::new_targets_view(x.as_targets()))
477    }
478}
479
480fn dense_from_fn<F: Float, D: Data<Elem = F>>(
481    dataset: &ArrayBase<D, Ix2>,
482    method: &KernelMethod<F>,
483) -> Array2<F> {
484    let n_observations = dataset.len_of(Axis(0));
485    let mut similarity = Array2::eye(n_observations);
486
487    for i in 0..n_observations {
488        for j in 0..n_observations {
489            let a = dataset.row(i);
490            let b = dataset.row(j);
491
492            similarity[(i, j)] = method.distance(a, b);
493        }
494    }
495
496    similarity
497}
498
499fn sparse_from_fn<F: Float, D: Data<Elem = F>, N: NearestNeighbour>(
500    dataset: &ArrayBase<D, Ix2>,
501    k: usize,
502    method: &KernelMethod<F>,
503    nn_algo: &N,
504) -> CsMat<F> {
505    // compute adjacency matrix between points in the input dataset:
506    // one point for each row
507    let mut data = sparse::adjacency_matrix(dataset, k, nn_algo);
508
509    // iterate through each row of the adjacency matrix where each
510    // row is represented by a vec containing a pair (col_index, value)
511    // for each non-zero element in the row
512    for (i, mut vec) in data.outer_iterator_mut().enumerate() {
513        // If there is a non-zero element in row i at index j
514        // then it means that points i and j in the input matrix are
515        // k-neighbours and their distance is stored in position (i,j)
516        for (j, val) in vec.iter_mut() {
517            let a = dataset.row(i);
518            let b = dataset.row(j);
519
520            *val = method.distance(a, b);
521        }
522    }
523    data
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use linfa::Dataset;
530    use linfa_nn::{BallTree, KdTree};
531    use ndarray::{Array1, Array2};
532    use std::f64::consts;
533
534    #[test]
535    fn autotraits() {
536        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
537        has_autotraits::<KernelType>();
538        has_autotraits::<KernelBase<ArrayView2<f64>, ArrayView2<f64>>>();
539        has_autotraits::<KernelMethod<f64>>();
540        has_autotraits::<KernelParams<f64, f64>>();
541        has_autotraits::<KernelView<f64>>();
542        has_autotraits::<KernelInner<ArrayView2<f64>, ArrayView2<f64>>>();
543        has_autotraits::<Kernel<f64>>();
544    }
545
546    #[test]
547    fn sparse_from_fn_test() {
548        // pts 0 & 1    pts 2 & 3    pts 4 & 5     pts 6 & 7
549        // |0.| |0.1| _ |1.| |1.1| _ |2.| |2.1| _  |3.| |3.1|
550        // |0.| |0.1|   |1.| |1.1|   |2.| |2.1|    |3.| |3.1|
551        let input_mat = vec![
552            0., 0., 0.1, 0.1, 1., 1., 1.1, 1.1, 2., 2., 2.1, 2.1, 3., 3., 3.1, 3.1,
553        ];
554        let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
555        let adj_mat = sparse_from_fn(&input_arr, 1, &KernelMethod::Linear, &KdTree);
556        assert_eq!(adj_mat.nnz(), 16);
557
558        // 2*0^2
559        assert_eq!(*adj_mat.get(0, 0).unwrap() as usize, 0);
560        // 2*0.1^2
561        assert_eq!((*adj_mat.get(1, 1).unwrap() * 100.) as usize, 2);
562        // 2*1^2
563        assert_eq!(*adj_mat.get(2, 2).unwrap() as usize, 2);
564        // 2*1.1^2
565        assert_eq!((*adj_mat.get(3, 3).unwrap() * 100.) as usize, 242);
566        // 2 * 2^2
567        assert_eq!(*adj_mat.get(4, 4).unwrap() as usize, 8);
568        // 2 * 2.1^2
569        assert_eq!((*adj_mat.get(5, 5).unwrap() * 100.) as usize, 882);
570        // 2 * 3^2
571        assert_eq!(*adj_mat.get(6, 6).unwrap() as usize, 18);
572        // 2 * 3.1^2
573        assert_eq!((*adj_mat.get(7, 7).unwrap() * 100.) as usize, 1922);
574
575        // 2*(0 * 0.1)
576        assert_eq!(*adj_mat.get(0, 1).unwrap() as usize, 0);
577        assert_eq!(*adj_mat.get(1, 0).unwrap() as usize, 0);
578
579        // 2*(1 * 1.1)
580        assert_eq!((*adj_mat.get(2, 3).unwrap() * 10.) as usize, 22);
581        assert_eq!((*adj_mat.get(3, 2).unwrap() * 10.) as usize, 22);
582
583        // 2*(2 * 2.1)
584        assert_eq!((*adj_mat.get(4, 5).unwrap() * 10.) as usize, 84);
585        assert_eq!((*adj_mat.get(5, 4).unwrap() * 10.) as usize, 84);
586
587        // 2*(3 * 3.1)
588        assert_eq!((*adj_mat.get(6, 7).unwrap() * 10.) as usize, 186);
589        assert_eq!((*adj_mat.get(7, 6).unwrap() * 10.) as usize, 186);
590    }
591
592    #[test]
593    fn dense_from_fn_test() {
594        // pts 0 & 1    pts 2 & 3    pts 4 & 5     pts 6 & 7
595        // |0.| |0.1| _ |1.| |1.1| _ |2.| |2.1| _  |3.| |3.1|
596        // |0.| |0.1|   |1.| |1.1|   |2.| |2.1|    |3.| |3.1|
597        let input_mat = vec![
598            0., 0., 0.1, 0.1, 1., 1., 1.1, 1.1, 2., 2., 2.1, 2.1, 3., 3., 3.1, 3.1,
599        ];
600        let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
601        let method: KernelMethod<f64> = KernelMethod::Linear;
602
603        let similarity_matrix = dense_from_fn(&input_arr, &method);
604
605        for i in 0..8 {
606            for j in 0..8 {
607                assert!(
608                    (similarity_matrix.row(i)[j]
609                        - method.distance(input_arr.row(i), input_arr.row(j)))
610                    .abs()
611                        <= f64::EPSILON
612                );
613            }
614        }
615    }
616
617    #[test]
618    fn gaussian_test() {
619        let gauss_1 = KernelMethod::Gaussian(1.);
620
621        let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
622        let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
623        let distance = gauss_1.distance(p1.view(), p2.view());
624        let expected = 1.;
625
626        assert!(f64::abs(distance - expected) <= f64::EPSILON);
627
628        let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
629        let p2 = Array1::from_shape_vec(2, vec![5., 5.]).unwrap();
630        let distance = gauss_1.distance(p1.view(), p2.view());
631        let expected = (consts::E).powf(-32.);
632        // this fails with e^-31 or e^-33 so f64::EPSILON still holds
633        assert!(f64::abs(distance - expected) <= f64::EPSILON);
634
635        let gauss_01 = KernelMethod::Gaussian(0.1);
636
637        let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
638        let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
639        let distance = gauss_01.distance(p1.view(), p2.view());
640        let expected = 1.;
641
642        assert!(f64::abs(distance - expected) <= f64::EPSILON);
643
644        let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
645        let p2 = Array1::from_shape_vec(2, vec![2., 2.]).unwrap();
646        let distance = gauss_01.distance(p1.view(), p2.view());
647        let expected = (consts::E).powf(-20.);
648
649        assert!(f64::abs(distance - expected) <= f64::EPSILON);
650    }
651
652    #[test]
653    fn poly2_test() {
654        let pol_0 = KernelMethod::Polynomial(0., 2.);
655
656        let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
657        let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
658        let distance = pol_0.distance(p1.view(), p2.view());
659        let expected = 0.;
660
661        assert!(f64::abs(distance - expected) <= f64::EPSILON);
662
663        let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
664        let p2 = Array1::from_shape_vec(2, vec![5., 5.]).unwrap();
665        let distance = pol_0.distance(p1.view(), p2.view());
666        let expected = 100.;
667        assert!(f64::abs(distance - expected) <= f64::EPSILON);
668
669        let pol_2 = KernelMethod::Polynomial(2., 2.);
670
671        let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
672        let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
673        let distance = pol_2.distance(p1.view(), p2.view());
674        let expected = 4.;
675
676        assert!(f64::abs(distance - expected) <= f64::EPSILON);
677
678        let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
679        let p2 = Array1::from_shape_vec(2, vec![2., 2.]).unwrap();
680        let distance = pol_2.distance(p1.view(), p2.view());
681        let expected = 36.;
682
683        assert!(f64::abs(distance - expected) <= f64::EPSILON);
684    }
685
686    #[test]
687    fn test_kernel_dot() {
688        let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
689        let vec_to_multiply: Vec<f64> = (0..100).map(|v| v as f64 * 0.3).collect();
690        let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
691        let to_multiply = Array2::from_shape_vec((10, 10), vec_to_multiply).unwrap();
692
693        // dense kernel dot
694        let mul_mat = dense_from_fn(&input_arr, &KernelMethod::Linear).dot(&to_multiply);
695        let kernel = KernelView::params()
696            .kind(KernelType::Dense)
697            .method(KernelMethod::Linear)
698            .transform(input_arr.view());
699        let mul_ker = kernel.dot(&to_multiply.view());
700        assert!(matrices_almost_equal(mul_mat.view(), mul_ker.view()));
701
702        // sparse kernel dot
703        let mul_mat =
704            sparse_from_fn(&input_arr, 3, &KernelMethod::Linear, &KdTree).mul(&to_multiply.view());
705        let kernel = KernelView::params()
706            .kind(KernelType::Sparse(3))
707            .method(KernelMethod::Linear)
708            .transform(input_arr.view());
709        let mul_ker = kernel.dot(&to_multiply.view());
710        assert!(matrices_almost_equal(mul_mat.view(), mul_ker.view()));
711    }
712
713    #[test]
714    fn test_kernel_upper_triangle() {
715        // symmetric vec, kernel matrix is a "cross" of ones
716        let input_vec: Vec<f64> = (0..50).map(|v| v as f64 * 0.1).collect();
717        let input_arr_1 = Array2::from_shape_vec((5, 10), input_vec.clone()).unwrap();
718        let mut input_arr_2 = Array2::from_shape_vec((5, 10), input_vec).unwrap();
719        input_arr_2.invert_axis(Axis(0));
720        let input_arr =
721            ndarray::concatenate(Axis(0), &[input_arr_1.view(), input_arr_2.view()]).unwrap();
722
723        for kind in [KernelType::Dense, KernelType::Sparse(1)] {
724            let kernel = KernelView::params()
725                .kind(kind)
726                // Such a value for eps brings to zero the inner product
727                // between any two points that are not equal
728                .method(KernelMethod::Gaussian(1e-5))
729                .transform(input_arr.view());
730            let mut kernel_upper_triang = kernel.to_upper_triangle();
731            assert_eq!(kernel_upper_triang.len(), 45);
732            //so that i can use pop()
733            kernel_upper_triang.reverse();
734            for i in 0..9 {
735                for j in (i + 1)..10 {
736                    if j == (9 - i) {
737                        assert_eq!(kernel_upper_triang.pop().unwrap() as usize, 1);
738                    } else {
739                        assert_eq!(kernel_upper_triang.pop().unwrap() as usize, 0);
740                    }
741                }
742            }
743            assert!(kernel_upper_triang.is_empty());
744        }
745    }
746
747    /* #[test]
748    fn test_kernel_weighted_sum() {
749        let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
750        let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
751        let weights = [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.];
752        for kind in vec![KernelType::Dense, KernelType::Sparse(1)] {
753            let kernel = KernelView::params()
754                .kind(kind)
755                // Such a value for eps brings to zero the inner product
756                // between any two points that are not equal
757                .method(KernelMethod::Gaussian(1e-5))
758                .transform(input_arr.view());
759            for (sample, w) in input_arr.outer_iter().zip(&weights) {
760                // with that kernel, only the input samples have non
761                // zero inner product with the samples used to generate the matrix.
762                // In particular, they have inner product equal to one only for the
763                // column corresponding to themselves
764                //let w_sum = kernel.weighted_sum(&weights, sample);
765                //assert!(values_almost_equal(&w_sum, w));
766            }
767        }
768    }*/
769
770    #[test]
771    fn test_kernel_sum() {
772        let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
773        let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
774
775        let method = KernelMethod::Linear;
776
777        // dense kernel sum
778        let cols_sum = dense_from_fn(&input_arr, &method).sum_axis(Axis(1));
779        let kernel = KernelView::params()
780            .kind(KernelType::Dense)
781            .method(method.clone())
782            .transform(input_arr.view());
783        let kers_sum = kernel.sum();
784        assert!(arrays_almost_equal(cols_sum.view(), kers_sum.view()));
785
786        // sparse kernel sum
787        let cols_sum = sparse_from_fn(&input_arr, 3, &method, &BallTree)
788            .to_dense()
789            .sum_axis(Axis(1));
790        let kernel = KernelView::params()
791            .kind(KernelType::Sparse(3))
792            .method(method)
793            .transform(input_arr.view());
794        let kers_sum = kernel.sum();
795        assert!(arrays_almost_equal(cols_sum.view(), kers_sum.view()));
796    }
797
798    #[test]
799    fn test_kernel_diag() {
800        let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
801        let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
802
803        let method = KernelMethod::Linear;
804
805        // dense kernel diag
806        let input_diagonal = dense_from_fn(&input_arr, &method).diag().into_owned();
807        let kernel = KernelView::params()
808            .kind(KernelType::Dense)
809            .method(method.clone())
810            .transform(input_arr.view());
811        let kers_diagonal = kernel.diagonal();
812        assert!(arrays_almost_equal(
813            input_diagonal.view(),
814            kers_diagonal.view()
815        ));
816
817        // sparse kernel diag
818        let input_diagonal: Vec<_> = sparse_from_fn(&input_arr, 3, &method, &BallTree)
819            .outer_iterator()
820            .enumerate()
821            .map(|(i, row)| *row.get(i).unwrap())
822            .collect();
823        let input_diagonal = Array1::from_shape_vec(10, input_diagonal).unwrap();
824        let kernel = KernelView::params()
825            .kind(KernelType::Sparse(3))
826            .method(method)
827            .transform(input_arr.view());
828        let kers_diagonal = kernel.diagonal();
829        assert!(arrays_almost_equal(
830            input_diagonal.view(),
831            kers_diagonal.view()
832        ));
833    }
834
835    // inspired from scikit learn's tests
836    #[test]
837    fn test_kernel_transform_from_array2() {
838        let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
839        let input = Array2::from_shape_vec((50, 2), input_vec).unwrap();
840        // checks that the transform for Array2 builds the right kernel
841        // according to its input params.
842        check_kernel_from_array2_type(&input, KernelType::Dense);
843        check_kernel_from_array2_type(&input, KernelType::Sparse(3));
844        // checks that the transform for ArrayView2 builds the right kernel
845        // according to its input params.
846        check_kernel_from_array_view_2_type(input.view(), KernelType::Dense);
847        check_kernel_from_array_view_2_type(input.view(), KernelType::Sparse(3));
848    }
849
850    // inspired from scikit learn's tests
851    #[test]
852    fn test_kernel_transform_from_dataset() {
853        let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
854        let input_arr = Array2::from_shape_vec((50, 2), input_vec).unwrap();
855        let input = Dataset::from(input_arr);
856        // checks that the transform for dataset builds the right kernel
857        // according to its input params.
858        check_kernel_from_dataset_type(&input, KernelType::Dense);
859        check_kernel_from_dataset_type(&input, KernelType::Sparse(3));
860
861        // checks that the transform for dataset view builds the right kernel
862        // according to its input params.
863        check_kernel_from_dataset_view_type(&input.view(), KernelType::Dense);
864        check_kernel_from_dataset_view_type(&input.view(), KernelType::Sparse(3));
865    }
866
867    fn check_kernel_from_dataset_type<'a, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a>>(
868        input: &'a DatasetBase<Array2<f64>, T>,
869        k_type: KernelType,
870    ) {
871        let methods = vec![
872            KernelMethod::Linear,
873            KernelMethod::Gaussian(0.1),
874            KernelMethod::Polynomial(1., 2.),
875        ];
876        for method in methods {
877            let kernel_ref = Kernel::new(
878                input.records().view(),
879                &Kernel::params_with_nn(KdTree)
880                    .method(method.clone())
881                    .kind(k_type.clone()),
882            );
883            let kernel_tr = Kernel::params()
884                .kind(k_type.clone())
885                .method(method.clone())
886                .transform(input);
887            match (&kernel_ref.inner, &kernel_tr.records().inner) {
888                (KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
889                    assert!(kernels_almost_equal(m1, m2))
890                }
891                (KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
892                    assert!(kernels_almost_equal(m1, m2))
893                }
894                _ => panic!("Kernel inners must match!"),
895            };
896        }
897    }
898
899    fn check_kernel_from_dataset_view_type<
900        'a,
901        L: 'a,
902        T: AsTargets<Elem = L> + FromTargetArray<'a>,
903    >(
904        input: &'a DatasetBase<ArrayView2<'a, f64>, T>,
905        k_type: KernelType,
906    ) {
907        let methods = vec![
908            KernelMethod::Linear,
909            KernelMethod::Gaussian(0.1),
910            KernelMethod::Polynomial(1., 2.),
911        ];
912        for method in methods {
913            let kernel_ref = Kernel::new(
914                *input.records(),
915                &Kernel::params_with_nn(KdTree)
916                    .method(method.clone())
917                    .kind(k_type.clone()),
918            );
919            let kernel_tr = Kernel::params()
920                .kind(k_type.clone())
921                .method(method.clone())
922                .transform(input);
923            match (&kernel_ref.inner, &kernel_tr.records().inner) {
924                (KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
925                    assert!(kernels_almost_equal(m1, m2))
926                }
927                (KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
928                    assert!(kernels_almost_equal(m1, m2))
929                }
930                _ => panic!("Kernel inners must match!"),
931            };
932        }
933    }
934
935    /// Test method for checking each KernelMethod can operate on `&Array2<f64>` using type and `view()`
936    fn check_kernel_from_array2_type(input: &Array2<f64>, k_type: KernelType) {
937        let methods = vec![
938            KernelMethod::Linear,
939            KernelMethod::Gaussian(0.1),
940            KernelMethod::Polynomial(1., 2.),
941        ];
942        for method in methods {
943            let kernel_ref = Kernel::new(
944                input.view(),
945                &Kernel::params_with_nn(KdTree)
946                    .method(method.clone())
947                    .kind(k_type.clone()),
948            );
949            let kernel_tr = Kernel::params()
950                .kind(k_type.clone())
951                .method(method.clone())
952                .transform(input.view());
953            match (&kernel_ref.inner, &kernel_tr.inner) {
954                (KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
955                    assert!(kernels_almost_equal(m1, m2))
956                }
957                (KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
958                    assert!(kernels_almost_equal(m1, m2))
959                }
960                _ => panic!("Kernel inners must match!"),
961            };
962        }
963    }
964
965    /// Test method for checking each KernelMethod can operate on `ArrayView2<f64>` type
966    fn check_kernel_from_array_view_2_type(input: ArrayView2<f64>, k_type: KernelType) {
967        let methods = vec![
968            KernelMethod::Linear,
969            KernelMethod::Gaussian(0.1),
970            KernelMethod::Polynomial(1., 2.),
971        ];
972        for method in methods {
973            let kernel_ref = Kernel::new(
974                input,
975                &Kernel::params_with_nn(KdTree)
976                    .method(method.clone())
977                    .kind(k_type.clone()),
978            );
979            let kernel_tr = Kernel::params()
980                .kind(k_type.clone())
981                .method(method.clone())
982                .transform(input);
983            match (&kernel_ref.inner, &kernel_tr.inner) {
984                (KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
985                    assert!(kernels_almost_equal(m1, m2))
986                }
987                (KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
988                    assert!(kernels_almost_equal(m1, m2))
989                }
990                _ => panic!("Kernel inners must match!"),
991            };
992        }
993    }
994
995    /// Determines if two matrices:`ArrayView2<f64>` are equivalent within f64::EPSILON
996    fn matrices_almost_equal(reference: ArrayView2<f64>, transformed: ArrayView2<f64>) -> bool {
997        for (ref_row, tr_row) in reference
998            .axis_iter(Axis(0))
999            .zip(transformed.axis_iter(Axis(0)))
1000        {
1001            if !arrays_almost_equal(ref_row, tr_row) {
1002                return false;
1003            }
1004        }
1005        true
1006    }
1007
1008    /// Determines if two arrays:`ArrayView1<64>` are equivalent within f64::EPSILON
1009    fn arrays_almost_equal(reference: ArrayView1<f64>, transformed: ArrayView1<f64>) -> bool {
1010        for (ref_item, tr_item) in reference.iter().zip(transformed.iter()) {
1011            if !values_almost_equal(ref_item, tr_item) {
1012                return false;
1013            }
1014        }
1015        true
1016    }
1017
1018    /// Determines if two kernels are equivalent for all matched elements are equivalent within f64::EPSILON
1019    fn kernels_almost_equal<K: Inner<Elem = f64>>(reference: &K, transformed: &K) -> bool {
1020        for i in 0..reference.size() {
1021            if !vecs_almost_equal(reference.column(i), transformed.column(i)) {
1022                return false;
1023            }
1024        }
1025        true
1026    }
1027
1028    /// Determines if all matched elements within a pair of vectors are equivalent within f64::EPSILON
1029    fn vecs_almost_equal(reference: Vec<f64>, transformed: Vec<f64>) -> bool {
1030        for (ref_item, tr_item) in reference.iter().zip(transformed.iter()) {
1031            if !values_almost_equal(ref_item, tr_item) {
1032                return false;
1033            }
1034        }
1035        true
1036    }
1037
1038    /// Determines if two values are equal within an absolute difference of f64::EPSILON
1039    fn values_almost_equal(v1: &f64, v2: &f64) -> bool {
1040        (v1 - v2).abs() <= f64::EPSILON
1041    }
1042}