1use crate::{error::FastIcaError, fast_ica::FastIca, fast_ica::GFunc};
2use linfa::{Float, ParamGuard};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6#[cfg_attr(
8 feature = "serde",
9 derive(Serialize, Deserialize),
10 serde(crate = "serde_crate")
11)]
12#[derive(Debug, Clone, PartialEq)]
13pub struct FastIcaValidParams<F: Float> {
14 ncomponents: Option<usize>,
15 gfunc: GFunc,
16 max_iter: usize,
17 tol: F,
18 random_state: Option<usize>,
19}
20
21impl<F: Float> FastIcaValidParams<F> {
22 pub fn ncomponents(&self) -> &Option<usize> {
23 &self.ncomponents
24 }
25
26 pub fn gfunc(&self) -> &GFunc {
27 &self.gfunc
28 }
29
30 pub fn max_iter(&self) -> usize {
31 self.max_iter
32 }
33
34 pub fn tol(&self) -> F {
35 self.tol
36 }
37
38 pub fn random_state(&self) -> &Option<usize> {
39 &self.random_state
40 }
41}
42
43#[derive(Debug, Clone, PartialEq)]
44pub struct FastIcaParams<F: Float>(FastIcaValidParams<F>);
45
46impl<F: Float> Default for FastIcaParams<F> {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl<F: Float> FastIca<F> {
53 pub fn params() -> FastIcaParams<F> {
54 FastIcaParams::new()
55 }
56}
57
58impl<F: Float> FastIcaParams<F> {
59 pub fn new() -> Self {
61 Self(FastIcaValidParams {
62 ncomponents: None,
63 gfunc: GFunc::Logcosh(1.),
64 max_iter: 200,
65 tol: F::cast(1e-4),
66 random_state: None,
67 })
68 }
69
70 pub fn ncomponents(mut self, ncomponents: usize) -> Self {
72 self.0.ncomponents = Some(ncomponents);
73 self
74 }
75
76 pub fn gfunc(mut self, gfunc: GFunc) -> Self {
78 self.0.gfunc = gfunc;
79 self
80 }
81
82 pub fn max_iter(mut self, max_iter: usize) -> Self {
84 self.0.max_iter = max_iter;
85 self
86 }
87
88 pub fn tol(mut self, tol: F) -> Self {
90 self.0.tol = tol;
91 self
92 }
93
94 pub fn random_state(mut self, random_state: usize) -> Self {
96 self.0.random_state = Some(random_state);
97 self
98 }
99}
100
101impl<F: Float> ParamGuard for FastIcaParams<F> {
102 type Checked = FastIcaValidParams<F>;
103 type Error = FastIcaError;
104
105 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
106 if self.0.tol < F::zero() {
107 Err(FastIcaError::InvalidTolerance(self.0.tol.to_f32().unwrap()))
108 } else {
109 Ok(&self.0)
110 }
111 }
112
113 fn check(self) -> Result<Self::Checked, Self::Error> {
114 self.check_ref()?;
115 Ok(self.0)
116 }
117}