linfa_clustering/gaussian_mixture/
hyperparams.rs1use crate::gaussian_mixture::errors::GmmError;
2use ndarray_rand::rand::{Rng, SeedableRng};
3use rand_xoshiro::Xoshiro256Plus;
4#[cfg(feature = "serde")]
5use serde_crate::{Deserialize, Serialize};
6
7use linfa::{Float, ParamGuard};
8
9#[cfg_attr(
10 feature = "serde",
11 derive(Serialize, Deserialize),
12 serde(crate = "serde_crate")
13)]
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15pub enum GmmCovarType {
17 Full,
19}
20
21#[cfg_attr(
22 feature = "serde",
23 derive(Serialize, Deserialize),
24 serde(crate = "serde_crate")
25)]
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
27pub enum GmmInitMethod {
29 KMeans,
31 Random,
33}
34
35#[cfg_attr(
36 feature = "serde",
37 derive(Serialize, Deserialize),
38 serde(crate = "serde_crate")
39)]
40#[derive(Clone, Debug, PartialEq)]
41pub struct GmmValidParams<F: Float, R: Rng> {
44 n_clusters: usize,
45 covar_type: GmmCovarType,
46 tolerance: F,
47 reg_covar: F,
48 n_runs: u64,
49 max_n_iter: u64,
50 init_method: GmmInitMethod,
51 rng: R,
52}
53
54impl<F: Float, R: Rng + Clone> GmmValidParams<F, R> {
55 pub fn n_clusters(&self) -> usize {
56 self.n_clusters
57 }
58
59 pub fn covariance_type(&self) -> &GmmCovarType {
60 &self.covar_type
61 }
62
63 pub fn tolerance(&self) -> F {
64 self.tolerance
65 }
66
67 pub fn reg_covariance(&self) -> F {
68 self.reg_covar
69 }
70
71 pub fn n_runs(&self) -> u64 {
72 self.n_runs
73 }
74
75 pub fn max_n_iterations(&self) -> u64 {
76 self.max_n_iter
77 }
78
79 pub fn init_method(&self) -> &GmmInitMethod {
80 &self.init_method
81 }
82
83 pub fn rng(&self) -> R {
84 self.rng.clone()
85 }
86}
87
88#[cfg_attr(
89 feature = "serde",
90 derive(Serialize, Deserialize),
91 serde(crate = "serde_crate")
92)]
93#[derive(Clone, Debug, PartialEq)]
94pub struct GmmParams<F: Float, R: Rng>(GmmValidParams<F, R>);
97
98impl<F: Float> GmmParams<F, Xoshiro256Plus> {
99 pub fn new(n_clusters: usize) -> Self {
100 Self::new_with_rng(n_clusters, Xoshiro256Plus::seed_from_u64(42))
101 }
102}
103
104impl<F: Float, R: Rng + Clone> GmmParams<F, R> {
105 pub fn new_with_rng(n_clusters: usize, rng: R) -> GmmParams<F, R> {
106 Self(GmmValidParams {
107 n_clusters,
108 covar_type: GmmCovarType::Full,
109 tolerance: F::cast(1e-3),
110 reg_covar: F::cast(1e-6),
111 n_runs: 1,
112 max_n_iter: 100,
113 init_method: GmmInitMethod::KMeans,
114 rng,
115 })
116 }
117
118 pub fn covariance_type(mut self, covar_type: GmmCovarType) -> Self {
120 self.0.covar_type = covar_type;
121 self
122 }
123
124 pub fn tolerance(mut self, tolerance: F) -> Self {
126 self.0.tolerance = tolerance;
127 self
128 }
129
130 pub fn reg_covariance(mut self, reg_covar: F) -> Self {
133 self.0.reg_covar = reg_covar;
134 self
135 }
136
137 pub fn n_runs(mut self, n_runs: u64) -> Self {
139 self.0.n_runs = n_runs;
140 self
141 }
142
143 pub fn max_n_iterations(mut self, max_n_iter: u64) -> Self {
145 self.0.max_n_iter = max_n_iter;
146 self
147 }
148
149 pub fn init_method(mut self, init_method: GmmInitMethod) -> Self {
151 self.0.init_method = init_method;
152 self
153 }
154
155 pub fn with_rng<R2: Rng + Clone>(self, rng: R2) -> GmmParams<F, R2> {
156 GmmParams(GmmValidParams {
157 n_clusters: self.0.n_clusters,
158 covar_type: self.0.covar_type,
159 tolerance: self.0.tolerance,
160 reg_covar: self.0.reg_covar,
161 n_runs: self.0.n_runs,
162 max_n_iter: self.0.max_n_iter,
163 init_method: self.0.init_method,
164 rng,
165 })
166 }
167}
168
169impl<F: Float, R: Rng> ParamGuard for GmmParams<F, R> {
170 type Checked = GmmValidParams<F, R>;
171 type Error = GmmError;
172
173 fn check_ref(&self) -> Result<&Self::Checked, GmmError> {
174 if self.0.n_clusters == 0 {
175 Err(GmmError::InvalidValue(
176 "`n_clusters` cannot be 0!".to_string(),
177 ))
178 } else if self.0.tolerance <= F::zero() {
179 Err(GmmError::InvalidValue(
180 "`tolerance` must be greater than 0!".to_string(),
181 ))
182 } else if self.0.reg_covar < F::zero() {
183 Err(GmmError::InvalidValue(
184 "`reg_covar` must be positive!".to_string(),
185 ))
186 } else if self.0.n_runs == 0 {
187 Err(GmmError::InvalidValue("`n_runs` cannot be 0!".to_string()))
188 } else if self.0.max_n_iter == 0 {
189 Err(GmmError::InvalidValue(
190 "`max_n_iterations` cannot be 0!".to_string(),
191 ))
192 } else {
193 Ok(&self.0)
194 }
195 }
196
197 fn check(self) -> Result<Self::Checked, GmmError> {
198 self.check_ref()?;
199 Ok(self.0)
200 }
201}