linfa_kernel/
inner.rs

1use linfa::Float;
2use ndarray::prelude::*;
3use ndarray::Data;
4use sprs::{CsMat, CsMatView};
5use std::ops::Mul;
6
7/// Specifies the methods an inner matrix of a kernel must
8/// be able to provide
9pub trait Inner {
10    type Elem: Float;
11
12    fn dot(&self, rhs: &ArrayView2<Self::Elem>) -> Array2<Self::Elem>;
13    fn sum(&self) -> Array1<Self::Elem>;
14    fn size(&self) -> usize;
15    fn column(&self, i: usize) -> Vec<Self::Elem>;
16    fn to_upper_triangle(&self) -> Vec<Self::Elem>;
17    fn is_dense(&self) -> bool;
18    fn diagonal(&self) -> Array1<Self::Elem>;
19}
20
21/// Allows a kernel to have either a dense or a sparse inner
22/// matrix in a way that is transparent to the user
23#[derive(Debug, Clone, PartialEq)]
24pub enum KernelInner<K1: Inner, K2: Inner> {
25    Dense(K1),
26    Sparse(K2),
27}
28
29impl<F: Float, D: Data<Elem = F>> Inner for ArrayBase<D, Ix2> {
30    type Elem = F;
31
32    fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
33        self.dot(rhs)
34    }
35    fn sum(&self) -> Array1<F> {
36        self.sum_axis(Axis(1))
37    }
38    fn size(&self) -> usize {
39        self.ncols()
40    }
41    fn column(&self, i: usize) -> Vec<F> {
42        self.column(i).to_vec()
43    }
44    fn to_upper_triangle(&self) -> Vec<F> {
45        self.indexed_iter()
46            .filter(|((row, col), _)| col > row)
47            .map(|(_, val)| *val)
48            .collect()
49    }
50
51    fn diagonal(&self) -> Array1<F> {
52        self.diag().to_owned()
53    }
54
55    fn is_dense(&self) -> bool {
56        true
57    }
58}
59
60impl<F: Float> Inner for CsMat<F> {
61    type Elem = F;
62
63    fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
64        self.mul(rhs)
65    }
66    fn sum(&self) -> Array1<F> {
67        let mut sum = Array1::zeros(self.cols());
68        for (val, i) in self.iter() {
69            let (_, col) = i;
70            sum[col] += *val;
71        }
72
73        sum
74    }
75    fn size(&self) -> usize {
76        self.cols()
77    }
78    fn column(&self, i: usize) -> Vec<F> {
79        (0..self.size())
80            .map(|j| *self.get(j, i).unwrap_or(&F::neg_zero()))
81            .collect::<Vec<_>>()
82    }
83    fn to_upper_triangle(&self) -> Vec<F> {
84        let mat = self.to_dense();
85        mat.indexed_iter()
86            .filter(|((row, col), _)| col > row)
87            .map(|(_, val)| *val)
88            .collect()
89    }
90
91    fn diagonal(&self) -> Array1<F> {
92        let diag_sprs = self.diag();
93        let mut diag = Array1::zeros(diag_sprs.dim());
94        for (sparse_i, sparse_elem) in diag_sprs.iter() {
95            diag[sparse_i] = *sparse_elem;
96        }
97        diag
98    }
99
100    fn is_dense(&self) -> bool {
101        false
102    }
103}
104
105impl<F: Float> Inner for CsMatView<'_, F> {
106    type Elem = F;
107
108    fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
109        self.mul(rhs)
110    }
111    fn sum(&self) -> Array1<F> {
112        let mut sum = Array1::zeros(self.cols());
113        for (val, i) in self.iter() {
114            let (_, col) = i;
115            sum[col] += *val;
116        }
117
118        sum
119    }
120    fn size(&self) -> usize {
121        self.cols()
122    }
123    fn column(&self, i: usize) -> Vec<F> {
124        (0..self.size())
125            .map(|j| *self.get(j, i).unwrap_or(&F::neg_zero()))
126            .collect::<Vec<_>>()
127    }
128    fn to_upper_triangle(&self) -> Vec<F> {
129        let mat = self.to_dense();
130        mat.indexed_iter()
131            .filter(|((row, col), _)| col > row)
132            .map(|(_, val)| *val)
133            .collect()
134    }
135    fn diagonal(&self) -> Array1<F> {
136        let diag_sprs = self.diag();
137        let mut diag = Array1::zeros(diag_sprs.dim());
138        for (sparse_i, sparse_elem) in diag_sprs.iter() {
139            diag[sparse_i] = *sparse_elem;
140        }
141        diag
142    }
143    fn is_dense(&self) -> bool {
144        false
145    }
146}