linfa_kernel/
sparse.rs

1use linfa::Float;
2use linfa_nn::{distance::L2Dist, NearestNeighbour};
3use ndarray::{ArrayBase, Axis, Data, Ix2};
4use sprs::{CsMat, CsMatBase};
5
6/// Create sparse adjacency matrix from dense dataset
7pub fn adjacency_matrix<F: Float, DT: Data<Elem = F>, N: NearestNeighbour>(
8    dataset: &ArrayBase<DT, Ix2>,
9    k: usize,
10    nn_algo: &N,
11) -> CsMat<F> {
12    let n_points = dataset.len_of(Axis(0));
13
14    // ensure that the number of neighbours is at least one and less than the total number of
15    // points
16    assert!(k < n_points);
17    assert!(k > 0);
18
19    let nn = nn_algo
20        .from_batch(dataset, L2Dist)
21        .expect("Unexpected nearest neighbour error");
22
23    // allocate buffer to initialize the sparse matrix later on
24    //  * data: we have exact #points * k positive entries
25    //  * indptr: has structure [0,k,2k,...,#points*k]
26    //  * indices: filled with the nearest indices
27    let mut data = Vec::with_capacity(n_points * (k + 1));
28    let mut indptr = Vec::with_capacity(n_points + 1);
29    //let indptr = (0..n_points+1).map(|x| x * (k+1)).collect::<Vec<_>>();
30    let mut indices = Vec::with_capacity(n_points * (k + 1));
31    indptr.push(0);
32
33    // find neighbours for each data point
34    let mut added = 0;
35    for (m, feature) in dataset.rows().into_iter().enumerate() {
36        let mut neighbours = nn.k_nearest(feature, k + 1).unwrap();
37
38        //dbg!(&neighbours);
39
40        // sort by indices
41        neighbours.sort_unstable_by_key(|(_, i)| *i);
42
43        indices.push(m);
44        data.push(F::one());
45        added += 1;
46
47        // push each index into the indices array
48        for &(_, i) in &neighbours {
49            if m != i {
50                indices.push(i);
51                data.push(F::one());
52                added += 1;
53            }
54        }
55
56        indptr.push(added);
57    }
58
59    // create CSR matrix from data, indptr and indices
60    let mat = CsMatBase::new_from_unsorted((n_points, n_points), indptr, indices, data).unwrap();
61    let transpose = mat.transpose_view().to_other_storage();
62    let mut mat = sprs::binop::csmat_binop(mat.view(), transpose.view(), |x, y| x.add(*y));
63
64    // ensure that all values are one
65    let val: F = F::one();
66    mat.map_inplace(|_| val);
67
68    mat
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use linfa_nn::{BallTree, KdTree};
75    use ndarray::Array2;
76
77    #[test]
78    #[allow(clippy::if_same_then_else)]
79    fn adjacency_matrix_test() {
80        // pts 0 & 1    pts 2 & 3    pts 4 & 5     pts 6 & 7
81        // |0.| |0.1| _ |1.| |1.1| _ |2.| |2.1| _  |3.| |3.1|
82        // |0.| |0.1|   |1.| |1.1|   |2.| |2.1|    |3.| |3.1|
83        let input_mat = vec![
84            0., 0., 0.1, 0.1, 1., 1., 1.1, 1.1, 2., 2., 2.1, 2.1, 3., 3., 3.1, 3.1,
85        ];
86        let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
87        // Elements in the input come in pairs of 2 nearby elements with consecutive indices
88        // I expect a matrix with 16 non-zero elements placed in the diagonal and connecting
89        // consecutive elements in pairs of two
90        let adj_mat = adjacency_matrix(&input_arr, 1, &KdTree);
91        assert_eq!(adj_mat.nnz(), 16);
92
93        for i in 0..8 {
94            for j in 0..8 {
95                // 8 diagonal elements
96                if i == j {
97                    assert_eq!(*adj_mat.get(i, j).unwrap() as usize, 1);
98                // (0,1), (2,3), (4,5), (6,7) -> 4 elements
99                } else if i % 2 == 0 && j == i + 1 {
100                    assert_eq!(*adj_mat.get(i, j).unwrap() as usize, 1);
101                // (1,0), (3,2), (5,4), (7,6) -> 4 elements
102                } else if j % 2 == 0 && i == j + 1 {
103                    assert_eq!(*adj_mat.get(i, j).unwrap() as usize, 1);
104                // all other 48 elements
105                } else {
106                    // Since this is the first test we check that all these elements
107                    // are `None`, even if it follows from `adj_mat.nnz() = 16`
108                    assert_eq!(adj_mat.get(i, j), None);
109                }
110            }
111        }
112
113        // Elements in the input come in triples of 3 nearby elements with consecutive indices
114        // I expect a matrix with 26 non-zero elements placed in the diagonal and connecting
115        // consecutive elements in triples
116        let adj_mat = adjacency_matrix(&input_arr, 2, &KdTree);
117        assert_eq!(adj_mat.nnz(), 26);
118
119        // diagonal -> 8 non-zeros
120        for i in 0..8 {
121            assert_eq!(*adj_mat.get(i, i).unwrap() as usize, 1);
122        }
123
124        // central input elements have neighbours in the previous and next input elements
125        // -> 12 non zeros
126        for i in 1..7 {
127            assert_eq!(*adj_mat.get(i, i + 1).unwrap() as usize, 1);
128            assert_eq!(*adj_mat.get(i, i - 1).unwrap() as usize, 1);
129        }
130
131        // first and last elements have neighbours respectively in
132        // the next and previous two elements
133        // -> 4 non-zeros
134        assert_eq!(*adj_mat.get(0, 1).unwrap() as usize, 1);
135        assert_eq!(*adj_mat.get(0, 2).unwrap() as usize, 1);
136        assert_eq!(*adj_mat.get(7, 6).unwrap() as usize, 1);
137        assert_eq!(*adj_mat.get(7, 5).unwrap() as usize, 1);
138
139        // it follows then that the third and third-to-last elements
140        // have also neighbours respectively in the first and last elements
141        // -> 2 non-zeros -> 26 total
142        assert_eq!(*adj_mat.get(0, 2).unwrap() as usize, 1);
143        assert_eq!(*adj_mat.get(7, 5).unwrap() as usize, 1);
144
145        // it follows then that all other elements are `None`
146    }
147
148    #[test]
149    fn adjacency_matrix_test_2() {
150        // pts 0 & 1    pts 2 & 3    pts 4 & 5     pts 6 & 7
151        // |0.| |3.1| _ |1.| |2.1| _ |2.| |1.1| _  |3.| |0.1|
152        // |0.| |3.1|   |1.| |2.1|   |2.| |1.1|    |3.| |0.1|
153        let input_mat = vec![
154            0., 0., 3.1, 3.1, 1., 1., 2.1, 1.1, 2., 2., 1.1, 1.1, 3., 3., 0.1, 0.1,
155        ];
156
157        let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
158        let adj_mat = adjacency_matrix(&input_arr, 1, &BallTree);
159        assert_eq!(adj_mat.nnz(), 16);
160
161        // I expext non-zeros in the diagonal and then:
162        // - point 0 to be neighbour of point 7 & vice versa
163        // - point 1 to be neighbour of point 6 & vice versa
164        // - point 2 to be neighbour of point 5 & vice versa
165        // - point 3 to be neighbour of point 4 & vice versa
166
167        for i in 0..8 {
168            assert_eq!(*adj_mat.get(i, i).unwrap() as usize, 1);
169            if i <= 3 {
170                assert_eq!(*adj_mat.get(i, 7 - i).unwrap() as usize, 1);
171                assert_eq!(*adj_mat.get(7 - i, i).unwrap() as usize, 1);
172            }
173        }
174    }
175
176    #[test]
177    #[should_panic]
178    fn sparse_panics_on_0_neighbours() {
179        let input_mat = [
180            [[0., 0.], [0.1, 0.1]],
181            [[1., 1.], [1.1, 1.1]],
182            [[2., 2.], [2.1, 2.1]],
183            [[3., 3.], [3.1, 3.1]],
184        ]
185        .concat()
186        .concat();
187        let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
188        let _ = adjacency_matrix(&input_arr, 0, &KdTree);
189    }
190
191    #[test]
192    #[should_panic]
193    fn sparse_panics_on_n_neighbours() {
194        let input_mat = [
195            [[0., 0.], [0.1, 0.1]],
196            [[1., 1.], [1.1, 1.1]],
197            [[2., 2.], [2.1, 2.1]],
198            [[3., 3.], [3.1, 3.1]],
199        ]
200        .concat()
201        .concat();
202        let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
203        let _ = adjacency_matrix(&input_arr, 8, &BallTree);
204    }
205}