linfa_clustering/gaussian_mixture/
algorithm.rs

1use crate::gaussian_mixture::errors::GmmError;
2use crate::gaussian_mixture::hyperparams::{
3    GmmCovarType, GmmInitMethod, GmmParams, GmmValidParams,
4};
5use crate::k_means::KMeans;
6use linfa::{prelude::*, DatasetBase, Float};
7use linfa_linalg::{cholesky::*, triangular::*};
8use ndarray::{s, Array, Array1, Array2, Array3, ArrayBase, Axis, Data, Ix2, Ix3, Zip};
9use ndarray_rand::rand::Rng;
10use ndarray_rand::rand_distr::Uniform;
11use ndarray_rand::RandomExt;
12use ndarray_stats::QuantileExt;
13use rand_xoshiro::Xoshiro256Plus;
14#[cfg(feature = "serde")]
15use serde_crate::{Deserialize, Serialize};
16
17#[cfg_attr(
18    feature = "serde",
19    derive(Serialize, Deserialize),
20    serde(crate = "serde_crate")
21)]
22/// Gaussian Mixture Model (GMM) aims at clustering a dataset by finding normally
23/// distributed sub datasets (hence the Gaussian Mixture name) .
24///
25/// GMM assumes all the data points are generated from a mixture of a number K
26/// of Gaussian distributions with certain parameters.
27/// Expectation-maximization (EM) algorithm is used to fit the GMM to the dataset
28/// by parameterizing the weight, mean, and covariance of each cluster distribution.
29///
30/// This implementation is a port of the [scikit-learn 0.23.2 Gaussian Mixture](https://scikit-learn.org)
31/// implementation.
32///
33/// ## The algorithm  
34///
35/// The general idea is to maximize the likelihood (equivalently the log likelihood)
36/// that is maximising the probability that the dataset is drawn from our mixture of normal distributions.
37///
38/// After an initialization step which can be either from random distribution or from the result
39/// of the [KMeans](KMeans) algorithm (which is the default value of the `init_method` parameter).
40/// The core EM iterative algorithm for Gaussian Mixture is a fixed-point two-step algorithm:
41///
42/// 1. Expectation step: compute the expectation of the likelihood of the current gaussian mixture model wrt the dataset.
43/// 2. Maximization step: update the gaussian parameters (weigths, means and covariances) to maximize the likelihood.
44///
45/// We stop iterating when there is no significant gaussian parameters change (controlled by the `tolerance` parameter) or
46/// if we reach a max number of iterations (controlled by `max_n_iterations` parameter)
47/// As the initialization of the algorithm is subject to randomness, several initializations are performed (controlled by
48/// the `n_runs` parameter).   
49///
50/// ## Tutorial
51///
52/// Let's do a walkthrough of a training-predict-save example.
53///
54/// ```rust
55/// use linfa::DatasetBase;
56/// use linfa::prelude::*;
57/// use linfa_clustering::{GmmValidParams, GaussianMixtureModel};
58/// use linfa_datasets::generate;
59/// use ndarray::{Axis, array, s, Zip};
60/// use ndarray_rand::rand::SeedableRng;
61/// use rand_xoshiro::Xoshiro256Plus;
62/// use approx::assert_abs_diff_eq;
63///
64/// let mut rng = Xoshiro256Plus::seed_from_u64(42);
65/// let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
66/// let n = 200;
67///
68/// // We generate a dataset from points normally distributed around some distant centroids.  
69/// let dataset = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));
70///
71/// // Our GMM is expected to have a number of clusters equals the number of centroids
72/// // used to generate the dataset
73/// let n_clusters = expected_centroids.len_of(Axis(0));
74///
75/// // We fit the model from the dataset setting some options
76/// let gmm = GaussianMixtureModel::params(n_clusters)
77///             .n_runs(10)
78///             .tolerance(1e-4)
79///             .with_rng(rng)
80///             .fit(&dataset).expect("GMM fitting");
81///
82/// // Then we can get dataset membership information, targets contain **cluster indexes**
83/// // corresponding to the cluster infos in the list of GMM means and covariances
84/// let blobs_dataset = gmm.predict(dataset);
85/// let DatasetBase {
86///     records: _blobs_records,
87///     targets: blobs_targets,
88///     ..
89/// } = blobs_dataset;
90/// println!("GMM means = {:?}", gmm.means());
91/// println!("GMM covariances = {:?}", gmm.covariances());
92/// println!("GMM membership = {:?}", blobs_targets);
93///
94/// // We can also get the nearest cluster for a new point
95/// let new_observation = DatasetBase::from(array![[-9., 20.5]]);
96/// // Predict returns the **index** of the nearest cluster
97/// let dataset = gmm.predict(new_observation);
98/// // We can retrieve the actual centroid of the closest cluster using `.centroids()` (alias of .means())
99/// let closest_centroid = &gmm.centroids().index_axis(Axis(0), dataset.targets()[0]);
100/// ```
101#[derive(Debug, PartialEq)]
102pub struct GaussianMixtureModel<F: Float> {
103    covar_type: GmmCovarType,
104    weights: Array1<F>,
105    means: Array2<F>,
106    covariances: Array3<F>,
107    precisions: Array3<F>,
108    precisions_chol: Array3<F>,
109}
110
111impl<F: Float> Clone for GaussianMixtureModel<F> {
112    fn clone(&self) -> Self {
113        Self {
114            covar_type: self.covar_type,
115            weights: self.weights.to_owned(),
116            means: self.means.to_owned(),
117            covariances: self.covariances.to_owned(),
118            precisions: self.precisions.to_owned(),
119            precisions_chol: self.precisions_chol.to_owned(),
120        }
121    }
122}
123
124impl<F: Float> GaussianMixtureModel<F> {
125    fn new<D: Data<Elem = F>, R: Rng + Clone, T>(
126        hyperparameters: &GmmValidParams<F, R>,
127        dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
128        mut rng: R,
129    ) -> Result<GaussianMixtureModel<F>, GmmError> {
130        let observations = dataset.records().view();
131        let n_samples = observations.nrows();
132
133        // We initialize responsabilities (n_samples, n_clusters) of each clusters
134        // that is, given a sample, the probabilities of a cluster being the source.
135        // Responsabilities can be initialized either from a KMeans result or randomly.
136        let resp = match hyperparameters.init_method() {
137            GmmInitMethod::KMeans => {
138                let model = KMeans::params_with_rng(hyperparameters.n_clusters(), rng)
139                    .check()
140                    .unwrap()
141                    .fit(dataset)?;
142                let mut resp = Array::<F, Ix2>::zeros((n_samples, hyperparameters.n_clusters()));
143                for (k, idx) in model.predict(dataset.records()).iter().enumerate() {
144                    resp[[k, *idx]] = F::cast(1.);
145                }
146                resp
147            }
148            GmmInitMethod::Random => {
149                let mut resp = Array2::<f64>::random_using(
150                    (n_samples, hyperparameters.n_clusters()),
151                    Uniform::new(0., 1.),
152                    &mut rng,
153                );
154                let totals = &resp.sum_axis(Axis(1)).insert_axis(Axis(0));
155                resp = (resp.reversed_axes() / totals).reversed_axes();
156                resp.mapv(F::cast)
157            }
158        };
159
160        // We compute an initial GMM model from dataset and initial responsabilities wrt
161        // to covariance specification.
162        let (mut weights, means, covariances) = Self::estimate_gaussian_parameters(
163            &observations,
164            &resp,
165            hyperparameters.covariance_type(),
166            hyperparameters.reg_covariance(),
167        )?;
168        weights /= F::cast(n_samples);
169
170        // GmmCovarType = full
171        let precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
172        let precisions = Self::compute_precisions_full(&precisions_chol);
173
174        Ok(GaussianMixtureModel {
175            covar_type: *hyperparameters.covariance_type(),
176            weights,
177            means,
178            covariances,
179            precisions,
180            precisions_chol,
181        })
182    }
183}
184
185impl<F: Float> GaussianMixtureModel<F> {
186    pub fn params(n_clusters: usize) -> GmmParams<F, Xoshiro256Plus> {
187        GmmParams::new(n_clusters)
188    }
189
190    pub fn params_with_rng<R: Rng + Clone>(n_clusters: usize, rng: R) -> GmmParams<F, R> {
191        GmmParams::new_with_rng(n_clusters, rng)
192    }
193
194    pub fn weights(&self) -> &Array1<F> {
195        &self.weights
196    }
197
198    pub fn means(&self) -> &Array2<F> {
199        &self.means
200    }
201
202    pub fn covariances(&self) -> &Array3<F> {
203        &self.covariances
204    }
205
206    pub fn precisions(&self) -> &Array3<F> {
207        &self.precisions
208    }
209
210    pub fn centroids(&self) -> &Array2<F> {
211        self.means()
212    }
213    /// Returns the responsibilities as a (n_obs, n_clusters) array where each row gives
214    /// the probabilities for the corresponding ith observation to belong to the jth cluster.
215    pub fn predict_proba<D: Data<Elem = F>>(&self, observations: &ArrayBase<D, Ix2>) -> Array2<F> {
216        let (_, log_resp) = self.estimate_log_prob_resp(observations);
217        log_resp.mapv(F::exp)
218    }
219
220    #[allow(clippy::type_complexity)]
221    fn estimate_gaussian_parameters<D: Data<Elem = F>>(
222        observations: &ArrayBase<D, Ix2>,
223        resp: &Array2<F>,
224        _covar_type: &GmmCovarType,
225        reg_covar: F,
226    ) -> Result<(Array1<F>, Array2<F>, Array3<F>), GmmError> {
227        let nk = resp.sum_axis(Axis(0));
228        if nk.min()? < &(F::cast(10.) * F::epsilon()) {
229            return Err(GmmError::EmptyCluster(format!(
230                "Cluster #{} has no more point. Consider decreasing number of clusters or change initialization.",
231                nk.argmin()? + 1
232            )));
233        }
234
235        let nk2 = nk.to_owned().insert_axis(Axis(1));
236        let means = resp.t().dot(observations) / nk2;
237        // GmmCovarType = Full
238        let covariances =
239            Self::estimate_gaussian_covariances_full(observations, resp, &nk, &means, reg_covar);
240        Ok((nk, means, covariances))
241    }
242
243    fn estimate_gaussian_covariances_full<D: Data<Elem = F>>(
244        observations: &ArrayBase<D, Ix2>,
245        resp: &Array2<F>,
246        nk: &Array1<F>,
247        means: &Array2<F>,
248        reg_covar: F,
249    ) -> Array3<F> {
250        let n_clusters = means.nrows();
251        let n_features = means.ncols();
252        let mut covariances = Array::zeros((n_clusters, n_features, n_features));
253        for k in 0..n_clusters {
254            let diff = observations - &means.row(k);
255            let m = &diff.t() * &resp.index_axis(Axis(1), k);
256            let mut cov_k = m.dot(&diff) / nk[k];
257            cov_k.diag_mut().mapv_inplace(|x| x + reg_covar);
258            covariances.slice_mut(s![k, .., ..]).assign(&cov_k);
259        }
260        covariances
261    }
262
263    fn compute_precisions_cholesky_full<D: Data<Elem = F>>(
264        covariances: &ArrayBase<D, Ix3>,
265    ) -> Result<Array3<F>, GmmError> {
266        let n_clusters = covariances.shape()[0];
267        let n_features = covariances.shape()[1];
268        let mut precisions_chol = Array::zeros((n_clusters, n_features, n_features));
269        for (k, covariance) in covariances.outer_iter().enumerate() {
270            let sol = {
271                let decomp = covariance.cholesky()?;
272                decomp.solve_triangular_into(Array::eye(n_features), UPLO::Lower)?
273            };
274
275            precisions_chol.slice_mut(s![k, .., ..]).assign(&sol.t());
276        }
277        Ok(precisions_chol)
278    }
279
280    fn compute_precisions_full<D: Data<Elem = F>>(
281        precisions_chol: &ArrayBase<D, Ix3>,
282    ) -> Array3<F> {
283        let mut precisions = Array3::zeros(precisions_chol.dim());
284        for (k, prec_chol) in precisions_chol.outer_iter().enumerate() {
285            precisions
286                .slice_mut(s![k, .., ..])
287                .assign(&prec_chol.dot(&prec_chol.t()));
288        }
289        precisions
290    }
291
292    // Refresh precisions value only at the end of the fitting procedure
293    fn refresh_precisions_full(&mut self) {
294        self.precisions = Self::compute_precisions_full(&self.precisions_chol);
295    }
296
297    fn e_step<D: Data<Elem = F>>(
298        &self,
299        observations: &ArrayBase<D, Ix2>,
300    ) -> Result<(F, Array2<F>), GmmError> {
301        let (log_prob_norm, log_resp) = self.estimate_log_prob_resp(observations);
302        let log_mean = log_prob_norm.mean().unwrap();
303        Ok((log_mean, log_resp))
304    }
305
306    fn m_step<D: Data<Elem = F>>(
307        &mut self,
308        reg_covar: F,
309        observations: &ArrayBase<D, Ix2>,
310        log_resp: &Array2<F>,
311    ) -> Result<(), GmmError> {
312        let n_samples = observations.nrows();
313        let (weights, means, covariances) = Self::estimate_gaussian_parameters(
314            observations,
315            &log_resp.mapv(|x| x.exp()),
316            &self.covar_type,
317            reg_covar,
318        )?;
319        self.means = means;
320        self.weights = weights / F::cast(n_samples);
321        self.covariances = covariances;
322        // GmmCovarType = Full()
323        self.precisions_chol = Self::compute_precisions_cholesky_full(&self.covariances)?;
324        Ok(())
325    }
326
327    // We keep methods names and method boundaries from scikit-learn implementation
328    // which handles also Bayesian mixture hence below the _log_resp argument which is not used.
329    fn compute_lower_bound<D: Data<Elem = F>>(
330        _log_resp: &ArrayBase<D, Ix2>,
331        log_prob_norm: F,
332    ) -> F {
333        log_prob_norm
334    }
335
336    // Estimate log probabilities (log P(X)) and responsibilities for each sample.
337    // Compute weighted log probabilities per component (log P(X)) and responsibilities
338    // for each sample in X with respect to the current state of the model.
339    fn estimate_log_prob_resp<D: Data<Elem = F>>(
340        &self,
341        observations: &ArrayBase<D, Ix2>,
342    ) -> (Array1<F>, Array2<F>) {
343        let weighted_log_prob = self.estimate_weighted_log_prob(observations);
344        let log_prob_norm = weighted_log_prob
345            .mapv(|x| x.exp())
346            .sum_axis(Axis(1))
347            .mapv(|x| x.ln());
348        let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1));
349        (log_prob_norm, log_resp)
350    }
351
352    // Estimate weighted log probabilities for each samples wrt to the model
353    fn estimate_weighted_log_prob<D: Data<Elem = F>>(
354        &self,
355        observations: &ArrayBase<D, Ix2>,
356    ) -> Array2<F> {
357        self.estimate_log_prob(observations) + self.estimate_log_weights()
358    }
359
360    // Compute log probabilities for each samples wrt to the model which is gaussian
361    fn estimate_log_prob<D: Data<Elem = F>>(&self, observations: &ArrayBase<D, Ix2>) -> Array2<F> {
362        self.estimate_log_gaussian_prob(observations)
363    }
364
365    // Compute the log likelihood in case of the gaussian probabilities
366    // log(P(X|Mean, Precision)) = -0.5*(d*ln(2*PI)-ln(det(Precision))-(X-Mean)^t.Precision.(X-Mean)
367    fn estimate_log_gaussian_prob<D: Data<Elem = F>>(
368        &self,
369        observations: &ArrayBase<D, Ix2>,
370    ) -> Array2<F> {
371        let n_samples = observations.nrows();
372        let n_features = observations.ncols();
373        let means = self.means();
374        let n_clusters = means.nrows();
375        // GmmCovarType = full
376        // det(precision_chol) is half of det(precision)
377        let log_det = Self::compute_log_det_cholesky_full(&self.precisions_chol, n_features);
378        let mut log_prob: Array2<F> = Array::zeros((n_samples, n_clusters));
379        Zip::indexed(means.rows())
380            .and(self.precisions_chol.outer_iter())
381            .for_each(|k, mu, prec_chol| {
382                let diff = (&observations.to_owned() - &mu).dot(&prec_chol);
383                log_prob
384                    .slice_mut(s![.., k])
385                    .assign(&diff.mapv(|v| v * v).sum_axis(Axis(1)))
386            });
387        log_prob.mapv(|v| {
388            F::cast(-0.5) * (v + F::cast(n_features as f64 * f64::ln(2. * std::f64::consts::PI)))
389        }) + log_det
390    }
391
392    fn compute_log_det_cholesky_full<D: Data<Elem = F>>(
393        matrix_chol: &ArrayBase<D, Ix3>,
394        n_features: usize,
395    ) -> Array1<F> {
396        let n_clusters = matrix_chol.shape()[0];
397        let log_diags = &matrix_chol
398            .to_owned()
399            .into_shape((n_clusters, n_features * n_features))
400            .unwrap()
401            .slice(s![.., ..; n_features+1])
402            .to_owned()
403            .mapv(|x| x.ln());
404        log_diags.sum_axis(Axis(1))
405    }
406
407    fn estimate_log_weights(&self) -> Array1<F> {
408        self.weights().mapv(|x| x.ln())
409    }
410}
411
412impl<F: Float, R: Rng + Clone, D: Data<Elem = F>, T> Fit<ArrayBase<D, Ix2>, T, GmmError>
413    for GmmValidParams<F, R>
414{
415    type Object = GaussianMixtureModel<F>;
416
417    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, GmmError> {
418        let observations = dataset.records().view();
419        let mut gmm = GaussianMixtureModel::<F>::new(self, dataset, self.rng())?;
420
421        let mut max_lower_bound = -F::infinity();
422        let mut best_params = None;
423        let mut best_iter = None;
424
425        let n_runs = self.n_runs();
426
427        for _ in 0..n_runs {
428            let mut lower_bound = -F::infinity();
429
430            let mut converged_iter: Option<u64> = None;
431            for n_iter in 0..self.max_n_iterations() {
432                let prev_lower_bound = lower_bound;
433                let (log_prob_norm, log_resp) = gmm.e_step(&observations)?;
434                gmm.m_step(self.reg_covariance(), &observations, &log_resp)?;
435                lower_bound =
436                    GaussianMixtureModel::<F>::compute_lower_bound(&log_resp, log_prob_norm);
437                let change = lower_bound - prev_lower_bound;
438                if change.abs() < self.tolerance() {
439                    converged_iter = Some(n_iter);
440                    break;
441                }
442            }
443
444            if lower_bound > max_lower_bound {
445                max_lower_bound = lower_bound;
446                gmm.refresh_precisions_full();
447                best_params = Some(gmm.clone());
448                best_iter = converged_iter;
449            }
450        }
451
452        match best_iter {
453            Some(_n_iter) => match best_params {
454                Some(gmm) => Ok(gmm),
455                _ => Err(GmmError::LowerBoundError(
456                    "No lower bound improvement (-inf)".to_string(),
457                )),
458            },
459            None => Err(GmmError::NotConverged(format!(
460                "EM fitting algorithm {} did not converge. Try different init parameters, \
461                            or increase max_n_iterations, tolerance or check for degenerate data.",
462                (n_runs + 1)
463            ))),
464        }
465    }
466}
467
468impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<usize>>
469    for GaussianMixtureModel<F>
470{
471    fn predict_inplace(&self, observations: &ArrayBase<D, Ix2>, targets: &mut Array1<usize>) {
472        assert_eq!(
473            observations.nrows(),
474            targets.len(),
475            "The number of data points must match the number of output targets."
476        );
477
478        let (_, log_resp) = self.estimate_log_prob_resp(observations);
479        *targets = log_resp
480            .mapv(F::exp)
481            .map_axis(Axis(1), |row| row.argmax().unwrap());
482    }
483
484    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<usize> {
485        Array1::zeros(x.nrows())
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use approx::{abs_diff_eq, assert_abs_diff_eq};
493    use linfa_datasets::generate;
494    use linfa_linalg::LinalgError;
495    use linfa_linalg::Result as LAResult;
496    use ndarray::Array;
497    use ndarray::{array, concatenate, ArrayView1, ArrayView2, Axis};
498    use ndarray_rand::rand::prelude::ThreadRng;
499    use ndarray_rand::rand::SeedableRng;
500    use ndarray_rand::rand_distr::Normal;
501    use ndarray_rand::rand_distr::{Distribution, StandardNormal};
502    use ndarray_rand::RandomExt;
503    use rand_xoshiro::Xoshiro256Plus;
504
505    #[test]
506    fn autotraits() {
507        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
508        has_autotraits::<GaussianMixtureModel<f64>>();
509        has_autotraits::<GmmError>();
510        has_autotraits::<GmmParams<f64, Xoshiro256Plus>>();
511        has_autotraits::<GmmValidParams<f64, Xoshiro256Plus>>();
512        has_autotraits::<GmmInitMethod>();
513        has_autotraits::<GmmCovarType>();
514    }
515
516    pub struct MultivariateNormal {
517        mean: Array1<f64>,
518        /// Lower triangular matrix (Cholesky decomposition of the covariance matrix)
519        lower: Array2<f64>,
520    }
521    impl MultivariateNormal {
522        pub fn new(mean: &ArrayView1<f64>, covariance: &ArrayView2<f64>) -> LAResult<Self> {
523            let lower = covariance.cholesky()?;
524            Ok(MultivariateNormal {
525                mean: mean.to_owned(),
526                lower,
527            })
528        }
529    }
530    impl Distribution<Array1<f64>> for MultivariateNormal {
531        fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
532            // standard normal distribution
533            let res = Array1::random_using(self.mean.shape()[0], StandardNormal, rng);
534            // use Cholesky decomposition to obtain a sample of our general multivariate normal
535            self.mean.clone() + self.lower.view().dot(&res)
536        }
537    }
538
539    #[test]
540    fn test_gmm_fit() {
541        let mut rng = Xoshiro256Plus::seed_from_u64(42);
542        let weights = array![0.5, 0.5];
543        let means = array![[0., 0.], [5., 5.]];
544        let covars = array![[[1., 0.8], [0.8, 1.]], [[1.0, -0.6], [-0.6, 1.0]]];
545        let mvn1 =
546            MultivariateNormal::new(&means.slice(s![0, ..]), &covars.slice(s![0, .., ..])).unwrap();
547        let mvn2 =
548            MultivariateNormal::new(&means.slice(s![1, ..]), &covars.slice(s![1, .., ..])).unwrap();
549
550        let n = 500;
551        let mut observations = Array2::zeros((2 * n, means.ncols()));
552        for (i, mut row) in observations.rows_mut().into_iter().enumerate() {
553            let sample = if i < n {
554                mvn1.sample(&mut rng)
555            } else {
556                mvn2.sample(&mut rng)
557            };
558            row.assign(&sample);
559        }
560        let dataset = DatasetBase::from(observations);
561        let gmm = GaussianMixtureModel::params(2)
562            .with_rng(rng)
563            .fit(&dataset)
564            .expect("GMM fitting");
565
566        // check weights
567        let w = gmm.weights();
568        assert_abs_diff_eq!(w, &weights, epsilon = 1e-1);
569        // check means (since kmeans centroids are ordered randomly, we try matching both orderings)
570        let m = gmm.means();
571        assert!(
572            abs_diff_eq!(means, &m, epsilon = 1e-1)
573                || abs_diff_eq!(means, m.slice(s![..;-1, ..]), epsilon = 1e-1)
574        );
575        // check covariances
576        let c = gmm.covariances();
577        assert!(
578            abs_diff_eq!(covars, &c, epsilon = 1e-1)
579                || abs_diff_eq!(covars, c.slice(s![..;-1, .., ..]), epsilon = 1e-1)
580        );
581    }
582
583    #[test]
584    fn test_gmm_covariances() {
585        let rng = rand_xoshiro::Xoshiro256Plus::seed_from_u64(123);
586
587        let data_0 = ndarray::Array::random((500,), Normal::new(0., 0.5).unwrap());
588        let data_1 = ndarray::Array::random((500,), Normal::new(1., 0.5).unwrap());
589        let data_2 = ndarray::Array::random((500,), Normal::new(2., 0.5).unwrap());
590        let data = ndarray::concatenate![ndarray::Axis(0), data_0, data_1, data_2];
591
592        let data_2d = data.insert_axis(ndarray::Axis(1)).to_owned();
593        let dataset = linfa::DatasetBase::from(data_2d);
594
595        let gmm = GaussianMixtureModel::params(3)
596            .n_runs(1)
597            .tolerance(1e-4)
598            .with_rng(rng)
599            .max_n_iterations(500)
600            .fit(&dataset)
601            .expect("GMM fit");
602
603        // expected results from scikit-learn 1.3.1
604        let expected = array![[[0.22564062]], [[0.26204446]], [[0.23393885]]];
605        let expected = Array::from_iter(expected.iter().cloned());
606        let actual = gmm.covariances();
607        let actual = Array::from_iter(actual.iter().cloned());
608        assert_abs_diff_eq!(expected, actual, epsilon = 1e-1);
609    }
610
611    fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
612        let mut y = Array2::zeros(x.dim());
613        Zip::from(&mut y).and(x).for_each(|yi, &xi| {
614            if xi < 0.4 {
615                *yi = xi * xi;
616            } else if (0.4..0.8).contains(&xi) {
617                *yi = 10. * xi + 1.;
618            } else {
619                *yi = f64::sin(10. * xi);
620            }
621        });
622        y
623    }
624
625    #[test]
626    fn test_zeroed_reg_covar_failure() {
627        let mut rng = Xoshiro256Plus::seed_from_u64(42);
628        let xt = Array2::random_using((50, 1), Uniform::new(0., 1.0), &mut rng);
629        let yt = function_test_1d(&xt);
630        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
631        let dataset = DatasetBase::from(data);
632
633        // Test that cholesky decomposition fails when reg_covariance is zero
634        let gmm = GaussianMixtureModel::params(3)
635            .reg_covariance(0.)
636            .with_rng(rng.clone())
637            .fit(&dataset);
638
639        match gmm.expect_err("should generate an error with reg_covar being nul") {
640            GmmError::LinalgError(e) => {
641                assert!(matches!(e, LinalgError::NotPositiveDefinite));
642            }
643            e => panic!("should be a linear algebra error: {:?}", e),
644        }
645        // Test it passes when default value is used
646        assert!(GaussianMixtureModel::params(3)
647            .with_rng(rng)
648            .fit(&dataset)
649            .is_ok());
650    }
651
652    #[test]
653    fn test_zeroed_reg_covar_const_failure() {
654        // repeat values such that covariance is zero
655        let xt = Array2::ones((50, 1));
656        let data = concatenate(Axis(1), &[xt.view(), xt.view()]).unwrap();
657        let dataset = DatasetBase::from(data);
658
659        // Test that cholesky decomposition fails when reg_covariance is zero
660        let gmm = GaussianMixtureModel::params(1)
661            .reg_covariance(0.)
662            .fit(&dataset);
663
664        gmm.expect_err("should generate an error with reg_covar being nul");
665
666        // Test it passes when default value is used
667        assert!(GaussianMixtureModel::params(1).fit(&dataset).is_ok());
668    }
669
670    #[test]
671    fn test_centroids_prediction() {
672        let mut rng = Xoshiro256Plus::seed_from_u64(42);
673        let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
674        let n = 1000;
675        let blobs = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));
676
677        let n_clusters = expected_centroids.len_of(Axis(0));
678        let gmm = GaussianMixtureModel::params(n_clusters)
679            .with_rng(rng)
680            .fit(&blobs)
681            .expect("GMM fitting");
682
683        let gmm_centroids = gmm.centroids();
684        let memberships = gmm.predict(&expected_centroids);
685
686        // check that centroids used to generate test dataset belongs to the right predicted cluster
687        for (i, expected_c) in expected_centroids.outer_iter().enumerate() {
688            let closest_c = gmm_centroids.index_axis(Axis(0), memberships[i]);
689            Zip::from(&closest_c)
690                .and(&expected_c)
691                .for_each(|a, b| assert_abs_diff_eq!(a, b, epsilon = 1.))
692        }
693    }
694
695    #[test]
696    fn test_invalid_n_runs() {
697        assert!(
698            GaussianMixtureModel::params(1)
699                .n_runs(0)
700                .fit(&DatasetBase::from(array![[0.]]))
701                .is_err(),
702            "n_runs must be strictly positive"
703        );
704    }
705
706    #[test]
707    fn test_invalid_tolerance() {
708        assert!(
709            GaussianMixtureModel::params(1)
710                .tolerance(0.)
711                .fit(&DatasetBase::from(array![[0.]]))
712                .is_err(),
713            "tolerance must be strictly positive"
714        );
715    }
716
717    #[test]
718    fn test_invalid_n_clusters() {
719        assert!(
720            GaussianMixtureModel::params(0)
721                .fit(&DatasetBase::from(array![[0., 0.]]))
722                .is_err(),
723            "n_clusters must be strictly positive"
724        );
725    }
726
727    #[test]
728    fn test_invalid_reg_covariance() {
729        assert!(
730            GaussianMixtureModel::params(1)
731                .reg_covariance(-1e-6)
732                .fit(&DatasetBase::from(array![[0.]]))
733                .is_err(),
734            "reg_covariance must be positive"
735        );
736    }
737
738    #[test]
739    fn test_invalid_max_n_iterations() {
740        assert!(
741            GaussianMixtureModel::params(1)
742                .max_n_iterations(0)
743                .fit(&DatasetBase::from(array![[0.]]))
744                .is_err(),
745            "max_n_iterations must be stricly positive"
746        );
747    }
748
749    fn fittable<T: Fit<Array2<f64>, (), GmmError>>(_: T) {}
750    #[test]
751    fn thread_rng_fittable() {
752        fittable(GaussianMixtureModel::params_with_rng(
753            1,
754            ThreadRng::default(),
755        ));
756    }
757
758    #[test]
759    fn test_predict_proba() {
760        let mut rng = Xoshiro256Plus::seed_from_u64(42);
761        let centroids = array![[0.0, 1.0], [-10.0, 20.0], [-1.0, 10.0]];
762        let n_samples_per_cluster = 1000;
763        let dataset =
764            DatasetBase::from(generate::blobs(n_samples_per_cluster, &centroids, &mut rng));
765        let n_clusters = centroids.len_of(Axis(0));
766        let n_samples = n_samples_per_cluster * n_clusters;
767
768        let gmm = GaussianMixtureModel::params(n_clusters)
769            .with_rng(rng)
770            .fit(&dataset)
771            .expect("Failed to fit GMM");
772
773        let proba = gmm.predict_proba(dataset.records());
774
775        assert_eq!(proba.dim(), (n_samples, n_clusters));
776
777        let row_sums = proba.sum_axis(Axis(1));
778        let ones = ndarray::Array1::ones(n_samples);
779        assert_abs_diff_eq!(row_sums, ones, epsilon = 1e-6);
780    }
781}