linfa_clustering/gaussian_mixture/
hyperparams.rs

1use crate::gaussian_mixture::errors::GmmError;
2use ndarray_rand::rand::{Rng, SeedableRng};
3use rand_xoshiro::Xoshiro256Plus;
4#[cfg(feature = "serde")]
5use serde_crate::{Deserialize, Serialize};
6
7use linfa::{Float, ParamGuard};
8
9#[cfg_attr(
10    feature = "serde",
11    derive(Serialize, Deserialize),
12    serde(crate = "serde_crate")
13)]
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15/// A specifier for the type of the relation between components' covariances.
16pub enum GmmCovarType {
17    /// each component has its own general covariance matrix
18    Full,
19}
20
21#[cfg_attr(
22    feature = "serde",
23    derive(Serialize, Deserialize),
24    serde(crate = "serde_crate")
25)]
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
27/// A specifier for the method used for the initialization of the fitting algorithm of GMM
28pub enum GmmInitMethod {
29    /// GMM fitting algorithm is initalized with the result of the [KMeans](crate::KMeans) clustering.
30    KMeans,
31    /// GMM fitting algorithm is initialized randomly.
32    Random,
33}
34
35#[cfg_attr(
36    feature = "serde",
37    derive(Serialize, Deserialize),
38    serde(crate = "serde_crate")
39)]
40#[derive(Clone, Debug, PartialEq)]
41/// The set of hyperparameters that can be specified for the execution of
42/// the [GMM algorithm](crate::GaussianMixtureModel).
43pub struct GmmValidParams<F: Float, R: Rng> {
44    n_clusters: usize,
45    covar_type: GmmCovarType,
46    tolerance: F,
47    reg_covar: F,
48    n_runs: u64,
49    max_n_iter: u64,
50    init_method: GmmInitMethod,
51    rng: R,
52}
53
54impl<F: Float, R: Rng + Clone> GmmValidParams<F, R> {
55    pub fn n_clusters(&self) -> usize {
56        self.n_clusters
57    }
58
59    pub fn covariance_type(&self) -> &GmmCovarType {
60        &self.covar_type
61    }
62
63    pub fn tolerance(&self) -> F {
64        self.tolerance
65    }
66
67    pub fn reg_covariance(&self) -> F {
68        self.reg_covar
69    }
70
71    pub fn n_runs(&self) -> u64 {
72        self.n_runs
73    }
74
75    pub fn max_n_iterations(&self) -> u64 {
76        self.max_n_iter
77    }
78
79    pub fn init_method(&self) -> &GmmInitMethod {
80        &self.init_method
81    }
82
83    pub fn rng(&self) -> R {
84        self.rng.clone()
85    }
86}
87
88#[cfg_attr(
89    feature = "serde",
90    derive(Serialize, Deserialize),
91    serde(crate = "serde_crate")
92)]
93#[derive(Clone, Debug, PartialEq)]
94/// The set of hyperparameters that can be specified for the execution of
95/// the [GMM algorithm](crate::GaussianMixtureModel).
96pub struct GmmParams<F: Float, R: Rng>(GmmValidParams<F, R>);
97
98impl<F: Float> GmmParams<F, Xoshiro256Plus> {
99    pub fn new(n_clusters: usize) -> Self {
100        Self::new_with_rng(n_clusters, Xoshiro256Plus::seed_from_u64(42))
101    }
102}
103
104impl<F: Float, R: Rng + Clone> GmmParams<F, R> {
105    pub fn new_with_rng(n_clusters: usize, rng: R) -> GmmParams<F, R> {
106        Self(GmmValidParams {
107            n_clusters,
108            covar_type: GmmCovarType::Full,
109            tolerance: F::cast(1e-3),
110            reg_covar: F::cast(1e-6),
111            n_runs: 1,
112            max_n_iter: 100,
113            init_method: GmmInitMethod::KMeans,
114            rng,
115        })
116    }
117
118    /// Set the covariance type.
119    pub fn covariance_type(mut self, covar_type: GmmCovarType) -> Self {
120        self.0.covar_type = covar_type;
121        self
122    }
123
124    /// Set the convergence threshold. EM iterations will stop when the lower bound average gain is below this threshold.
125    pub fn tolerance(mut self, tolerance: F) -> Self {
126        self.0.tolerance = tolerance;
127        self
128    }
129
130    /// Non-negative regularization added to the diagonal of covariance.
131    /// Allows to assure that the covariance matrices are all positive.
132    pub fn reg_covariance(mut self, reg_covar: F) -> Self {
133        self.0.reg_covar = reg_covar;
134        self
135    }
136
137    /// Set the number of initializations to perform. The best results are kept.
138    pub fn n_runs(mut self, n_runs: u64) -> Self {
139        self.0.n_runs = n_runs;
140        self
141    }
142
143    /// Set the number of EM iterations to perform.
144    pub fn max_n_iterations(mut self, max_n_iter: u64) -> Self {
145        self.0.max_n_iter = max_n_iter;
146        self
147    }
148
149    /// Set the method used to initialize the weights, the means and the precisions.
150    pub fn init_method(mut self, init_method: GmmInitMethod) -> Self {
151        self.0.init_method = init_method;
152        self
153    }
154
155    pub fn with_rng<R2: Rng + Clone>(self, rng: R2) -> GmmParams<F, R2> {
156        GmmParams(GmmValidParams {
157            n_clusters: self.0.n_clusters,
158            covar_type: self.0.covar_type,
159            tolerance: self.0.tolerance,
160            reg_covar: self.0.reg_covar,
161            n_runs: self.0.n_runs,
162            max_n_iter: self.0.max_n_iter,
163            init_method: self.0.init_method,
164            rng,
165        })
166    }
167}
168
169impl<F: Float, R: Rng> ParamGuard for GmmParams<F, R> {
170    type Checked = GmmValidParams<F, R>;
171    type Error = GmmError;
172
173    fn check_ref(&self) -> Result<&Self::Checked, GmmError> {
174        if self.0.n_clusters == 0 {
175            Err(GmmError::InvalidValue(
176                "`n_clusters` cannot be 0!".to_string(),
177            ))
178        } else if self.0.tolerance <= F::zero() {
179            Err(GmmError::InvalidValue(
180                "`tolerance` must be greater than 0!".to_string(),
181            ))
182        } else if self.0.reg_covar < F::zero() {
183            Err(GmmError::InvalidValue(
184                "`reg_covar` must be positive!".to_string(),
185            ))
186        } else if self.0.n_runs == 0 {
187            Err(GmmError::InvalidValue("`n_runs` cannot be 0!".to_string()))
188        } else if self.0.max_n_iter == 0 {
189            Err(GmmError::InvalidValue(
190                "`max_n_iterations` cannot be 0!".to_string(),
191            ))
192        } else {
193            Ok(&self.0)
194        }
195    }
196
197    fn check(self) -> Result<Self::Checked, GmmError> {
198        self.check_ref()?;
199        Ok(self.0)
200    }
201}