linfa_reduction/random_projection/
algorithms.rs1use std::marker::PhantomData;
2
3use linfa::{
4 dataset::{AsTargets, FromTargetArray},
5 prelude::Records,
6 traits::{Fit, Transformer},
7 DatasetBase, Float,
8};
9use ndarray::{linalg::Dot, Array2, ArrayBase, Data, Ix2};
10
11use rand::{prelude::Distribution, Rng, SeedableRng};
12use rand_xoshiro::Xoshiro256Plus;
13
14use super::hyperparams::RandomProjectionParamsInner;
15use super::{common::johnson_lindenstrauss_min_dim, methods::ProjectionMethod};
16use super::{RandomProjectionParams, RandomProjectionValidParams};
17use crate::ReductionError;
18
19pub struct RandomProjection<Proj: ProjectionMethod, F: Float>
21where
22 Proj::RandomDistribution: Distribution<F>,
23{
24 projection: Proj::ProjectionMatrix<F>,
25}
26
27impl<F, Proj, Rec, T, R> Fit<Rec, T, ReductionError> for RandomProjectionValidParams<Proj, R>
28where
29 F: Float,
30 Proj: ProjectionMethod,
31 Rec: Records<Elem = F>,
32 R: Rng + Clone,
33 Proj::RandomDistribution: Distribution<F>,
34{
35 type Object = RandomProjection<Proj, F>;
36
37 fn fit(&self, dataset: &linfa::DatasetBase<Rec, T>) -> Result<Self::Object, ReductionError> {
38 let n_samples = dataset.nsamples();
39 let n_features = dataset.nfeatures();
40 let mut rng = self.rng.clone();
41
42 let n_dims = match &self.params {
43 RandomProjectionParamsInner::Dimension { target_dim } => *target_dim,
44 RandomProjectionParamsInner::Epsilon { eps } => {
45 johnson_lindenstrauss_min_dim(n_samples, *eps)
46 }
47 };
48
49 if n_dims > n_features {
50 return Err(ReductionError::DimensionIncrease(n_dims, n_features));
51 }
52
53 let projection = Proj::generate_matrix(n_features, n_dims, &mut rng)?;
54
55 Ok(RandomProjection { projection })
56 }
57}
58
59impl<Proj: ProjectionMethod, F: Float> RandomProjection<Proj, F>
60where
61 Proj::RandomDistribution: Distribution<F>,
62{
63 pub fn params() -> RandomProjectionParams<Proj, Xoshiro256Plus> {
66 RandomProjectionParams(RandomProjectionValidParams {
67 params: RandomProjectionParamsInner::Epsilon { eps: 0.1 },
68 rng: Xoshiro256Plus::seed_from_u64(42),
69 marker: PhantomData,
70 })
71 }
72
73 pub fn params_with_rng<R>(rng: R) -> RandomProjectionParams<Proj, R>
76 where
77 R: Rng + Clone,
78 {
79 RandomProjectionParams(RandomProjectionValidParams {
80 params: RandomProjectionParamsInner::Epsilon { eps: 0.1 },
81 rng,
82 marker: PhantomData,
83 })
84 }
85}
86
87impl<Proj, F, D> Transformer<&ArrayBase<D, Ix2>, Array2<F>> for RandomProjection<Proj, F>
88where
89 Proj: ProjectionMethod,
90 F: Float,
91 D: Data<Elem = F>,
92 ArrayBase<D, Ix2>: Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>,
93 Proj::RandomDistribution: Distribution<F>,
94{
95 fn transform(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
97 x.dot(&self.projection)
98 }
99}
100
101impl<Proj, F, D> Transformer<ArrayBase<D, Ix2>, Array2<F>> for RandomProjection<Proj, F>
102where
103 Proj: ProjectionMethod,
104 F: Float,
105 D: Data<Elem = F>,
106 ArrayBase<D, Ix2>: Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>,
107 Proj::RandomDistribution: Distribution<F>,
108{
109 fn transform(&self, x: ArrayBase<D, Ix2>) -> Array2<F> {
111 self.transform(&x)
112 }
113}
114
115impl<Proj, F, T> Transformer<DatasetBase<Array2<F>, T>, DatasetBase<Array2<F>, T>>
116 for RandomProjection<Proj, F>
117where
118 Proj: ProjectionMethod,
119 F: Float,
120 T: AsTargets,
121 for<'a> ArrayBase<ndarray::ViewRepr<&'a F>, Ix2>:
122 Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>,
123 Proj::RandomDistribution: Distribution<F>,
124{
125 fn transform(&self, data: DatasetBase<Array2<F>, T>) -> DatasetBase<Array2<F>, T> {
135 let new_records = self.transform(data.records().view());
136
137 DatasetBase::new(new_records, data.targets)
138 }
139}
140
141impl<'a, Proj, F, L, T> Transformer<&'a DatasetBase<Array2<F>, T>, DatasetBase<Array2<F>, T::View>>
142 for RandomProjection<Proj, F>
143where
144 Proj: ProjectionMethod,
145 F: Float,
146 L: 'a,
147 T: AsTargets<Elem = L> + FromTargetArray<'a>,
148 for<'b> ArrayBase<ndarray::ViewRepr<&'b F>, Ix2>:
149 Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>,
150 Proj::RandomDistribution: Distribution<F>,
151{
152 fn transform(&self, data: &'a DatasetBase<Array2<F>, T>) -> DatasetBase<Array2<F>, T::View> {
162 let new_records = self.transform(data.records().view());
163
164 DatasetBase::new(
165 new_records,
166 T::new_targets_view(AsTargets::as_targets(data)),
167 )
168 }
169}