linfa_ica/
hyperparams.rs

1use crate::{error::FastIcaError, fast_ica::FastIca, fast_ica::GFunc};
2use linfa::{Float, ParamGuard};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6/// Fast Independent Component Analysis (ICA)
7#[cfg_attr(
8    feature = "serde",
9    derive(Serialize, Deserialize),
10    serde(crate = "serde_crate")
11)]
12#[derive(Debug, Clone, PartialEq)]
13pub struct FastIcaValidParams<F: Float> {
14    ncomponents: Option<usize>,
15    gfunc: GFunc,
16    max_iter: usize,
17    tol: F,
18    random_state: Option<usize>,
19}
20
21impl<F: Float> FastIcaValidParams<F> {
22    pub fn ncomponents(&self) -> &Option<usize> {
23        &self.ncomponents
24    }
25
26    pub fn gfunc(&self) -> &GFunc {
27        &self.gfunc
28    }
29
30    pub fn max_iter(&self) -> usize {
31        self.max_iter
32    }
33
34    pub fn tol(&self) -> F {
35        self.tol
36    }
37
38    pub fn random_state(&self) -> &Option<usize> {
39        &self.random_state
40    }
41}
42
43#[derive(Debug, Clone, PartialEq)]
44pub struct FastIcaParams<F: Float>(FastIcaValidParams<F>);
45
46impl<F: Float> Default for FastIcaParams<F> {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl<F: Float> FastIca<F> {
53    pub fn params() -> FastIcaParams<F> {
54        FastIcaParams::new()
55    }
56}
57
58impl<F: Float> FastIcaParams<F> {
59    /// Create new FastICA algorithm with default values for its parameters
60    pub fn new() -> Self {
61        Self(FastIcaValidParams {
62            ncomponents: None,
63            gfunc: GFunc::Logcosh(1.),
64            max_iter: 200,
65            tol: F::cast(1e-4),
66            random_state: None,
67        })
68    }
69
70    /// Set the number of components to use, if not set all are used
71    pub fn ncomponents(mut self, ncomponents: usize) -> Self {
72        self.0.ncomponents = Some(ncomponents);
73        self
74    }
75
76    /// G function used in the approximation to neg-entropy, refer [`GFunc`]
77    pub fn gfunc(mut self, gfunc: GFunc) -> Self {
78        self.0.gfunc = gfunc;
79        self
80    }
81
82    /// Set maximum number of iterations during fit
83    pub fn max_iter(mut self, max_iter: usize) -> Self {
84        self.0.max_iter = max_iter;
85        self
86    }
87
88    /// Set tolerance on upate at each iteration
89    pub fn tol(mut self, tol: F) -> Self {
90        self.0.tol = tol;
91        self
92    }
93
94    /// Set seed for random number generator for reproducible results.
95    pub fn random_state(mut self, random_state: usize) -> Self {
96        self.0.random_state = Some(random_state);
97        self
98    }
99}
100
101impl<F: Float> ParamGuard for FastIcaParams<F> {
102    type Checked = FastIcaValidParams<F>;
103    type Error = FastIcaError;
104
105    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
106        if self.0.tol < F::zero() {
107            Err(FastIcaError::InvalidTolerance(self.0.tol.to_f32().unwrap()))
108        } else {
109            Ok(&self.0)
110        }
111    }
112
113    fn check(self) -> Result<Self::Checked, Self::Error> {
114        self.check_ref()?;
115        Ok(self.0)
116    }
117}