1use linfa::Float;
2use linfa_nn::{distance::L2Dist, NearestNeighbour};
3use ndarray::{ArrayBase, Axis, Data, Ix2};
4use sprs::{CsMat, CsMatBase};
5
6pub fn adjacency_matrix<F: Float, DT: Data<Elem = F>, N: NearestNeighbour>(
8 dataset: &ArrayBase<DT, Ix2>,
9 k: usize,
10 nn_algo: &N,
11) -> CsMat<F> {
12 let n_points = dataset.len_of(Axis(0));
13
14 assert!(k < n_points);
17 assert!(k > 0);
18
19 let nn = nn_algo
20 .from_batch(dataset, L2Dist)
21 .expect("Unexpected nearest neighbour error");
22
23 let mut data = Vec::with_capacity(n_points * (k + 1));
28 let mut indptr = Vec::with_capacity(n_points + 1);
29 let mut indices = Vec::with_capacity(n_points * (k + 1));
31 indptr.push(0);
32
33 let mut added = 0;
35 for (m, feature) in dataset.rows().into_iter().enumerate() {
36 let mut neighbours = nn.k_nearest(feature, k + 1).unwrap();
37
38 neighbours.sort_unstable_by_key(|(_, i)| *i);
42
43 indices.push(m);
44 data.push(F::one());
45 added += 1;
46
47 for &(_, i) in &neighbours {
49 if m != i {
50 indices.push(i);
51 data.push(F::one());
52 added += 1;
53 }
54 }
55
56 indptr.push(added);
57 }
58
59 let mat = CsMatBase::new_from_unsorted((n_points, n_points), indptr, indices, data).unwrap();
61 let transpose = mat.transpose_view().to_other_storage();
62 let mut mat = sprs::binop::csmat_binop(mat.view(), transpose.view(), |x, y| x.add(*y));
63
64 let val: F = F::one();
66 mat.map_inplace(|_| val);
67
68 mat
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use linfa_nn::{BallTree, KdTree};
75 use ndarray::Array2;
76
77 #[test]
78 #[allow(clippy::if_same_then_else)]
79 fn adjacency_matrix_test() {
80 let input_mat = vec![
84 0., 0., 0.1, 0.1, 1., 1., 1.1, 1.1, 2., 2., 2.1, 2.1, 3., 3., 3.1, 3.1,
85 ];
86 let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
87 let adj_mat = adjacency_matrix(&input_arr, 1, &KdTree);
91 assert_eq!(adj_mat.nnz(), 16);
92
93 for i in 0..8 {
94 for j in 0..8 {
95 if i == j {
97 assert_eq!(*adj_mat.get(i, j).unwrap() as usize, 1);
98 } else if i % 2 == 0 && j == i + 1 {
100 assert_eq!(*adj_mat.get(i, j).unwrap() as usize, 1);
101 } else if j % 2 == 0 && i == j + 1 {
103 assert_eq!(*adj_mat.get(i, j).unwrap() as usize, 1);
104 } else {
106 assert_eq!(adj_mat.get(i, j), None);
109 }
110 }
111 }
112
113 let adj_mat = adjacency_matrix(&input_arr, 2, &KdTree);
117 assert_eq!(adj_mat.nnz(), 26);
118
119 for i in 0..8 {
121 assert_eq!(*adj_mat.get(i, i).unwrap() as usize, 1);
122 }
123
124 for i in 1..7 {
127 assert_eq!(*adj_mat.get(i, i + 1).unwrap() as usize, 1);
128 assert_eq!(*adj_mat.get(i, i - 1).unwrap() as usize, 1);
129 }
130
131 assert_eq!(*adj_mat.get(0, 1).unwrap() as usize, 1);
135 assert_eq!(*adj_mat.get(0, 2).unwrap() as usize, 1);
136 assert_eq!(*adj_mat.get(7, 6).unwrap() as usize, 1);
137 assert_eq!(*adj_mat.get(7, 5).unwrap() as usize, 1);
138
139 assert_eq!(*adj_mat.get(0, 2).unwrap() as usize, 1);
143 assert_eq!(*adj_mat.get(7, 5).unwrap() as usize, 1);
144
145 }
147
148 #[test]
149 fn adjacency_matrix_test_2() {
150 let input_mat = vec![
154 0., 0., 3.1, 3.1, 1., 1., 2.1, 1.1, 2., 2., 1.1, 1.1, 3., 3., 0.1, 0.1,
155 ];
156
157 let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
158 let adj_mat = adjacency_matrix(&input_arr, 1, &BallTree);
159 assert_eq!(adj_mat.nnz(), 16);
160
161 for i in 0..8 {
168 assert_eq!(*adj_mat.get(i, i).unwrap() as usize, 1);
169 if i <= 3 {
170 assert_eq!(*adj_mat.get(i, 7 - i).unwrap() as usize, 1);
171 assert_eq!(*adj_mat.get(7 - i, i).unwrap() as usize, 1);
172 }
173 }
174 }
175
176 #[test]
177 #[should_panic]
178 fn sparse_panics_on_0_neighbours() {
179 let input_mat = [
180 [[0., 0.], [0.1, 0.1]],
181 [[1., 1.], [1.1, 1.1]],
182 [[2., 2.], [2.1, 2.1]],
183 [[3., 3.], [3.1, 3.1]],
184 ]
185 .concat()
186 .concat();
187 let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
188 let _ = adjacency_matrix(&input_arr, 0, &KdTree);
189 }
190
191 #[test]
192 #[should_panic]
193 fn sparse_panics_on_n_neighbours() {
194 let input_mat = [
195 [[0., 0.], [0.1, 0.1]],
196 [[1., 1.], [1.1, 1.1]],
197 [[2., 2.], [2.1, 2.1]],
198 [[3., 3.], [3.1, 3.1]],
199 ]
200 .concat()
201 .concat();
202 let input_arr = Array2::from_shape_vec((8, 2), input_mat).unwrap();
203 let _ = adjacency_matrix(&input_arr, 8, &BallTree);
204 }
205}