linfa/dataset/
lapack_bounds.rs

1use crate::Float;
2use ndarray::{ArrayBase, Data, Dimension, OwnedRepr, ViewRepr};
3
4/// Add the Lapack bound to the floating point of a dataset
5///
6/// This helper trait is introduced to avoid leaking `Lapack + Scalar` bounds to the outside which
7/// causes ambiguities when calling functions like `abs` for `num_traits::Float` and
8/// `Cauchy::Scalar`. We are only using real values here, but the LAPACK routines
9/// require that `Cauchy::Scalar` is implemented.
10pub trait WithLapack<D: Data + WithLapackData, I: Dimension> {
11    fn with_lapack(self) -> ArrayBase<D::D, I>;
12}
13
14impl<F: Float, D, I> WithLapack<D, I> for ArrayBase<D, I>
15where
16    D: Data<Elem = F> + WithLapackData,
17    I: Dimension,
18{
19    fn with_lapack(self) -> ArrayBase<D::D, I> {
20        D::with_lapack(self)
21    }
22}
23
24/// Remove the Lapack bound to the floating point of a dataset
25///
26/// This helper trait is introduced to avoid leaking `Lapack + Scalar` bounds to the outside which
27/// causes ambiguities when calling functions like `abs` for `num_traits::Float` and
28/// `Cauchy::Scalar`. We are only using real values here, but the LAPACK routines
29/// require that `Cauchy::Scalar` is implemented.
30pub trait WithoutLapack<F: Float, D: Data + WithoutLapackData<F>, I: Dimension> {
31    fn without_lapack(self) -> ArrayBase<D::D, I>;
32}
33
34impl<F: Float, D, I> WithoutLapack<F, D, I> for ArrayBase<D, I>
35where
36    D: Data<Elem = F::Lapack> + WithoutLapackData<F>,
37    I: Dimension,
38{
39    fn without_lapack(self) -> ArrayBase<D::D, I> {
40        D::without_lapack(self)
41    }
42}
43
44unsafe fn transmute<A, B>(a: A) -> B {
45    let b = std::ptr::read(&a as *const A as *const B);
46    std::mem::forget(a);
47
48    b
49}
50
51pub trait WithLapackData
52where
53    Self: Data,
54{
55    type D: Data;
56
57    /// Add trait bound `Lapack` and `Scalar` to NdArray's floating point
58    ///
59    /// This is safe, because only implemented for D == Self
60    fn with_lapack<I>(x: ArrayBase<Self, I>) -> ArrayBase<Self::D, I>
61    where
62        I: Dimension,
63    {
64        unsafe { transmute(x) }
65    }
66}
67
68impl<F: Float> WithLapackData for OwnedRepr<F> {
69    type D = OwnedRepr<F::Lapack>;
70}
71
72impl<'a, F: Float> WithLapackData for ViewRepr<&'a F> {
73    type D = ViewRepr<&'a F::Lapack>;
74}
75
76impl<'a, F: Float> WithLapackData for ViewRepr<&'a mut F> {
77    type D = ViewRepr<&'a mut F::Lapack>;
78}
79
80pub trait WithoutLapackData<F: Float>
81where
82    Self: Data,
83{
84    type D: Data<Elem = F>;
85
86    /// Add trait bound `Lapack` and `Scalar` to NdArray's floating point
87    ///
88    /// This is safe, because only implemented for D == Self
89    fn without_lapack<I>(x: ArrayBase<Self, I>) -> ArrayBase<Self::D, I>
90    where
91        I: Dimension,
92    {
93        unsafe { transmute(x) }
94    }
95}
96
97impl<F: Float> WithoutLapackData<F> for OwnedRepr<F::Lapack> {
98    type D = OwnedRepr<F>;
99}
100
101impl<'a, F: Float> WithoutLapackData<F> for ViewRepr<&'a F::Lapack> {
102    type D = ViewRepr<&'a F>;
103}
104
105impl<'a, F: Float> WithoutLapackData<F> for ViewRepr<&'a mut F::Lapack> {
106    type D = ViewRepr<&'a mut F>;
107}
108
109#[cfg(test)]
110mod tests {
111    use ndarray::Array2;
112    #[cfg(feature = "ndarray-linalg")]
113    use ndarray_linalg::eig::*;
114
115    use super::{WithLapack, WithoutLapack};
116
117    #[test]
118    fn memory_check() {
119        let a: Array2<f32> = Array2::zeros((20, 20));
120        let a: Array2<f32> = a.with_lapack();
121
122        assert_eq!(a.shape(), &[20, 20]);
123
124        let b: Array2<f32> = a.clone().without_lapack();
125
126        assert_eq!(b, a);
127    }
128
129    #[cfg(feature = "ndarray-linalg")]
130    #[test]
131    fn lapack_exists() {
132        let a: Array2<f32> = Array2::zeros((4, 4));
133        let a: Array2<f32> = a.with_lapack();
134
135        let (_a, _b) = a.eig().unwrap();
136    }
137}