1pub mod inner;
18mod sparse;
19
20pub use inner::{Inner, KernelInner};
21use linfa_nn::CommonNearestNeighbour;
22use linfa_nn::NearestNeighbour;
23use ndarray::prelude::*;
24use ndarray::Data;
25#[cfg(feature = "serde")]
26use serde_crate::{Deserialize, Serialize};
27use sprs::{CsMat, CsMatView};
28use std::ops::Mul;
29
30use linfa::{
31 dataset::AsTargets, dataset::DatasetBase, dataset::FromTargetArray, dataset::Records,
32 traits::Transformer, Float,
33};
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum KernelType {
38 Dense,
39 Sparse(usize),
42}
43
44#[cfg_attr(
46 feature = "serde",
47 derive(Serialize, Deserialize),
48 serde(crate = "serde_crate")
49)]
50#[derive(Debug, Clone, PartialEq)]
51pub struct KernelBase<K1: Inner, K2: Inner>
52where
53 K1::Elem: Float,
54 K2::Elem: Float,
55{
56 #[cfg_attr(
57 feature = "serde",
58 serde(bound(
59 serialize = "KernelInner<K1, K2>: Serialize",
60 deserialize = "KernelInner<K1, K2>: Deserialize<'de>"
61 ))
62 )]
63 pub inner: KernelInner<K1, K2>,
64 #[cfg_attr(
65 feature = "serde",
66 serde(bound(
67 serialize = "KernelMethod<K1::Elem>: Serialize",
68 deserialize = "KernelMethod<K1::Elem>: Deserialize<'de>"
69 ))
70 )]
71 pub method: KernelMethod<K1::Elem>,
73}
74
75pub type Kernel<F> = KernelBase<Array2<F>, CsMat<F>>;
77pub type KernelView<'a, F> = KernelBase<ArrayView2<'a, F>, CsMatView<'a, F>>;
79
80impl<F: Float, K1: Inner<Elem = F>, K2: Inner<Elem = F>> KernelBase<K1, K2> {
81 pub fn is_linear(&self) -> bool {
88 self.method.is_linear()
89 }
90
91 pub fn params() -> KernelParams<F, CommonNearestNeighbour> {
94 Self::params_with_nn(CommonNearestNeighbour::KdTree)
95 }
96
97 pub fn params_with_nn<N: NearestNeighbour>(nn_algo: N) -> KernelParams<F, N> {
99 KernelParams {
100 kind: KernelType::Dense,
101 method: KernelMethod::Gaussian(F::cast(0.5)),
102 nn_algo,
103 }
104 }
105
106 pub fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
122 match &self.inner {
123 KernelInner::Dense(inn) => inn.dot(rhs),
124 KernelInner::Sparse(inn) => inn.dot(rhs),
125 }
126 }
127
128 pub fn sum(&self) -> Array1<F> {
134 match &self.inner {
135 KernelInner::Dense(inn) => inn.sum(),
136 KernelInner::Sparse(inn) => inn.sum(),
137 }
138 }
139
140 pub fn size(&self) -> usize {
142 match &self.inner {
143 KernelInner::Dense(inn) => inn.size(),
144 KernelInner::Sparse(inn) => inn.size(),
145 }
146 }
147
148 pub fn column(&self, i: usize) -> Vec<F> {
162 match &self.inner {
163 KernelInner::Dense(inn) => inn.column(i),
164 KernelInner::Sparse(inn) => inn.column(i),
165 }
166 }
167
168 pub fn to_upper_triangle(&self) -> Vec<F> {
176 match &self.inner {
177 KernelInner::Dense(inn) => inn.to_upper_triangle(),
178 KernelInner::Sparse(inn) => inn.to_upper_triangle(),
179 }
180 }
181
182 pub fn diagonal(&self) -> Array1<F> {
189 match &self.inner {
190 KernelInner::Dense(inn) => inn.diagonal(),
191 KernelInner::Sparse(inn) => inn.diagonal(),
192 }
193 }
194}
195
196impl<'a, F: Float> Kernel<F> {
197 pub fn new<N: NearestNeighbour>(
198 dataset: ArrayView2<'a, F>,
199 params: &KernelParams<F, N>,
200 ) -> Kernel<F> {
201 let inner = match params.kind {
202 KernelType::Dense => KernelInner::Dense(dense_from_fn(&dataset, ¶ms.method)),
203 KernelType::Sparse(k) => {
204 KernelInner::Sparse(sparse_from_fn(&dataset, k, ¶ms.method, ¶ms.nn_algo))
205 }
206 };
207
208 Kernel {
209 inner,
210 method: params.method.clone(),
211 }
212 }
213
214 pub fn view(&'a self) -> KernelView<'a, F> {
216 KernelView {
217 inner: match &self.inner {
218 KernelInner::Dense(inn) => KernelInner::Dense(inn.view()),
219 KernelInner::Sparse(inn) => KernelInner::Sparse(inn.view()),
220 },
221 method: self.method.clone(),
222 }
223 }
224}
225
226impl<F: Float> KernelView<'_, F> {
227 pub fn to_owned(&self) -> Kernel<F> {
228 Kernel {
229 inner: match &self.inner {
230 KernelInner::Dense(inn) => KernelInner::Dense(inn.to_owned()),
231 KernelInner::Sparse(inn) => KernelInner::Sparse(inn.to_owned()),
232 },
233 method: self.method.clone(),
234 }
235 }
236}
237
238impl<F: Float, K1: Inner<Elem = F>, K2: Inner<Elem = F>> Records for KernelBase<K1, K2> {
239 type Elem = F;
240
241 fn nsamples(&self) -> usize {
242 self.size()
243 }
244
245 fn nfeatures(&self) -> usize {
246 self.size()
247 }
248}
249
250#[cfg_attr(
258 feature = "serde",
259 derive(Serialize, Deserialize),
260 serde(crate = "serde_crate")
261)]
262#[derive(Debug, Clone, PartialEq)]
263pub enum KernelMethod<F> {
264 Gaussian(F),
266 Linear,
268 Polynomial(F, F),
270}
271
272impl<F: Float> KernelMethod<F> {
273 pub fn distance(&self, a: ArrayView1<F>, b: ArrayView1<F>) -> F {
274 match *self {
275 KernelMethod::Gaussian(eps) => {
276 let distance = a
277 .iter()
278 .zip(b.iter())
279 .map(|(x, y)| (*x - *y) * (*x - *y))
280 .sum::<F>();
281
282 (-distance / eps).exp()
283 }
284 KernelMethod::Linear => a.mul(&b).sum(),
285 KernelMethod::Polynomial(c, d) => (a.mul(&b).sum() + c).powf(d),
286 }
287 }
288
289 pub fn is_linear(&self) -> bool {
290 matches!(*self, KernelMethod::Linear)
291 }
292}
293
294#[derive(Debug, Clone, PartialEq)]
296pub struct KernelParams<F, N = CommonNearestNeighbour> {
297 kind: KernelType,
299 method: KernelMethod<F>,
301 nn_algo: N,
303}
304
305impl<F, N> KernelParams<F, N> {
306 pub fn method(mut self, method: KernelMethod<F>) -> Self {
308 self.method = method;
309 self
310 }
311
312 pub fn kind(mut self, kind: KernelType) -> Self {
314 self.kind = kind;
315 self
316 }
317
318 pub fn nn_algo(mut self, nn_algo: N) -> Self {
320 self.nn_algo = nn_algo;
321 self
322 }
323}
324
325impl<F: Float, N: NearestNeighbour> Transformer<&Array2<F>, Kernel<F>> for KernelParams<F, N> {
326 fn transform(&self, x: &Array2<F>) -> Kernel<F> {
340 Kernel::new(x.view(), self)
341 }
342}
343
344impl<'a, F: Float, N: NearestNeighbour> Transformer<ArrayView2<'a, F>, Kernel<F>>
345 for KernelParams<F, N>
346{
347 fn transform(&self, x: ArrayView2<'a, F>) -> Kernel<F> {
361 Kernel::new(x, self)
362 }
363}
364
365impl<'a, F: Float, N: NearestNeighbour> Transformer<&ArrayView2<'a, F>, Kernel<F>>
366 for KernelParams<F, N>
367{
368 fn transform(&self, x: &ArrayView2<'a, F>) -> Kernel<F> {
382 Kernel::new(*x, self)
383 }
384}
385
386impl<F: Float, T: AsTargets, N: NearestNeighbour>
387 Transformer<DatasetBase<Array2<F>, T>, DatasetBase<Kernel<F>, T>> for KernelParams<F, N>
388{
389 fn transform(&self, x: DatasetBase<Array2<F>, T>) -> DatasetBase<Kernel<F>, T> {
409 let kernel = Kernel::new(x.records.view(), self);
410 DatasetBase::new(kernel, x.targets)
411 }
412}
413
414impl<'a, F: Float, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a>, N: NearestNeighbour>
415 Transformer<&'a DatasetBase<Array2<F>, T>, DatasetBase<Kernel<F>, T::View>>
416 for KernelParams<F, N>
417{
418 fn transform(&self, x: &'a DatasetBase<Array2<F>, T>) -> DatasetBase<Kernel<F>, T::View> {
436 let kernel = Kernel::new(x.records.view(), self);
437 DatasetBase::new(kernel, T::new_targets_view(x.as_targets()))
438 }
439}
440
441impl<
444 'a,
445 'b,
446 F: Float,
447 L: 'b,
448 T: AsTargets<Elem = L> + FromTargetArray<'b>,
449 N: NearestNeighbour,
450 > Transformer<&'b DatasetBase<ArrayView2<'a, F>, T>, DatasetBase<Kernel<F>, T::View>>
451 for KernelParams<F, N>
452{
453 fn transform(
471 &self,
472 x: &'b DatasetBase<ArrayView2<'a, F>, T>,
473 ) -> DatasetBase<Kernel<F>, T::View> {
474 let kernel = Kernel::new(x.records.view(), self);
475
476 DatasetBase::new(kernel, T::new_targets_view(x.as_targets()))
477 }
478}
479
480fn dense_from_fn<F: Float, D: Data<Elem = F>>(
481 dataset: &ArrayBase<D, Ix2>,
482 method: &KernelMethod<F>,
483) -> Array2<F> {
484 let n_observations = dataset.len_of(Axis(0));
485 let mut similarity = Array2::eye(n_observations);
486
487 for i in 0..n_observations {
488 for j in 0..n_observations {
489 let a = dataset.row(i);
490 let b = dataset.row(j);
491
492 similarity[(i, j)] = method.distance(a, b);
493 }
494 }
495
496 similarity
497}
498
499fn sparse_from_fn<F: Float, D: Data<Elem = F>, N: NearestNeighbour>(
500 dataset: &ArrayBase<D, Ix2>,
501 k: usize,
502 method: &KernelMethod<F>,
503 nn_algo: &N,
504) -> CsMat<F> {
505 let mut data = sparse::adjacency_matrix(dataset, k, nn_algo);
508
509 for (i, mut vec) in data.outer_iterator_mut().enumerate() {
513 for (j, val) in vec.iter_mut() {
517 let a = dataset.row(i);
518 let b = dataset.row(j);
519
520 *val = method.distance(a, b);
521 }
522 }
523 data
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use linfa::Dataset;
530 use linfa_nn::{BallTree, KdTree};
531 use ndarray::{Array1, Array2};
532 use std::f64::consts;
533
534 #[test]
535 fn autotraits() {
536 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
537 has_autotraits::<KernelType>();
538 has_autotraits::<KernelBase<ArrayView2<f64>, ArrayView2<f64>>>();
539 has_autotraits::<KernelMethod<f64>>();
540 has_autotraits::<KernelParams<f64, f64>>();
541 has_autotraits::<KernelView<f64>>();
542 has_autotraits::<KernelInner<ArrayView2<f64>, ArrayView2<f64>>>();
543 has_autotraits::<Kernel<f64>>();
544 }
545
546 #[test]
547 fn sparse_from_fn_test() {
548 let input_mat = vec![
552 0., 0., 0.1, 0.1, 1., 1., 1.1, 1.1, 2., 2., 2.1, 2.1, 3., 3., 3.1, 3.1,
553 ];
554 let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
555 let adj_mat = sparse_from_fn(&input_arr, 1, &KernelMethod::Linear, &KdTree);
556 assert_eq!(adj_mat.nnz(), 16);
557
558 assert_eq!(*adj_mat.get(0, 0).unwrap() as usize, 0);
560 assert_eq!((*adj_mat.get(1, 1).unwrap() * 100.) as usize, 2);
562 assert_eq!(*adj_mat.get(2, 2).unwrap() as usize, 2);
564 assert_eq!((*adj_mat.get(3, 3).unwrap() * 100.) as usize, 242);
566 assert_eq!(*adj_mat.get(4, 4).unwrap() as usize, 8);
568 assert_eq!((*adj_mat.get(5, 5).unwrap() * 100.) as usize, 882);
570 assert_eq!(*adj_mat.get(6, 6).unwrap() as usize, 18);
572 assert_eq!((*adj_mat.get(7, 7).unwrap() * 100.) as usize, 1922);
574
575 assert_eq!(*adj_mat.get(0, 1).unwrap() as usize, 0);
577 assert_eq!(*adj_mat.get(1, 0).unwrap() as usize, 0);
578
579 assert_eq!((*adj_mat.get(2, 3).unwrap() * 10.) as usize, 22);
581 assert_eq!((*adj_mat.get(3, 2).unwrap() * 10.) as usize, 22);
582
583 assert_eq!((*adj_mat.get(4, 5).unwrap() * 10.) as usize, 84);
585 assert_eq!((*adj_mat.get(5, 4).unwrap() * 10.) as usize, 84);
586
587 assert_eq!((*adj_mat.get(6, 7).unwrap() * 10.) as usize, 186);
589 assert_eq!((*adj_mat.get(7, 6).unwrap() * 10.) as usize, 186);
590 }
591
592 #[test]
593 fn dense_from_fn_test() {
594 let input_mat = vec![
598 0., 0., 0.1, 0.1, 1., 1., 1.1, 1.1, 2., 2., 2.1, 2.1, 3., 3., 3.1, 3.1,
599 ];
600 let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
601 let method: KernelMethod<f64> = KernelMethod::Linear;
602
603 let similarity_matrix = dense_from_fn(&input_arr, &method);
604
605 for i in 0..8 {
606 for j in 0..8 {
607 assert!(
608 (similarity_matrix.row(i)[j]
609 - method.distance(input_arr.row(i), input_arr.row(j)))
610 .abs()
611 <= f64::EPSILON
612 );
613 }
614 }
615 }
616
617 #[test]
618 fn gaussian_test() {
619 let gauss_1 = KernelMethod::Gaussian(1.);
620
621 let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
622 let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
623 let distance = gauss_1.distance(p1.view(), p2.view());
624 let expected = 1.;
625
626 assert!(f64::abs(distance - expected) <= f64::EPSILON);
627
628 let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
629 let p2 = Array1::from_shape_vec(2, vec![5., 5.]).unwrap();
630 let distance = gauss_1.distance(p1.view(), p2.view());
631 let expected = (consts::E).powf(-32.);
632 assert!(f64::abs(distance - expected) <= f64::EPSILON);
634
635 let gauss_01 = KernelMethod::Gaussian(0.1);
636
637 let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
638 let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
639 let distance = gauss_01.distance(p1.view(), p2.view());
640 let expected = 1.;
641
642 assert!(f64::abs(distance - expected) <= f64::EPSILON);
643
644 let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
645 let p2 = Array1::from_shape_vec(2, vec![2., 2.]).unwrap();
646 let distance = gauss_01.distance(p1.view(), p2.view());
647 let expected = (consts::E).powf(-20.);
648
649 assert!(f64::abs(distance - expected) <= f64::EPSILON);
650 }
651
652 #[test]
653 fn poly2_test() {
654 let pol_0 = KernelMethod::Polynomial(0., 2.);
655
656 let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
657 let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
658 let distance = pol_0.distance(p1.view(), p2.view());
659 let expected = 0.;
660
661 assert!(f64::abs(distance - expected) <= f64::EPSILON);
662
663 let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
664 let p2 = Array1::from_shape_vec(2, vec![5., 5.]).unwrap();
665 let distance = pol_0.distance(p1.view(), p2.view());
666 let expected = 100.;
667 assert!(f64::abs(distance - expected) <= f64::EPSILON);
668
669 let pol_2 = KernelMethod::Polynomial(2., 2.);
670
671 let p1 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
672 let p2 = Array1::from_shape_vec(2, vec![0., 0.]).unwrap();
673 let distance = pol_2.distance(p1.view(), p2.view());
674 let expected = 4.;
675
676 assert!(f64::abs(distance - expected) <= f64::EPSILON);
677
678 let p1 = Array1::from_shape_vec(2, vec![1., 1.]).unwrap();
679 let p2 = Array1::from_shape_vec(2, vec![2., 2.]).unwrap();
680 let distance = pol_2.distance(p1.view(), p2.view());
681 let expected = 36.;
682
683 assert!(f64::abs(distance - expected) <= f64::EPSILON);
684 }
685
686 #[test]
687 fn test_kernel_dot() {
688 let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
689 let vec_to_multiply: Vec<f64> = (0..100).map(|v| v as f64 * 0.3).collect();
690 let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
691 let to_multiply = Array2::from_shape_vec((10, 10), vec_to_multiply).unwrap();
692
693 let mul_mat = dense_from_fn(&input_arr, &KernelMethod::Linear).dot(&to_multiply);
695 let kernel = KernelView::params()
696 .kind(KernelType::Dense)
697 .method(KernelMethod::Linear)
698 .transform(input_arr.view());
699 let mul_ker = kernel.dot(&to_multiply.view());
700 assert!(matrices_almost_equal(mul_mat.view(), mul_ker.view()));
701
702 let mul_mat =
704 sparse_from_fn(&input_arr, 3, &KernelMethod::Linear, &KdTree).mul(&to_multiply.view());
705 let kernel = KernelView::params()
706 .kind(KernelType::Sparse(3))
707 .method(KernelMethod::Linear)
708 .transform(input_arr.view());
709 let mul_ker = kernel.dot(&to_multiply.view());
710 assert!(matrices_almost_equal(mul_mat.view(), mul_ker.view()));
711 }
712
713 #[test]
714 fn test_kernel_upper_triangle() {
715 let input_vec: Vec<f64> = (0..50).map(|v| v as f64 * 0.1).collect();
717 let input_arr_1 = Array2::from_shape_vec((5, 10), input_vec.clone()).unwrap();
718 let mut input_arr_2 = Array2::from_shape_vec((5, 10), input_vec).unwrap();
719 input_arr_2.invert_axis(Axis(0));
720 let input_arr =
721 ndarray::concatenate(Axis(0), &[input_arr_1.view(), input_arr_2.view()]).unwrap();
722
723 for kind in [KernelType::Dense, KernelType::Sparse(1)] {
724 let kernel = KernelView::params()
725 .kind(kind)
726 .method(KernelMethod::Gaussian(1e-5))
729 .transform(input_arr.view());
730 let mut kernel_upper_triang = kernel.to_upper_triangle();
731 assert_eq!(kernel_upper_triang.len(), 45);
732 kernel_upper_triang.reverse();
734 for i in 0..9 {
735 for j in (i + 1)..10 {
736 if j == (9 - i) {
737 assert_eq!(kernel_upper_triang.pop().unwrap() as usize, 1);
738 } else {
739 assert_eq!(kernel_upper_triang.pop().unwrap() as usize, 0);
740 }
741 }
742 }
743 assert!(kernel_upper_triang.is_empty());
744 }
745 }
746
747 #[test]
771 fn test_kernel_sum() {
772 let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
773 let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
774
775 let method = KernelMethod::Linear;
776
777 let cols_sum = dense_from_fn(&input_arr, &method).sum_axis(Axis(1));
779 let kernel = KernelView::params()
780 .kind(KernelType::Dense)
781 .method(method.clone())
782 .transform(input_arr.view());
783 let kers_sum = kernel.sum();
784 assert!(arrays_almost_equal(cols_sum.view(), kers_sum.view()));
785
786 let cols_sum = sparse_from_fn(&input_arr, 3, &method, &BallTree)
788 .to_dense()
789 .sum_axis(Axis(1));
790 let kernel = KernelView::params()
791 .kind(KernelType::Sparse(3))
792 .method(method)
793 .transform(input_arr.view());
794 let kers_sum = kernel.sum();
795 assert!(arrays_almost_equal(cols_sum.view(), kers_sum.view()));
796 }
797
798 #[test]
799 fn test_kernel_diag() {
800 let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
801 let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap();
802
803 let method = KernelMethod::Linear;
804
805 let input_diagonal = dense_from_fn(&input_arr, &method).diag().into_owned();
807 let kernel = KernelView::params()
808 .kind(KernelType::Dense)
809 .method(method.clone())
810 .transform(input_arr.view());
811 let kers_diagonal = kernel.diagonal();
812 assert!(arrays_almost_equal(
813 input_diagonal.view(),
814 kers_diagonal.view()
815 ));
816
817 let input_diagonal: Vec<_> = sparse_from_fn(&input_arr, 3, &method, &BallTree)
819 .outer_iterator()
820 .enumerate()
821 .map(|(i, row)| *row.get(i).unwrap())
822 .collect();
823 let input_diagonal = Array1::from_shape_vec(10, input_diagonal).unwrap();
824 let kernel = KernelView::params()
825 .kind(KernelType::Sparse(3))
826 .method(method)
827 .transform(input_arr.view());
828 let kers_diagonal = kernel.diagonal();
829 assert!(arrays_almost_equal(
830 input_diagonal.view(),
831 kers_diagonal.view()
832 ));
833 }
834
835 #[test]
837 fn test_kernel_transform_from_array2() {
838 let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
839 let input = Array2::from_shape_vec((50, 2), input_vec).unwrap();
840 check_kernel_from_array2_type(&input, KernelType::Dense);
843 check_kernel_from_array2_type(&input, KernelType::Sparse(3));
844 check_kernel_from_array_view_2_type(input.view(), KernelType::Dense);
847 check_kernel_from_array_view_2_type(input.view(), KernelType::Sparse(3));
848 }
849
850 #[test]
852 fn test_kernel_transform_from_dataset() {
853 let input_vec: Vec<f64> = (0..100).map(|v| v as f64 * 0.1).collect();
854 let input_arr = Array2::from_shape_vec((50, 2), input_vec).unwrap();
855 let input = Dataset::from(input_arr);
856 check_kernel_from_dataset_type(&input, KernelType::Dense);
859 check_kernel_from_dataset_type(&input, KernelType::Sparse(3));
860
861 check_kernel_from_dataset_view_type(&input.view(), KernelType::Dense);
864 check_kernel_from_dataset_view_type(&input.view(), KernelType::Sparse(3));
865 }
866
867 fn check_kernel_from_dataset_type<'a, L: 'a, T: AsTargets<Elem = L> + FromTargetArray<'a>>(
868 input: &'a DatasetBase<Array2<f64>, T>,
869 k_type: KernelType,
870 ) {
871 let methods = vec![
872 KernelMethod::Linear,
873 KernelMethod::Gaussian(0.1),
874 KernelMethod::Polynomial(1., 2.),
875 ];
876 for method in methods {
877 let kernel_ref = Kernel::new(
878 input.records().view(),
879 &Kernel::params_with_nn(KdTree)
880 .method(method.clone())
881 .kind(k_type.clone()),
882 );
883 let kernel_tr = Kernel::params()
884 .kind(k_type.clone())
885 .method(method.clone())
886 .transform(input);
887 match (&kernel_ref.inner, &kernel_tr.records().inner) {
888 (KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
889 assert!(kernels_almost_equal(m1, m2))
890 }
891 (KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
892 assert!(kernels_almost_equal(m1, m2))
893 }
894 _ => panic!("Kernel inners must match!"),
895 };
896 }
897 }
898
899 fn check_kernel_from_dataset_view_type<
900 'a,
901 L: 'a,
902 T: AsTargets<Elem = L> + FromTargetArray<'a>,
903 >(
904 input: &'a DatasetBase<ArrayView2<'a, f64>, T>,
905 k_type: KernelType,
906 ) {
907 let methods = vec![
908 KernelMethod::Linear,
909 KernelMethod::Gaussian(0.1),
910 KernelMethod::Polynomial(1., 2.),
911 ];
912 for method in methods {
913 let kernel_ref = Kernel::new(
914 *input.records(),
915 &Kernel::params_with_nn(KdTree)
916 .method(method.clone())
917 .kind(k_type.clone()),
918 );
919 let kernel_tr = Kernel::params()
920 .kind(k_type.clone())
921 .method(method.clone())
922 .transform(input);
923 match (&kernel_ref.inner, &kernel_tr.records().inner) {
924 (KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
925 assert!(kernels_almost_equal(m1, m2))
926 }
927 (KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
928 assert!(kernels_almost_equal(m1, m2))
929 }
930 _ => panic!("Kernel inners must match!"),
931 };
932 }
933 }
934
935 fn check_kernel_from_array2_type(input: &Array2<f64>, k_type: KernelType) {
937 let methods = vec![
938 KernelMethod::Linear,
939 KernelMethod::Gaussian(0.1),
940 KernelMethod::Polynomial(1., 2.),
941 ];
942 for method in methods {
943 let kernel_ref = Kernel::new(
944 input.view(),
945 &Kernel::params_with_nn(KdTree)
946 .method(method.clone())
947 .kind(k_type.clone()),
948 );
949 let kernel_tr = Kernel::params()
950 .kind(k_type.clone())
951 .method(method.clone())
952 .transform(input.view());
953 match (&kernel_ref.inner, &kernel_tr.inner) {
954 (KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
955 assert!(kernels_almost_equal(m1, m2))
956 }
957 (KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
958 assert!(kernels_almost_equal(m1, m2))
959 }
960 _ => panic!("Kernel inners must match!"),
961 };
962 }
963 }
964
965 fn check_kernel_from_array_view_2_type(input: ArrayView2<f64>, k_type: KernelType) {
967 let methods = vec![
968 KernelMethod::Linear,
969 KernelMethod::Gaussian(0.1),
970 KernelMethod::Polynomial(1., 2.),
971 ];
972 for method in methods {
973 let kernel_ref = Kernel::new(
974 input,
975 &Kernel::params_with_nn(KdTree)
976 .method(method.clone())
977 .kind(k_type.clone()),
978 );
979 let kernel_tr = Kernel::params()
980 .kind(k_type.clone())
981 .method(method.clone())
982 .transform(input);
983 match (&kernel_ref.inner, &kernel_tr.inner) {
984 (KernelInner::Dense(m1), KernelInner::Dense(m2)) => {
985 assert!(kernels_almost_equal(m1, m2))
986 }
987 (KernelInner::Sparse(m1), KernelInner::Sparse(m2)) => {
988 assert!(kernels_almost_equal(m1, m2))
989 }
990 _ => panic!("Kernel inners must match!"),
991 };
992 }
993 }
994
995 fn matrices_almost_equal(reference: ArrayView2<f64>, transformed: ArrayView2<f64>) -> bool {
997 for (ref_row, tr_row) in reference
998 .axis_iter(Axis(0))
999 .zip(transformed.axis_iter(Axis(0)))
1000 {
1001 if !arrays_almost_equal(ref_row, tr_row) {
1002 return false;
1003 }
1004 }
1005 true
1006 }
1007
1008 fn arrays_almost_equal(reference: ArrayView1<f64>, transformed: ArrayView1<f64>) -> bool {
1010 for (ref_item, tr_item) in reference.iter().zip(transformed.iter()) {
1011 if !values_almost_equal(ref_item, tr_item) {
1012 return false;
1013 }
1014 }
1015 true
1016 }
1017
1018 fn kernels_almost_equal<K: Inner<Elem = f64>>(reference: &K, transformed: &K) -> bool {
1020 for i in 0..reference.size() {
1021 if !vecs_almost_equal(reference.column(i), transformed.column(i)) {
1022 return false;
1023 }
1024 }
1025 true
1026 }
1027
1028 fn vecs_almost_equal(reference: Vec<f64>, transformed: Vec<f64>) -> bool {
1030 for (ref_item, tr_item) in reference.iter().zip(transformed.iter()) {
1031 if !values_almost_equal(ref_item, tr_item) {
1032 return false;
1033 }
1034 }
1035 true
1036 }
1037
1038 fn values_almost_equal(v1: &f64, v2: &f64) -> bool {
1040 (v1 - v2).abs() <= f64::EPSILON
1041 }
1042}