linfa_ica/
fast_ica.rs

1//! Fast algorithm for Independent Component Analysis (ICA)
2
3use linfa::{
4    dataset::{DatasetBase, Records, WithLapack, WithoutLapack},
5    traits::*,
6    Float,
7};
8#[cfg(not(feature = "blas"))]
9use linfa_linalg::{eigh::*, svd::*};
10use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Data, Ix2};
11#[cfg(feature = "blas")]
12use ndarray_linalg::{eigh::Eigh, solveh::UPLO, svd::SVD};
13use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
14use ndarray_stats::QuantileExt;
15use rand_xoshiro::Xoshiro256Plus;
16#[cfg(feature = "serde")]
17use serde_crate::{Deserialize, Serialize};
18
19use crate::error::{FastIcaError, Result};
20use crate::hyperparams::FastIcaValidParams;
21
22impl<F: Float, D: Data<Elem = F>, T> Fit<ArrayBase<D, Ix2>, T, FastIcaError>
23    for FastIcaValidParams<F>
24{
25    type Object = FastIca<F>;
26
27    /// Fit the model
28    ///
29    /// # Errors
30    ///
31    /// If the [`FastIcaValidParams::ncomponents`] is set to a number greater than the minimum of
32    /// the number of rows and columns
33    ///
34    /// If the `alpha` value set for [`GFunc::Logcosh`] is not between 1 and 2
35    /// inclusive
36    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
37        let x = &dataset.records;
38        let (nsamples, nfeatures) = (x.nsamples(), x.nfeatures());
39        if dataset.nsamples() == 0 {
40            return Err(FastIcaError::NotEnoughSamples);
41        }
42
43        // If the number of components is not set, we take the minimum of
44        // the number of rows and columns
45        let ncomponents = self
46            .ncomponents()
47            .unwrap_or_else(|| nsamples.min(nfeatures));
48
49        // The number of components cannot be greater than the minimum of
50        // the number of rows and columns
51        if ncomponents > nsamples.min(nfeatures) {
52            return Err(FastIcaError::InvalidValue(format!(
53                "ncomponents cannot be greater than the min({nsamples}, {nfeatures}), got {ncomponents}"
54            )));
55        }
56
57        // We center the input by subtracting the mean of its features
58        // safe unwrap because we already returned an error on zero samples
59        let xmean = x.mean_axis(Axis(0)).unwrap();
60        let mut xcentered = x - &xmean.view().insert_axis(Axis(0));
61
62        // We transpose the centered matrix
63        xcentered = xcentered.reversed_axes();
64
65        // We whiten the matrix to remove any potential correlation between
66        // the components
67        let xcentered = xcentered.with_lapack();
68        let k = match xcentered.svd(true, false)? {
69            (Some(u), s, _) => {
70                let s = s.mapv(F::Lapack::cast);
71                // This slice operation will extract the "thin" SVD component of `u` regardless of
72                // whether `.svd` returns a full or thin SVD, because the slice dimensions
73                // correspond to the thin SVD dimensions.
74                (u.slice_move(s![.., ..nsamples.min(nfeatures)]) / s)
75                    .t()
76                    .slice(s![..ncomponents, ..])
77                    .to_owned()
78            }
79            _ => return Err(FastIcaError::SvdDecomposition),
80        };
81
82        let mut xwhitened = k.dot(&xcentered).without_lapack();
83        let k = k.without_lapack();
84
85        // We multiply the matrix with root of the number of records
86        let nsamples_sqrt = F::cast(nsamples).sqrt();
87        xwhitened.mapv_inplace(|x| x * nsamples_sqrt);
88
89        // We initialize the de-mixing matrix with a uniform distribution
90        let w: Array2<f64>;
91        if let Some(seed) = self.random_state() {
92            let mut rng = Xoshiro256Plus::seed_from_u64(*seed as u64);
93            w = Array::random_using((ncomponents, ncomponents), Uniform::new(0., 1.), &mut rng);
94        } else {
95            w = Array::random((ncomponents, ncomponents), Uniform::new(0., 1.));
96        }
97        let mut w = w.mapv(F::cast);
98
99        // We find the optimized de-mixing matrix
100        w = self.ica_parallel(&xwhitened, &w)?;
101
102        // We whiten the de-mixing matrix
103        let components = w.dot(&k);
104
105        Ok(FastIca {
106            mean: xmean,
107            components,
108        })
109    }
110}
111
112impl<F: Float> FastIcaValidParams<F> {
113    // Parallel FastICA, Optimization step
114    fn ica_parallel(&self, x: &Array2<F>, w: &Array2<F>) -> Result<Array2<F>> {
115        let mut w = Self::sym_decorrelation(w)?;
116
117        let p = x.ncols() as f64;
118
119        for _ in 0..self.max_iter() {
120            let (gwtx, g_wtx) = self.gfunc().exec(&w.dot(x))?;
121
122            let lhs = gwtx.dot(&x.t()).mapv(|x| x / F::cast(p));
123            let rhs = &w * &g_wtx.insert_axis(Axis(1));
124            let wnew = Self::sym_decorrelation(&(lhs - rhs))?;
125
126            // `lim` let us check for convergence between the old and
127            // new weight values, we want their dot-product to almost equal one
128            let lim = *wnew
129                .outer_iter()
130                .zip(w.outer_iter())
131                .map(|(a, b)| a.dot(&b))
132                .collect::<Array1<F>>()
133                .mapv(|x| x.abs())
134                .mapv(|x| x - F::cast(1.))
135                .mapv(|x| x.abs())
136                .max()
137                .unwrap();
138
139            w = wnew;
140
141            if lim < F::cast(self.tol()) {
142                break;
143            }
144        }
145
146        Ok(w)
147    }
148
149    // Symmetric decorrelation
150    //
151    // W <- (W * W.T)^{-1/2} * W
152    fn sym_decorrelation(w: &Array2<F>) -> Result<Array2<F>> {
153        #[cfg(feature = "blas")]
154        let (eig_val, eig_vec) = w.dot(&w.t()).with_lapack().eigh(UPLO::Upper)?;
155        #[cfg(not(feature = "blas"))]
156        let (eig_val, eig_vec) = w.dot(&w.t()).with_lapack().eigh()?;
157        let eig_val = eig_val.mapv(F::cast);
158        let eig_vec = eig_vec.without_lapack();
159
160        let tmp = &eig_vec
161            * &(eig_val.mapv(|x| x.sqrt()).mapv(|x| {
162                // We lower bound the float value at 1e-7 when taking the reciprocal
163                let lower_bound = F::cast(1e-7);
164                if x < lower_bound {
165                    return lower_bound.recip();
166                }
167                x.recip()
168            }))
169            .insert_axis(Axis(0));
170
171        Ok(tmp.dot(&eig_vec.t()).dot(w))
172    }
173}
174
175/// Fitted FastICA model for recovering the sources
176#[cfg_attr(
177    feature = "serde",
178    derive(Serialize, Deserialize),
179    serde(crate = "serde_crate")
180)]
181#[derive(Debug, Clone, PartialEq)]
182pub struct FastIca<F> {
183    mean: Array1<F>,
184    components: Array2<F>,
185}
186
187impl<F: Float> PredictInplace<Array2<F>, Array2<F>> for FastIca<F> {
188    /// Recover the sources
189    fn predict_inplace(&self, x: &Array2<F>, y: &mut Array2<F>) {
190        assert_eq!(
191            y.shape(),
192            &[x.nrows(), self.components.nrows()],
193            "The number of data points must match the number of output targets."
194        );
195
196        let xcentered = x - &self.mean.view().insert_axis(Axis(0));
197        *y = xcentered.dot(&self.components.t());
198    }
199
200    fn default_target(&self, x: &Array2<F>) -> Array2<F> {
201        Array2::zeros((x.nrows(), self.components.nrows()))
202    }
203}
204
205/// Some standard non-linear functions
206#[cfg_attr(
207    feature = "serde",
208    derive(Serialize, Deserialize),
209    serde(crate = "serde_crate")
210)]
211#[derive(Debug, Clone, PartialEq, Copy)]
212pub enum GFunc {
213    Logcosh(f64),
214    Exp,
215    Cube,
216}
217
218impl GFunc {
219    // Function to select the correct non-linear function and execute it
220    // returning a tuple, consisting of the first and second derivatives of the
221    // non-linear function
222    fn exec<A: Float>(&self, x: &Array2<A>) -> Result<(Array2<A>, Array1<A>)> {
223        match self {
224            Self::Cube => Ok(Self::cube(x)),
225            Self::Exp => Ok(Self::exp(x)),
226            Self::Logcosh(alpha) => Self::logcosh(x, *alpha),
227        }
228    }
229
230    fn cube<A: Float>(x: &Array2<A>) -> (Array2<A>, Array1<A>) {
231        (
232            x.mapv(|x| x.powi(3)),
233            x.mapv(|x| A::cast(3.) * x.powi(2))
234                .mean_axis(Axis(1))
235                .unwrap(),
236        )
237    }
238
239    fn exp<A: Float>(x: &Array2<A>) -> (Array2<A>, Array1<A>) {
240        let exp = x.mapv(|x| -x.powi(2) / A::cast(2.));
241        (
242            x * &exp,
243            (x.mapv(|x| A::cast(1.) - x.powi(2)) * &exp)
244                .mean_axis(Axis(1))
245                .unwrap(),
246        )
247    }
248
249    //#[allow(clippy::manual_range_contains)]
250    fn logcosh<A: Float>(x: &Array2<A>, alpha: f64) -> Result<(Array2<A>, Array1<A>)> {
251        //if alpha < 1.0 || alpha > 2.0 {
252        if !(1.0..=2.0).contains(&alpha) {
253            return Err(FastIcaError::InvalidValue(format!(
254                "alpha must be between 1 and 2 inclusive, got {alpha}"
255            )));
256        }
257        let alpha = A::cast(alpha);
258
259        let gx = x.mapv(|x| (x * alpha).tanh());
260        let g_x = gx.mapv(|x| alpha * (A::cast(1.) - x.powi(2)));
261
262        Ok((gx, g_x.mean_axis(Axis(1)).unwrap()))
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use linfa::traits::{Fit, Predict};
270
271    use crate::hyperparams::{FastIcaParams, FastIcaValidParams};
272    use ndarray_rand::rand_distr::StudentT;
273
274    #[test]
275    fn autotraits() {
276        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
277        has_autotraits::<FastIca<f64>>();
278        has_autotraits::<GFunc>();
279        has_autotraits::<FastIcaParams<f64>>();
280        has_autotraits::<FastIcaValidParams<f64>>();
281        has_autotraits::<FastIcaError>();
282    }
283
284    // Test to make sure the number of components set cannot be greater
285    // that the minimum of the number of rows and columns of the input
286    #[test]
287    fn test_ncomponents_err() {
288        let input = DatasetBase::from(Array::random((4, 4), Uniform::new(0.0, 1.0)));
289        let ica = FastIca::params().ncomponents(100);
290        let ica = ica.fit(&input);
291        assert!(ica.is_err());
292    }
293
294    // Test to make sure the alpha value of the `GFunc::Logcosh` is between
295    // 1 and 2 inclusive
296    #[test]
297    fn test_logcosh_alpha_err() {
298        let input = DatasetBase::from(Array::random((4, 4), Uniform::new(0.0, 1.0)));
299        let ica = FastIca::params().gfunc(GFunc::Logcosh(10.));
300        let ica = ica.fit(&input);
301        assert!(ica.is_err());
302    }
303
304    // Helper macro that produces test-cases with the pattern test_fast_ica_*
305    macro_rules! fast_ica_tests {
306        ($($name:ident: $gfunc:expr,)*) => {
307            paste::item! {
308                $(
309                    #[test]
310                    fn [<test_fast_ica_$name>]() {
311                        test_fast_ica($gfunc);
312                    }
313                )*
314            }
315        }
316    }
317
318    // Tests to make sure all of the `GFunc`'s non-linear functions and the
319    // model itself performs well
320    fast_ica_tests! {
321        exp: GFunc::Exp, cube: GFunc::Cube, logcosh: GFunc::Logcosh(1.0),
322    }
323
324    // Helper function that mixes two signal sources sends it to FastICA
325    // and makes sure the model can demix them with considerable amount of
326    // accuracy
327    fn test_fast_ica(gfunc: GFunc) {
328        let nsamples = 1000;
329
330        // Center the data and make it have unit variance
331        let center_and_norm = |s: &mut Array2<f64>| {
332            let mean = s.mean_axis(Axis(0)).unwrap();
333            *s -= &mean.insert_axis(Axis(0));
334            let std = s.std_axis(Axis(0), 0.);
335            *s /= &std.insert_axis(Axis(0));
336        };
337
338        // Creaing a sawtooth signal
339        let mut source1 = Array::linspace(0., 100., nsamples);
340        source1.mapv_inplace(|x| {
341            let tmp = 2. * f64::sin(x);
342            if tmp > 0. {
343                return 0.;
344            }
345            -1.
346        });
347
348        // Creating noise using Student T distribution
349        let mut rng = Xoshiro256Plus::seed_from_u64(42);
350        let source2 = Array::random_using((nsamples, 1), StudentT::new(1.0).unwrap(), &mut rng);
351
352        // Column concatenating both the sources
353        let mut sources = concatenate![Axis(1), source1.insert_axis(Axis(1)), source2];
354        center_and_norm(&mut sources);
355
356        // Mixing the two sources
357        let phi: f64 = 0.6;
358        let mixing = array![[phi.cos(), phi.sin()], [phi.sin(), -phi.cos()]];
359        sources = mixing.dot(&sources.t());
360        center_and_norm(&mut sources);
361
362        sources = sources.reversed_axes();
363
364        // We fit and transform using the model to unmix the two sources
365        let ica = FastIca::params()
366            .ncomponents(2)
367            .gfunc(gfunc)
368            .random_state(42);
369
370        let sources_dataset = DatasetBase::from(sources.view());
371        let ica = ica.fit(&sources_dataset).unwrap();
372        let mut output = ica.predict(&sources);
373
374        center_and_norm(&mut output);
375
376        // Making sure the model output has the right shape
377        assert_eq!(output.shape(), &[1000, 2]);
378
379        // The order of the sources in the ICA output is not deterministic,
380        // so we account for that here
381        let s1 = sources.column(0);
382        let s2 = sources.column(1);
383        let mut s1_ = output.column(0);
384        let mut s2_ = output.column(1);
385        if s1_.dot(&s2).abs() > s1_.dot(&s1).abs() {
386            s1_ = output.column(1);
387            s2_ = output.column(0);
388        }
389
390        let similarity1 = s1.dot(&s1_).abs() / (nsamples as f64);
391        let similarity2 = s2.dot(&s2_).abs() / (nsamples as f64);
392
393        // We make sure the saw tooth signal identified by ICA using the mixed
394        // source is similar to the original sawtooth signal
395        // We ignore the noise signal's similarity measure
396        assert!(similarity1.max(similarity2) > 0.9);
397    }
398}