1use linfa::{
4 dataset::{DatasetBase, Records, WithLapack, WithoutLapack},
5 traits::*,
6 Float,
7};
8#[cfg(not(feature = "blas"))]
9use linfa_linalg::{eigh::*, svd::*};
10use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Data, Ix2};
11#[cfg(feature = "blas")]
12use ndarray_linalg::{eigh::Eigh, solveh::UPLO, svd::SVD};
13use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
14use ndarray_stats::QuantileExt;
15use rand_xoshiro::Xoshiro256Plus;
16#[cfg(feature = "serde")]
17use serde_crate::{Deserialize, Serialize};
18
19use crate::error::{FastIcaError, Result};
20use crate::hyperparams::FastIcaValidParams;
21
22impl<F: Float, D: Data<Elem = F>, T> Fit<ArrayBase<D, Ix2>, T, FastIcaError>
23 for FastIcaValidParams<F>
24{
25 type Object = FastIca<F>;
26
27 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
37 let x = &dataset.records;
38 let (nsamples, nfeatures) = (x.nsamples(), x.nfeatures());
39 if dataset.nsamples() == 0 {
40 return Err(FastIcaError::NotEnoughSamples);
41 }
42
43 let ncomponents = self
46 .ncomponents()
47 .unwrap_or_else(|| nsamples.min(nfeatures));
48
49 if ncomponents > nsamples.min(nfeatures) {
52 return Err(FastIcaError::InvalidValue(format!(
53 "ncomponents cannot be greater than the min({nsamples}, {nfeatures}), got {ncomponents}"
54 )));
55 }
56
57 let xmean = x.mean_axis(Axis(0)).unwrap();
60 let mut xcentered = x - &xmean.view().insert_axis(Axis(0));
61
62 xcentered = xcentered.reversed_axes();
64
65 let xcentered = xcentered.with_lapack();
68 let k = match xcentered.svd(true, false)? {
69 (Some(u), s, _) => {
70 let s = s.mapv(F::Lapack::cast);
71 (u.slice_move(s![.., ..nsamples.min(nfeatures)]) / s)
75 .t()
76 .slice(s![..ncomponents, ..])
77 .to_owned()
78 }
79 _ => return Err(FastIcaError::SvdDecomposition),
80 };
81
82 let mut xwhitened = k.dot(&xcentered).without_lapack();
83 let k = k.without_lapack();
84
85 let nsamples_sqrt = F::cast(nsamples).sqrt();
87 xwhitened.mapv_inplace(|x| x * nsamples_sqrt);
88
89 let w: Array2<f64>;
91 if let Some(seed) = self.random_state() {
92 let mut rng = Xoshiro256Plus::seed_from_u64(*seed as u64);
93 w = Array::random_using((ncomponents, ncomponents), Uniform::new(0., 1.), &mut rng);
94 } else {
95 w = Array::random((ncomponents, ncomponents), Uniform::new(0., 1.));
96 }
97 let mut w = w.mapv(F::cast);
98
99 w = self.ica_parallel(&xwhitened, &w)?;
101
102 let components = w.dot(&k);
104
105 Ok(FastIca {
106 mean: xmean,
107 components,
108 })
109 }
110}
111
112impl<F: Float> FastIcaValidParams<F> {
113 fn ica_parallel(&self, x: &Array2<F>, w: &Array2<F>) -> Result<Array2<F>> {
115 let mut w = Self::sym_decorrelation(w)?;
116
117 let p = x.ncols() as f64;
118
119 for _ in 0..self.max_iter() {
120 let (gwtx, g_wtx) = self.gfunc().exec(&w.dot(x))?;
121
122 let lhs = gwtx.dot(&x.t()).mapv(|x| x / F::cast(p));
123 let rhs = &w * &g_wtx.insert_axis(Axis(1));
124 let wnew = Self::sym_decorrelation(&(lhs - rhs))?;
125
126 let lim = *wnew
129 .outer_iter()
130 .zip(w.outer_iter())
131 .map(|(a, b)| a.dot(&b))
132 .collect::<Array1<F>>()
133 .mapv(|x| x.abs())
134 .mapv(|x| x - F::cast(1.))
135 .mapv(|x| x.abs())
136 .max()
137 .unwrap();
138
139 w = wnew;
140
141 if lim < F::cast(self.tol()) {
142 break;
143 }
144 }
145
146 Ok(w)
147 }
148
149 fn sym_decorrelation(w: &Array2<F>) -> Result<Array2<F>> {
153 #[cfg(feature = "blas")]
154 let (eig_val, eig_vec) = w.dot(&w.t()).with_lapack().eigh(UPLO::Upper)?;
155 #[cfg(not(feature = "blas"))]
156 let (eig_val, eig_vec) = w.dot(&w.t()).with_lapack().eigh()?;
157 let eig_val = eig_val.mapv(F::cast);
158 let eig_vec = eig_vec.without_lapack();
159
160 let tmp = &eig_vec
161 * &(eig_val.mapv(|x| x.sqrt()).mapv(|x| {
162 let lower_bound = F::cast(1e-7);
164 if x < lower_bound {
165 return lower_bound.recip();
166 }
167 x.recip()
168 }))
169 .insert_axis(Axis(0));
170
171 Ok(tmp.dot(&eig_vec.t()).dot(w))
172 }
173}
174
175#[cfg_attr(
177 feature = "serde",
178 derive(Serialize, Deserialize),
179 serde(crate = "serde_crate")
180)]
181#[derive(Debug, Clone, PartialEq)]
182pub struct FastIca<F> {
183 mean: Array1<F>,
184 components: Array2<F>,
185}
186
187impl<F: Float> PredictInplace<Array2<F>, Array2<F>> for FastIca<F> {
188 fn predict_inplace(&self, x: &Array2<F>, y: &mut Array2<F>) {
190 assert_eq!(
191 y.shape(),
192 &[x.nrows(), self.components.nrows()],
193 "The number of data points must match the number of output targets."
194 );
195
196 let xcentered = x - &self.mean.view().insert_axis(Axis(0));
197 *y = xcentered.dot(&self.components.t());
198 }
199
200 fn default_target(&self, x: &Array2<F>) -> Array2<F> {
201 Array2::zeros((x.nrows(), self.components.nrows()))
202 }
203}
204
205#[cfg_attr(
207 feature = "serde",
208 derive(Serialize, Deserialize),
209 serde(crate = "serde_crate")
210)]
211#[derive(Debug, Clone, PartialEq, Copy)]
212pub enum GFunc {
213 Logcosh(f64),
214 Exp,
215 Cube,
216}
217
218impl GFunc {
219 fn exec<A: Float>(&self, x: &Array2<A>) -> Result<(Array2<A>, Array1<A>)> {
223 match self {
224 Self::Cube => Ok(Self::cube(x)),
225 Self::Exp => Ok(Self::exp(x)),
226 Self::Logcosh(alpha) => Self::logcosh(x, *alpha),
227 }
228 }
229
230 fn cube<A: Float>(x: &Array2<A>) -> (Array2<A>, Array1<A>) {
231 (
232 x.mapv(|x| x.powi(3)),
233 x.mapv(|x| A::cast(3.) * x.powi(2))
234 .mean_axis(Axis(1))
235 .unwrap(),
236 )
237 }
238
239 fn exp<A: Float>(x: &Array2<A>) -> (Array2<A>, Array1<A>) {
240 let exp = x.mapv(|x| -x.powi(2) / A::cast(2.));
241 (
242 x * &exp,
243 (x.mapv(|x| A::cast(1.) - x.powi(2)) * &exp)
244 .mean_axis(Axis(1))
245 .unwrap(),
246 )
247 }
248
249 fn logcosh<A: Float>(x: &Array2<A>, alpha: f64) -> Result<(Array2<A>, Array1<A>)> {
251 if !(1.0..=2.0).contains(&alpha) {
253 return Err(FastIcaError::InvalidValue(format!(
254 "alpha must be between 1 and 2 inclusive, got {alpha}"
255 )));
256 }
257 let alpha = A::cast(alpha);
258
259 let gx = x.mapv(|x| (x * alpha).tanh());
260 let g_x = gx.mapv(|x| alpha * (A::cast(1.) - x.powi(2)));
261
262 Ok((gx, g_x.mean_axis(Axis(1)).unwrap()))
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use linfa::traits::{Fit, Predict};
270
271 use crate::hyperparams::{FastIcaParams, FastIcaValidParams};
272 use ndarray_rand::rand_distr::StudentT;
273
274 #[test]
275 fn autotraits() {
276 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
277 has_autotraits::<FastIca<f64>>();
278 has_autotraits::<GFunc>();
279 has_autotraits::<FastIcaParams<f64>>();
280 has_autotraits::<FastIcaValidParams<f64>>();
281 has_autotraits::<FastIcaError>();
282 }
283
284 #[test]
287 fn test_ncomponents_err() {
288 let input = DatasetBase::from(Array::random((4, 4), Uniform::new(0.0, 1.0)));
289 let ica = FastIca::params().ncomponents(100);
290 let ica = ica.fit(&input);
291 assert!(ica.is_err());
292 }
293
294 #[test]
297 fn test_logcosh_alpha_err() {
298 let input = DatasetBase::from(Array::random((4, 4), Uniform::new(0.0, 1.0)));
299 let ica = FastIca::params().gfunc(GFunc::Logcosh(10.));
300 let ica = ica.fit(&input);
301 assert!(ica.is_err());
302 }
303
304 macro_rules! fast_ica_tests {
306 ($($name:ident: $gfunc:expr,)*) => {
307 paste::item! {
308 $(
309 #[test]
310 fn [<test_fast_ica_$name>]() {
311 test_fast_ica($gfunc);
312 }
313 )*
314 }
315 }
316 }
317
318 fast_ica_tests! {
321 exp: GFunc::Exp, cube: GFunc::Cube, logcosh: GFunc::Logcosh(1.0),
322 }
323
324 fn test_fast_ica(gfunc: GFunc) {
328 let nsamples = 1000;
329
330 let center_and_norm = |s: &mut Array2<f64>| {
332 let mean = s.mean_axis(Axis(0)).unwrap();
333 *s -= &mean.insert_axis(Axis(0));
334 let std = s.std_axis(Axis(0), 0.);
335 *s /= &std.insert_axis(Axis(0));
336 };
337
338 let mut source1 = Array::linspace(0., 100., nsamples);
340 source1.mapv_inplace(|x| {
341 let tmp = 2. * f64::sin(x);
342 if tmp > 0. {
343 return 0.;
344 }
345 -1.
346 });
347
348 let mut rng = Xoshiro256Plus::seed_from_u64(42);
350 let source2 = Array::random_using((nsamples, 1), StudentT::new(1.0).unwrap(), &mut rng);
351
352 let mut sources = concatenate![Axis(1), source1.insert_axis(Axis(1)), source2];
354 center_and_norm(&mut sources);
355
356 let phi: f64 = 0.6;
358 let mixing = array![[phi.cos(), phi.sin()], [phi.sin(), -phi.cos()]];
359 sources = mixing.dot(&sources.t());
360 center_and_norm(&mut sources);
361
362 sources = sources.reversed_axes();
363
364 let ica = FastIca::params()
366 .ncomponents(2)
367 .gfunc(gfunc)
368 .random_state(42);
369
370 let sources_dataset = DatasetBase::from(sources.view());
371 let ica = ica.fit(&sources_dataset).unwrap();
372 let mut output = ica.predict(&sources);
373
374 center_and_norm(&mut output);
375
376 assert_eq!(output.shape(), &[1000, 2]);
378
379 let s1 = sources.column(0);
382 let s2 = sources.column(1);
383 let mut s1_ = output.column(0);
384 let mut s2_ = output.column(1);
385 if s1_.dot(&s2).abs() > s1_.dot(&s1).abs() {
386 s1_ = output.column(1);
387 s2_ = output.column(0);
388 }
389
390 let similarity1 = s1.dot(&s1_).abs() / (nsamples as f64);
391 let similarity2 = s2.dot(&s2_).abs() / (nsamples as f64);
392
393 assert!(similarity1.max(similarity2) > 0.9);
397 }
398}