linfa_pls/
utils.rs

1use linfa::{
2    dataset::{WithLapack, WithoutLapack},
3    DatasetBase, Float,
4};
5#[cfg(not(feature = "blas"))]
6use linfa_linalg::svd::*;
7use ndarray::{s, Array1, Array2, ArrayBase, ArrayView2, Axis, Data, DataMut, Ix1, Ix2, Zip};
8#[cfg(feature = "blas")]
9use ndarray_linalg::svd::*;
10use ndarray_stats::QuantileExt;
11
12pub fn outer<F: Float>(
13    a: &ArrayBase<impl Data<Elem = F>, Ix1>,
14    b: &ArrayBase<impl Data<Elem = F>, Ix1>,
15) -> Array2<F> {
16    let mut outer = Array2::zeros((a.len(), b.len()));
17    Zip::from(outer.rows_mut()).and(a).for_each(|mut out, ai| {
18        out.assign(&b.mapv(|v| *ai * v));
19    });
20    outer
21}
22
23/// Calculates the pseudo inverse of a matrix
24pub fn pinv2<F: Float>(x: ArrayView2<F>, cond: Option<F>) -> Array2<F> {
25    let x = x.with_lapack();
26    #[cfg(feature = "blas")]
27    let (opt_u, s, opt_vh) = x.svd(true, true).unwrap();
28    #[cfg(not(feature = "blas"))]
29    let (opt_u, s, opt_vh) = x.svd(true, true).unwrap().sort_svd_desc();
30    let u = opt_u.unwrap();
31    let vh = opt_vh.unwrap();
32
33    let cond = cond
34        .unwrap_or(F::cast(*s.max().unwrap()) * F::cast(x.nrows().max(x.ncols())) * F::epsilon());
35
36    let rank = s.fold(0, |mut acc, v| {
37        if F::cast(*v) > cond {
38            acc += 1
39        };
40        acc
41    });
42
43    let mut ucut = u.slice_move(s![.., ..rank]);
44    ucut /= &s.slice(s![..rank]).mapv(F::Lapack::cast);
45
46    vh.slice(s![..rank, ..]).t().dot(&ucut.t()).without_lapack()
47}
48
49#[allow(clippy::type_complexity)]
50pub fn center_scale_dataset<F: Float, D: Data<Elem = F>>(
51    dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>>,
52    scale: bool,
53) -> (
54    Array2<F>,
55    Array2<F>,
56    Array1<F>,
57    Array1<F>,
58    Array1<F>,
59    Array1<F>,
60) {
61    let (xnorm, x_mean, x_std) = center_scale(dataset.records(), scale);
62    let (ynorm, y_mean, y_std) = center_scale(dataset.targets(), scale);
63    (xnorm, ynorm, x_mean, y_mean, x_std, y_std)
64}
65
66fn center_scale<F: Float>(
67    x: &ArrayBase<impl Data<Elem = F>, Ix2>,
68    scale: bool,
69) -> (Array2<F>, Array1<F>, Array1<F>) {
70    let x_mean = x.mean_axis(Axis(0)).unwrap();
71    let (xnorm, x_std) = if scale {
72        let mut x_std = x.std_axis(Axis(0), F::one());
73        x_std.mapv_inplace(|v| if v == F::zero() { F::one() } else { v });
74        ((x - &x_mean) / &x_std, x_std)
75    } else {
76        ((x - &x_mean), Array1::ones(x.ncols()))
77    };
78
79    (xnorm, x_mean, x_std)
80}
81
82pub fn svd_flip_1d<F: Float>(
83    x_weights: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
84    y_weights: &mut ArrayBase<impl DataMut<Elem = F>, Ix1>,
85) {
86    let biggest_abs_val_idx = x_weights.mapv(|v| v.abs()).argmax().unwrap();
87    let sign: F = x_weights[biggest_abs_val_idx].signum();
88    x_weights.map_inplace(|v| *v *= sign);
89    y_weights.map_inplace(|v| *v *= sign);
90}
91
92pub fn svd_flip<F: Float>(
93    u: ArrayBase<impl Data<Elem = F>, Ix2>,
94    v: ArrayBase<impl Data<Elem = F>, Ix2>,
95) -> (Array2<F>, Array2<F>) {
96    // columns of u, rows of v
97    let abs_u = u.mapv(|v| v.abs());
98    let max_abs_val_indices = abs_u.map_axis(Axis(0), |col| col.argmax().unwrap());
99    let mut signs = Array1::<F>::zeros(u.ncols());
100    let range: Vec<usize> = (0..u.ncols()).collect();
101    Zip::from(&mut signs)
102        .and(&max_abs_val_indices)
103        .and(&range)
104        .for_each(|s, &i, &j| *s = u[[i, j]].signum());
105    (&u * &signs, &v * &signs.insert_axis(Axis(1)))
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use approx::assert_abs_diff_eq;
112    use ndarray::array;
113
114    #[test]
115    fn test_outer() {
116        let a = array![1., 2., 3.];
117        let b = array![2., 3.];
118        let expected = array![[2., 3.], [4., 6.], [6., 9.]];
119        assert_abs_diff_eq!(expected, outer(&a, &b));
120    }
121
122    #[test]
123    fn test_pinv2() {
124        let a = array![[1., 2., 3.], [4., 5., 6.], [7., 8., 10.]];
125        let a_pinv2 = pinv2(a.view(), None);
126        assert_abs_diff_eq!(a.dot(&a_pinv2), Array2::eye(3), epsilon = 1e-6)
127    }
128}