1use linfa::Float;
2use ndarray::prelude::*;
3use ndarray::Data;
4use sprs::{CsMat, CsMatView};
5use std::ops::Mul;
6
7pub trait Inner {
10 type Elem: Float;
11
12 fn dot(&self, rhs: &ArrayView2<Self::Elem>) -> Array2<Self::Elem>;
13 fn sum(&self) -> Array1<Self::Elem>;
14 fn size(&self) -> usize;
15 fn column(&self, i: usize) -> Vec<Self::Elem>;
16 fn to_upper_triangle(&self) -> Vec<Self::Elem>;
17 fn is_dense(&self) -> bool;
18 fn diagonal(&self) -> Array1<Self::Elem>;
19}
20
21#[derive(Debug, Clone, PartialEq)]
24pub enum KernelInner<K1: Inner, K2: Inner> {
25 Dense(K1),
26 Sparse(K2),
27}
28
29impl<F: Float, D: Data<Elem = F>> Inner for ArrayBase<D, Ix2> {
30 type Elem = F;
31
32 fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
33 self.dot(rhs)
34 }
35 fn sum(&self) -> Array1<F> {
36 self.sum_axis(Axis(1))
37 }
38 fn size(&self) -> usize {
39 self.ncols()
40 }
41 fn column(&self, i: usize) -> Vec<F> {
42 self.column(i).to_vec()
43 }
44 fn to_upper_triangle(&self) -> Vec<F> {
45 self.indexed_iter()
46 .filter(|((row, col), _)| col > row)
47 .map(|(_, val)| *val)
48 .collect()
49 }
50
51 fn diagonal(&self) -> Array1<F> {
52 self.diag().to_owned()
53 }
54
55 fn is_dense(&self) -> bool {
56 true
57 }
58}
59
60impl<F: Float> Inner for CsMat<F> {
61 type Elem = F;
62
63 fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
64 self.mul(rhs)
65 }
66 fn sum(&self) -> Array1<F> {
67 let mut sum = Array1::zeros(self.cols());
68 for (val, i) in self.iter() {
69 let (_, col) = i;
70 sum[col] += *val;
71 }
72
73 sum
74 }
75 fn size(&self) -> usize {
76 self.cols()
77 }
78 fn column(&self, i: usize) -> Vec<F> {
79 (0..self.size())
80 .map(|j| *self.get(j, i).unwrap_or(&F::neg_zero()))
81 .collect::<Vec<_>>()
82 }
83 fn to_upper_triangle(&self) -> Vec<F> {
84 let mat = self.to_dense();
85 mat.indexed_iter()
86 .filter(|((row, col), _)| col > row)
87 .map(|(_, val)| *val)
88 .collect()
89 }
90
91 fn diagonal(&self) -> Array1<F> {
92 let diag_sprs = self.diag();
93 let mut diag = Array1::zeros(diag_sprs.dim());
94 for (sparse_i, sparse_elem) in diag_sprs.iter() {
95 diag[sparse_i] = *sparse_elem;
96 }
97 diag
98 }
99
100 fn is_dense(&self) -> bool {
101 false
102 }
103}
104
105impl<F: Float> Inner for CsMatView<'_, F> {
106 type Elem = F;
107
108 fn dot(&self, rhs: &ArrayView2<F>) -> Array2<F> {
109 self.mul(rhs)
110 }
111 fn sum(&self) -> Array1<F> {
112 let mut sum = Array1::zeros(self.cols());
113 for (val, i) in self.iter() {
114 let (_, col) = i;
115 sum[col] += *val;
116 }
117
118 sum
119 }
120 fn size(&self) -> usize {
121 self.cols()
122 }
123 fn column(&self, i: usize) -> Vec<F> {
124 (0..self.size())
125 .map(|j| *self.get(j, i).unwrap_or(&F::neg_zero()))
126 .collect::<Vec<_>>()
127 }
128 fn to_upper_triangle(&self) -> Vec<F> {
129 let mat = self.to_dense();
130 mat.indexed_iter()
131 .filter(|((row, col), _)| col > row)
132 .map(|(_, val)| *val)
133 .collect()
134 }
135 fn diagonal(&self) -> Array1<F> {
136 let diag_sprs = self.diag();
137 let mut diag = Array1::zeros(diag_sprs.dim());
138 for (sparse_i, sparse_elem) in diag_sprs.iter() {
139 diag[sparse_i] = *sparse_elem;
140 }
141 diag
142 }
143 fn is_dense(&self) -> bool {
144 false
145 }
146}