linfa_reduction/random_projection/
algorithms.rs

1use std::marker::PhantomData;
2
3use linfa::{
4    dataset::{AsTargets, FromTargetArray},
5    prelude::Records,
6    traits::{Fit, Transformer},
7    DatasetBase, Float,
8};
9use ndarray::{linalg::Dot, Array2, ArrayBase, Data, Ix2};
10
11use rand::{prelude::Distribution, Rng, SeedableRng};
12use rand_xoshiro::Xoshiro256Plus;
13
14use super::hyperparams::RandomProjectionParamsInner;
15use super::{common::johnson_lindenstrauss_min_dim, methods::ProjectionMethod};
16use super::{RandomProjectionParams, RandomProjectionValidParams};
17use crate::ReductionError;
18
19/// Embedding via random projection
20pub struct RandomProjection<Proj: ProjectionMethod, F: Float>
21where
22    Proj::RandomDistribution: Distribution<F>,
23{
24    projection: Proj::ProjectionMatrix<F>,
25}
26
27impl<F, Proj, Rec, T, R> Fit<Rec, T, ReductionError> for RandomProjectionValidParams<Proj, R>
28where
29    F: Float,
30    Proj: ProjectionMethod,
31    Rec: Records<Elem = F>,
32    R: Rng + Clone,
33    Proj::RandomDistribution: Distribution<F>,
34{
35    type Object = RandomProjection<Proj, F>;
36
37    fn fit(&self, dataset: &linfa::DatasetBase<Rec, T>) -> Result<Self::Object, ReductionError> {
38        let n_samples = dataset.nsamples();
39        let n_features = dataset.nfeatures();
40        let mut rng = self.rng.clone();
41
42        let n_dims = match &self.params {
43            RandomProjectionParamsInner::Dimension { target_dim } => *target_dim,
44            RandomProjectionParamsInner::Epsilon { eps } => {
45                johnson_lindenstrauss_min_dim(n_samples, *eps)
46            }
47        };
48
49        if n_dims > n_features {
50            return Err(ReductionError::DimensionIncrease(n_dims, n_features));
51        }
52
53        let projection = Proj::generate_matrix(n_features, n_dims, &mut rng)?;
54
55        Ok(RandomProjection { projection })
56    }
57}
58
59impl<Proj: ProjectionMethod, F: Float> RandomProjection<Proj, F>
60where
61    Proj::RandomDistribution: Distribution<F>,
62{
63    /// Create new parameters for a [`RandomProjection`] with default value
64    /// `eps = 0.1` and a [`Xoshiro256Plus`] RNG.
65    pub fn params() -> RandomProjectionParams<Proj, Xoshiro256Plus> {
66        RandomProjectionParams(RandomProjectionValidParams {
67            params: RandomProjectionParamsInner::Epsilon { eps: 0.1 },
68            rng: Xoshiro256Plus::seed_from_u64(42),
69            marker: PhantomData,
70        })
71    }
72
73    /// Create new parameters for a [`RandomProjection`] with default values
74    /// `eps = 0.1` and the provided [`Rng`].
75    pub fn params_with_rng<R>(rng: R) -> RandomProjectionParams<Proj, R>
76    where
77        R: Rng + Clone,
78    {
79        RandomProjectionParams(RandomProjectionValidParams {
80            params: RandomProjectionParamsInner::Epsilon { eps: 0.1 },
81            rng,
82            marker: PhantomData,
83        })
84    }
85}
86
87impl<Proj, F, D> Transformer<&ArrayBase<D, Ix2>, Array2<F>> for RandomProjection<Proj, F>
88where
89    Proj: ProjectionMethod,
90    F: Float,
91    D: Data<Elem = F>,
92    ArrayBase<D, Ix2>: Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>,
93    Proj::RandomDistribution: Distribution<F>,
94{
95    /// Compute the embedding of a two-dimensional array
96    fn transform(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
97        x.dot(&self.projection)
98    }
99}
100
101impl<Proj, F, D> Transformer<ArrayBase<D, Ix2>, Array2<F>> for RandomProjection<Proj, F>
102where
103    Proj: ProjectionMethod,
104    F: Float,
105    D: Data<Elem = F>,
106    ArrayBase<D, Ix2>: Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>,
107    Proj::RandomDistribution: Distribution<F>,
108{
109    /// Compute the embedding of a two-dimensional array
110    fn transform(&self, x: ArrayBase<D, Ix2>) -> Array2<F> {
111        self.transform(&x)
112    }
113}
114
115impl<Proj, F, T> Transformer<DatasetBase<Array2<F>, T>, DatasetBase<Array2<F>, T>>
116    for RandomProjection<Proj, F>
117where
118    Proj: ProjectionMethod,
119    F: Float,
120    T: AsTargets,
121    for<'a> ArrayBase<ndarray::ViewRepr<&'a F>, Ix2>:
122        Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>,
123    Proj::RandomDistribution: Distribution<F>,
124{
125    /// Compute the embedding of a dataset
126    ///
127    /// # Parameter
128    ///
129    /// * `data`: a dataset
130    ///
131    /// # Returns
132    ///
133    /// New dataset, with data equal to the embedding of the input data
134    fn transform(&self, data: DatasetBase<Array2<F>, T>) -> DatasetBase<Array2<F>, T> {
135        let new_records = self.transform(data.records().view());
136
137        DatasetBase::new(new_records, data.targets)
138    }
139}
140
141impl<'a, Proj, F, L, T> Transformer<&'a DatasetBase<Array2<F>, T>, DatasetBase<Array2<F>, T::View>>
142    for RandomProjection<Proj, F>
143where
144    Proj: ProjectionMethod,
145    F: Float,
146    L: 'a,
147    T: AsTargets<Elem = L> + FromTargetArray<'a>,
148    for<'b> ArrayBase<ndarray::ViewRepr<&'b F>, Ix2>:
149        Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>,
150    Proj::RandomDistribution: Distribution<F>,
151{
152    /// Compute the embedding of a dataset
153    ///
154    /// # Parameter
155    ///
156    /// * `data`: a dataset
157    ///
158    /// # Returns
159    ///
160    /// New dataset, with data equal to the embedding of the input data
161    fn transform(&self, data: &'a DatasetBase<Array2<F>, T>) -> DatasetBase<Array2<F>, T::View> {
162        let new_records = self.transform(data.records().view());
163
164        DatasetBase::new(
165            new_records,
166            T::new_targets_view(AsTargets::as_targets(data)),
167        )
168    }
169}