linfa/dataset/
lapack_bounds.rs1use crate::Float;
2use ndarray::{ArrayBase, Data, Dimension, OwnedRepr, ViewRepr};
3
4pub trait WithLapack<D: Data + WithLapackData, I: Dimension> {
11 fn with_lapack(self) -> ArrayBase<D::D, I>;
12}
13
14impl<F: Float, D, I> WithLapack<D, I> for ArrayBase<D, I>
15where
16 D: Data<Elem = F> + WithLapackData,
17 I: Dimension,
18{
19 fn with_lapack(self) -> ArrayBase<D::D, I> {
20 D::with_lapack(self)
21 }
22}
23
24pub trait WithoutLapack<F: Float, D: Data + WithoutLapackData<F>, I: Dimension> {
31 fn without_lapack(self) -> ArrayBase<D::D, I>;
32}
33
34impl<F: Float, D, I> WithoutLapack<F, D, I> for ArrayBase<D, I>
35where
36 D: Data<Elem = F::Lapack> + WithoutLapackData<F>,
37 I: Dimension,
38{
39 fn without_lapack(self) -> ArrayBase<D::D, I> {
40 D::without_lapack(self)
41 }
42}
43
44unsafe fn transmute<A, B>(a: A) -> B {
45 let b = std::ptr::read(&a as *const A as *const B);
46 std::mem::forget(a);
47
48 b
49}
50
51pub trait WithLapackData
52where
53 Self: Data,
54{
55 type D: Data;
56
57 fn with_lapack<I>(x: ArrayBase<Self, I>) -> ArrayBase<Self::D, I>
61 where
62 I: Dimension,
63 {
64 unsafe { transmute(x) }
65 }
66}
67
68impl<F: Float> WithLapackData for OwnedRepr<F> {
69 type D = OwnedRepr<F::Lapack>;
70}
71
72impl<'a, F: Float> WithLapackData for ViewRepr<&'a F> {
73 type D = ViewRepr<&'a F::Lapack>;
74}
75
76impl<'a, F: Float> WithLapackData for ViewRepr<&'a mut F> {
77 type D = ViewRepr<&'a mut F::Lapack>;
78}
79
80pub trait WithoutLapackData<F: Float>
81where
82 Self: Data,
83{
84 type D: Data<Elem = F>;
85
86 fn without_lapack<I>(x: ArrayBase<Self, I>) -> ArrayBase<Self::D, I>
90 where
91 I: Dimension,
92 {
93 unsafe { transmute(x) }
94 }
95}
96
97impl<F: Float> WithoutLapackData<F> for OwnedRepr<F::Lapack> {
98 type D = OwnedRepr<F>;
99}
100
101impl<'a, F: Float> WithoutLapackData<F> for ViewRepr<&'a F::Lapack> {
102 type D = ViewRepr<&'a F>;
103}
104
105impl<'a, F: Float> WithoutLapackData<F> for ViewRepr<&'a mut F::Lapack> {
106 type D = ViewRepr<&'a mut F>;
107}
108
109#[cfg(test)]
110mod tests {
111 use ndarray::Array2;
112 #[cfg(feature = "ndarray-linalg")]
113 use ndarray_linalg::eig::*;
114
115 use super::{WithLapack, WithoutLapack};
116
117 #[test]
118 fn memory_check() {
119 let a: Array2<f32> = Array2::zeros((20, 20));
120 let a: Array2<f32> = a.with_lapack();
121
122 assert_eq!(a.shape(), &[20, 20]);
123
124 let b: Array2<f32> = a.clone().without_lapack();
125
126 assert_eq!(b, a);
127 }
128
129 #[cfg(feature = "ndarray-linalg")]
130 #[test]
131 fn lapack_exists() {
132 let a: Array2<f32> = Array2::zeros((4, 4));
133 let a: Array2<f32> = a.with_lapack();
134
135 let (_a, _b) = a.eig().unwrap();
136 }
137}