1use linfa::{
2 dataset::{WithLapack, WithoutLapack},
3 DatasetBase, Float,
4};
5#[cfg(not(feature = "blas"))]
6use linfa_linalg::svd::*;
7use ndarray::{s, Array1, Array2, ArrayBase, ArrayView2, Axis, Data, DataMut, Ix1, Ix2, Zip};
8#[cfg(feature = "blas")]
9use ndarray_linalg::svd::*;
10use ndarray_stats::QuantileExt;
11
12pub fn outer<F: Float>(
13 a: &ArrayBase<impl Data<Elem = F>, Ix1>,
14 b: &ArrayBase<impl Data<Elem = F>, Ix1>,
15) -> Array2<F> {
16 let mut outer = Array2::zeros((a.len(), b.len()));
17 Zip::from(outer.rows_mut()).and(a).for_each(|mut out, ai| {
18 out.assign(&b.mapv(|v| *ai * v));
19 });
20 outer
21}
22
23pub fn pinv2<F: Float>(x: ArrayView2<F>, cond: Option<F>) -> Array2<F> {
25 let x = x.with_lapack();
26 #[cfg(feature = "blas")]
27 let (opt_u, s, opt_vh) = x.svd(true, true).unwrap();
28 #[cfg(not(feature = "blas"))]
29 let (opt_u, s, opt_vh) = x.svd(true, true).unwrap().sort_svd_desc();
30 let u = opt_u.unwrap();
31 let vh = opt_vh.unwrap();
32
33 let cond = cond
34 .unwrap_or(F::cast(*s.max().unwrap()) * F::cast(x.nrows().max(x.ncols())) * F::epsilon());
35
36 let rank = s.fold(0, |mut acc, v| {
37 if F::cast(*v) > cond {
38 acc += 1
39 };
40 acc
41 });
42
43 let mut ucut = u.slice_move(s![.., ..rank]);
44 ucut /= &s.slice(s![..rank]).mapv(F::Lapack::cast);
45
46 vh.slice(s![..rank, ..]).t().dot(&ucut.t()).without_lapack()
47}
48
49#[allow(clippy::type_complexity)]
50pub fn center_scale_dataset<F: Float, D: Data<Elem = F>>(
51 dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
52 scale: bool,
53) -> (
54 Array2<F>,
55 Array2<F>,
56 Array1<F>,
57 Array1<F>,
58 Array1<F>,
59 Array1<F>,
60) {
61 let (xnorm, x_mean, x_std) = center_scale(dataset.records(), scale);
62 let (ynorm, y_mean, y_std) = center_scale(dataset.targets(), scale);
63 (xnorm, ynorm, x_mean, y_mean, x_std, y_std)
64}
65
66fn center_scale<F: Float>(
67 x: &ArrayBase<impl Data<Elem = F>, Ix2>,
68 scale: bool,
69) -> (Array2<F>, Array1<F>, Array1<F>) {
70 let x_mean = x.mean_axis(Axis(0)).unwrap();
71 let (xnorm, x_std) = if scale {
72 let mut x_std = x.std_axis(Axis(0), F::one());
73 x_std.mapv_inplace(|v| if v == F::zero() { F::one() } else { v });
74 ((x - &x_mean) / &x_std, x_std)
75 } else {
76 ((x - &x_mean), Array1::ones(x.ncols()))
77 };
78
79 (xnorm, x_mean, x_std)
80}
81
82pub fn svd_flip_1d<F: Float>(
83 x_weights: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
84 y_weights: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
85) {
86 let biggest_abs_val_idx = x_weights.mapv(|v| v.abs()).argmax().unwrap();
87 let sign: F = x_weights[biggest_abs_val_idx].signum();
88 x_weights.map_inplace(|v| *v *= sign);
89 y_weights.map_inplace(|v| *v *= sign);
90}
91
92pub fn svd_flip<F: Float>(
93 u: ArrayBase<impl Data<Elem = F>, Ix2>,
94 v: ArrayBase<impl Data<Elem = F>, Ix2>,
95) -> (Array2<F>, Array2<F>) {
96 let abs_u = u.mapv(|v| v.abs());
98 let max_abs_val_indices = abs_u.map_axis(Axis(0), |col| col.argmax().unwrap());
99 let mut signs = Array1::<F>::zeros(u.ncols());
100 let range: Vec<usize> = (0..u.ncols()).collect();
101 Zip::from(&mut signs)
102 .and(&max_abs_val_indices)
103 .and(&range)
104 .for_each(|s, &i, &j| *s = u[[i, j]].signum());
105 (&u * &signs, &v * &signs.insert_axis(Axis(1)))
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use approx::assert_abs_diff_eq;
112 use ndarray::array;
113
114 #[test]
115 fn test_outer() {
116 let a = array![1., 2., 3.];
117 let b = array![2., 3.];
118 let expected = array![[2., 3.], [4., 6.], [6., 9.]];
119 assert_abs_diff_eq!(expected, outer(&a, &b));
120 }
121
122 #[test]
123 fn test_pinv2() {
124 let a = array![[1., 2., 3.], [4., 5., 6.], [7., 8., 10.]];
125 let a_pinv2 = pinv2(a.view(), None);
126 assert_abs_diff_eq!(a.dot(&a_pinv2), Array2::eye(3), epsilon = 1e-6)
127 }
128}