1use crate::gaussian_mixture::errors::GmmError;
2use crate::gaussian_mixture::hyperparams::{
3 GmmCovarType, GmmInitMethod, GmmParams, GmmValidParams,
4};
5use crate::k_means::KMeans;
6use linfa::{prelude::*, DatasetBase, Float};
7use linfa_linalg::{cholesky::*, triangular::*};
8use ndarray::{s, Array, Array1, Array2, Array3, ArrayBase, Axis, Data, Ix2, Ix3, Zip};
9use ndarray_rand::rand::Rng;
10use ndarray_rand::rand_distr::Uniform;
11use ndarray_rand::RandomExt;
12use ndarray_stats::QuantileExt;
13use rand_xoshiro::Xoshiro256Plus;
14#[cfg(feature = "serde")]
15use serde_crate::{Deserialize, Serialize};
16
17#[cfg_attr(
18 feature = "serde",
19 derive(Serialize, Deserialize),
20 serde(crate = "serde_crate")
21)]
22#[derive(Debug, PartialEq)]
102pub struct GaussianMixtureModel<F: Float> {
103 covar_type: GmmCovarType,
104 weights: Array1<F>,
105 means: Array2<F>,
106 covariances: Array3<F>,
107 precisions: Array3<F>,
108 precisions_chol: Array3<F>,
109}
110
111impl<F: Float> Clone for GaussianMixtureModel<F> {
112 fn clone(&self) -> Self {
113 Self {
114 covar_type: self.covar_type,
115 weights: self.weights.to_owned(),
116 means: self.means.to_owned(),
117 covariances: self.covariances.to_owned(),
118 precisions: self.precisions.to_owned(),
119 precisions_chol: self.precisions_chol.to_owned(),
120 }
121 }
122}
123
124impl<F: Float> GaussianMixtureModel<F> {
125 fn new<D: Data<Elem = F>, R: Rng + Clone, T>(
126 hyperparameters: &GmmValidParams<F, R>,
127 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
128 mut rng: R,
129 ) -> Result<GaussianMixtureModel<F>, GmmError> {
130 let observations = dataset.records().view();
131 let n_samples = observations.nrows();
132
133 let resp = match hyperparameters.init_method() {
137 GmmInitMethod::KMeans => {
138 let model = KMeans::params_with_rng(hyperparameters.n_clusters(), rng)
139 .check()
140 .unwrap()
141 .fit(dataset)?;
142 let mut resp = Array::<F, Ix2>::zeros((n_samples, hyperparameters.n_clusters()));
143 for (k, idx) in model.predict(dataset.records()).iter().enumerate() {
144 resp[[k, *idx]] = F::cast(1.);
145 }
146 resp
147 }
148 GmmInitMethod::Random => {
149 let mut resp = Array2::<f64>::random_using(
150 (n_samples, hyperparameters.n_clusters()),
151 Uniform::new(0., 1.),
152 &mut rng,
153 );
154 let totals = &resp.sum_axis(Axis(1)).insert_axis(Axis(0));
155 resp = (resp.reversed_axes() / totals).reversed_axes();
156 resp.mapv(F::cast)
157 }
158 };
159
160 let (mut weights, means, covariances) = Self::estimate_gaussian_parameters(
163 &observations,
164 &resp,
165 hyperparameters.covariance_type(),
166 hyperparameters.reg_covariance(),
167 )?;
168 weights /= F::cast(n_samples);
169
170 let precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
172 let precisions = Self::compute_precisions_full(&precisions_chol);
173
174 Ok(GaussianMixtureModel {
175 covar_type: *hyperparameters.covariance_type(),
176 weights,
177 means,
178 covariances,
179 precisions,
180 precisions_chol,
181 })
182 }
183}
184
185impl<F: Float> GaussianMixtureModel<F> {
186 pub fn params(n_clusters: usize) -> GmmParams<F, Xoshiro256Plus> {
187 GmmParams::new(n_clusters)
188 }
189
190 pub fn params_with_rng<R: Rng + Clone>(n_clusters: usize, rng: R) -> GmmParams<F, R> {
191 GmmParams::new_with_rng(n_clusters, rng)
192 }
193
194 pub fn weights(&self) -> &Array1<F> {
195 &self.weights
196 }
197
198 pub fn means(&self) -> &Array2<F> {
199 &self.means
200 }
201
202 pub fn covariances(&self) -> &Array3<F> {
203 &self.covariances
204 }
205
206 pub fn precisions(&self) -> &Array3<F> {
207 &self.precisions
208 }
209
210 pub fn centroids(&self) -> &Array2<F> {
211 self.means()
212 }
213 pub fn predict_proba<D: Data<Elem = F>>(&self, observations: &ArrayBase<D, Ix2>) -> Array2<F> {
216 let (_, log_resp) = self.estimate_log_prob_resp(observations);
217 log_resp.mapv(F::exp)
218 }
219
220 #[allow(clippy::type_complexity)]
221 fn estimate_gaussian_parameters<D: Data<Elem = F>>(
222 observations: &ArrayBase<D, Ix2>,
223 resp: &Array2<F>,
224 _covar_type: &GmmCovarType,
225 reg_covar: F,
226 ) -> Result<(Array1<F>, Array2<F>, Array3<F>), GmmError> {
227 let nk = resp.sum_axis(Axis(0));
228 if nk.min()? < &(F::cast(10.) * F::epsilon()) {
229 return Err(GmmError::EmptyCluster(format!(
230 "Cluster #{} has no more point. Consider decreasing number of clusters or change initialization.",
231 nk.argmin()? + 1
232 )));
233 }
234
235 let nk2 = nk.to_owned().insert_axis(Axis(1));
236 let means = resp.t().dot(observations) / nk2;
237 let covariances =
239 Self::estimate_gaussian_covariances_full(observations, resp, &nk, &means, reg_covar);
240 Ok((nk, means, covariances))
241 }
242
243 fn estimate_gaussian_covariances_full<D: Data<Elem = F>>(
244 observations: &ArrayBase<D, Ix2>,
245 resp: &Array2<F>,
246 nk: &Array1<F>,
247 means: &Array2<F>,
248 reg_covar: F,
249 ) -> Array3<F> {
250 let n_clusters = means.nrows();
251 let n_features = means.ncols();
252 let mut covariances = Array::zeros((n_clusters, n_features, n_features));
253 for k in 0..n_clusters {
254 let diff = observations - &means.row(k);
255 let m = &diff.t() * &resp.index_axis(Axis(1), k);
256 let mut cov_k = m.dot(&diff) / nk[k];
257 cov_k.diag_mut().mapv_inplace(|x| x + reg_covar);
258 covariances.slice_mut(s![k, .., ..]).assign(&cov_k);
259 }
260 covariances
261 }
262
263 fn compute_precisions_cholesky_full<D: Data<Elem = F>>(
264 covariances: &ArrayBase<D, Ix3>,
265 ) -> Result<Array3<F>, GmmError> {
266 let n_clusters = covariances.shape()[0];
267 let n_features = covariances.shape()[1];
268 let mut precisions_chol = Array::zeros((n_clusters, n_features, n_features));
269 for (k, covariance) in covariances.outer_iter().enumerate() {
270 let sol = {
271 let decomp = covariance.cholesky()?;
272 decomp.solve_triangular_into(Array::eye(n_features), UPLO::Lower)?
273 };
274
275 precisions_chol.slice_mut(s![k, .., ..]).assign(&sol.t());
276 }
277 Ok(precisions_chol)
278 }
279
280 fn compute_precisions_full<D: Data<Elem = F>>(
281 precisions_chol: &ArrayBase<D, Ix3>,
282 ) -> Array3<F> {
283 let mut precisions = Array3::zeros(precisions_chol.dim());
284 for (k, prec_chol) in precisions_chol.outer_iter().enumerate() {
285 precisions
286 .slice_mut(s![k, .., ..])
287 .assign(&prec_chol.dot(&prec_chol.t()));
288 }
289 precisions
290 }
291
292 fn refresh_precisions_full(&mut self) {
294 self.precisions = Self::compute_precisions_full(&self.precisions_chol);
295 }
296
297 fn e_step<D: Data<Elem = F>>(
298 &self,
299 observations: &ArrayBase<D, Ix2>,
300 ) -> Result<(F, Array2<F>), GmmError> {
301 let (log_prob_norm, log_resp) = self.estimate_log_prob_resp(observations);
302 let log_mean = log_prob_norm.mean().unwrap();
303 Ok((log_mean, log_resp))
304 }
305
306 fn m_step<D: Data<Elem = F>>(
307 &mut self,
308 reg_covar: F,
309 observations: &ArrayBase<D, Ix2>,
310 log_resp: &Array2<F>,
311 ) -> Result<(), GmmError> {
312 let n_samples = observations.nrows();
313 let (weights, means, covariances) = Self::estimate_gaussian_parameters(
314 observations,
315 &log_resp.mapv(|x| x.exp()),
316 &self.covar_type,
317 reg_covar,
318 )?;
319 self.means = means;
320 self.weights = weights / F::cast(n_samples);
321 self.covariances = covariances;
322 self.precisions_chol = Self::compute_precisions_cholesky_full(&self.covariances)?;
324 Ok(())
325 }
326
327 fn compute_lower_bound<D: Data<Elem = F>>(
330 _log_resp: &ArrayBase<D, Ix2>,
331 log_prob_norm: F,
332 ) -> F {
333 log_prob_norm
334 }
335
336 fn estimate_log_prob_resp<D: Data<Elem = F>>(
340 &self,
341 observations: &ArrayBase<D, Ix2>,
342 ) -> (Array1<F>, Array2<F>) {
343 let weighted_log_prob = self.estimate_weighted_log_prob(observations);
344 let log_prob_norm = weighted_log_prob
345 .mapv(|x| x.exp())
346 .sum_axis(Axis(1))
347 .mapv(|x| x.ln());
348 let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1));
349 (log_prob_norm, log_resp)
350 }
351
352 fn estimate_weighted_log_prob<D: Data<Elem = F>>(
354 &self,
355 observations: &ArrayBase<D, Ix2>,
356 ) -> Array2<F> {
357 self.estimate_log_prob(observations) + self.estimate_log_weights()
358 }
359
360 fn estimate_log_prob<D: Data<Elem = F>>(&self, observations: &ArrayBase<D, Ix2>) -> Array2<F> {
362 self.estimate_log_gaussian_prob(observations)
363 }
364
365 fn estimate_log_gaussian_prob<D: Data<Elem = F>>(
368 &self,
369 observations: &ArrayBase<D, Ix2>,
370 ) -> Array2<F> {
371 let n_samples = observations.nrows();
372 let n_features = observations.ncols();
373 let means = self.means();
374 let n_clusters = means.nrows();
375 let log_det = Self::compute_log_det_cholesky_full(&self.precisions_chol, n_features);
378 let mut log_prob: Array2<F> = Array::zeros((n_samples, n_clusters));
379 Zip::indexed(means.rows())
380 .and(self.precisions_chol.outer_iter())
381 .for_each(|k, mu, prec_chol| {
382 let diff = (&observations.to_owned() - &mu).dot(&prec_chol);
383 log_prob
384 .slice_mut(s![.., k])
385 .assign(&diff.mapv(|v| v * v).sum_axis(Axis(1)))
386 });
387 log_prob.mapv(|v| {
388 F::cast(-0.5) * (v + F::cast(n_features as f64 * f64::ln(2. * std::f64::consts::PI)))
389 }) + log_det
390 }
391
392 fn compute_log_det_cholesky_full<D: Data<Elem = F>>(
393 matrix_chol: &ArrayBase<D, Ix3>,
394 n_features: usize,
395 ) -> Array1<F> {
396 let n_clusters = matrix_chol.shape()[0];
397 let log_diags = &matrix_chol
398 .to_owned()
399 .into_shape((n_clusters, n_features * n_features))
400 .unwrap()
401 .slice(s![.., ..; n_features+1])
402 .to_owned()
403 .mapv(|x| x.ln());
404 log_diags.sum_axis(Axis(1))
405 }
406
407 fn estimate_log_weights(&self) -> Array1<F> {
408 self.weights().mapv(|x| x.ln())
409 }
410}
411
412impl<F: Float, R: Rng + Clone, D: Data<Elem = F>, T> Fit<ArrayBase<D, Ix2>, T, GmmError>
413 for GmmValidParams<F, R>
414{
415 type Object = GaussianMixtureModel<F>;
416
417 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, GmmError> {
418 let observations = dataset.records().view();
419 let mut gmm = GaussianMixtureModel::<F>::new(self, dataset, self.rng())?;
420
421 let mut max_lower_bound = -F::infinity();
422 let mut best_params = None;
423 let mut best_iter = None;
424
425 let n_runs = self.n_runs();
426
427 for _ in 0..n_runs {
428 let mut lower_bound = -F::infinity();
429
430 let mut converged_iter: Option<u64> = None;
431 for n_iter in 0..self.max_n_iterations() {
432 let prev_lower_bound = lower_bound;
433 let (log_prob_norm, log_resp) = gmm.e_step(&observations)?;
434 gmm.m_step(self.reg_covariance(), &observations, &log_resp)?;
435 lower_bound =
436 GaussianMixtureModel::<F>::compute_lower_bound(&log_resp, log_prob_norm);
437 let change = lower_bound - prev_lower_bound;
438 if change.abs() < self.tolerance() {
439 converged_iter = Some(n_iter);
440 break;
441 }
442 }
443
444 if lower_bound > max_lower_bound {
445 max_lower_bound = lower_bound;
446 gmm.refresh_precisions_full();
447 best_params = Some(gmm.clone());
448 best_iter = converged_iter;
449 }
450 }
451
452 match best_iter {
453 Some(_n_iter) => match best_params {
454 Some(gmm) => Ok(gmm),
455 _ => Err(GmmError::LowerBoundError(
456 "No lower bound improvement (-inf)".to_string(),
457 )),
458 },
459 None => Err(GmmError::NotConverged(format!(
460 "EM fitting algorithm {} did not converge. Try different init parameters, \
461 or increase max_n_iterations, tolerance or check for degenerate data.",
462 (n_runs + 1)
463 ))),
464 }
465 }
466}
467
468impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<usize>>
469 for GaussianMixtureModel<F>
470{
471 fn predict_inplace(&self, observations: &ArrayBase<D, Ix2>, targets: &mut Array1<usize>) {
472 assert_eq!(
473 observations.nrows(),
474 targets.len(),
475 "The number of data points must match the number of output targets."
476 );
477
478 let (_, log_resp) = self.estimate_log_prob_resp(observations);
479 *targets = log_resp
480 .mapv(F::exp)
481 .map_axis(Axis(1), |row| row.argmax().unwrap());
482 }
483
484 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<usize> {
485 Array1::zeros(x.nrows())
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use approx::{abs_diff_eq, assert_abs_diff_eq};
493 use linfa_datasets::generate;
494 use linfa_linalg::LinalgError;
495 use linfa_linalg::Result as LAResult;
496 use ndarray::Array;
497 use ndarray::{array, concatenate, ArrayView1, ArrayView2, Axis};
498 use ndarray_rand::rand::prelude::ThreadRng;
499 use ndarray_rand::rand::SeedableRng;
500 use ndarray_rand::rand_distr::Normal;
501 use ndarray_rand::rand_distr::{Distribution, StandardNormal};
502 use ndarray_rand::RandomExt;
503 use rand_xoshiro::Xoshiro256Plus;
504
505 #[test]
506 fn autotraits() {
507 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
508 has_autotraits::<GaussianMixtureModel<f64>>();
509 has_autotraits::<GmmError>();
510 has_autotraits::<GmmParams<f64, Xoshiro256Plus>>();
511 has_autotraits::<GmmValidParams<f64, Xoshiro256Plus>>();
512 has_autotraits::<GmmInitMethod>();
513 has_autotraits::<GmmCovarType>();
514 }
515
516 pub struct MultivariateNormal {
517 mean: Array1<f64>,
518 lower: Array2<f64>,
520 }
521 impl MultivariateNormal {
522 pub fn new(mean: &ArrayView1<f64>, covariance: &ArrayView2<f64>) -> LAResult<Self> {
523 let lower = covariance.cholesky()?;
524 Ok(MultivariateNormal {
525 mean: mean.to_owned(),
526 lower,
527 })
528 }
529 }
530 impl Distribution<Array1<f64>> for MultivariateNormal {
531 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
532 let res = Array1::random_using(self.mean.shape()[0], StandardNormal, rng);
534 self.mean.clone() + self.lower.view().dot(&res)
536 }
537 }
538
539 #[test]
540 fn test_gmm_fit() {
541 let mut rng = Xoshiro256Plus::seed_from_u64(42);
542 let weights = array![0.5, 0.5];
543 let means = array![[0., 0.], [5., 5.]];
544 let covars = array![[[1., 0.8], [0.8, 1.]], [[1.0, -0.6], [-0.6, 1.0]]];
545 let mvn1 =
546 MultivariateNormal::new(&means.slice(s![0, ..]), &covars.slice(s![0, .., ..])).unwrap();
547 let mvn2 =
548 MultivariateNormal::new(&means.slice(s![1, ..]), &covars.slice(s![1, .., ..])).unwrap();
549
550 let n = 500;
551 let mut observations = Array2::zeros((2 * n, means.ncols()));
552 for (i, mut row) in observations.rows_mut().into_iter().enumerate() {
553 let sample = if i < n {
554 mvn1.sample(&mut rng)
555 } else {
556 mvn2.sample(&mut rng)
557 };
558 row.assign(&sample);
559 }
560 let dataset = DatasetBase::from(observations);
561 let gmm = GaussianMixtureModel::params(2)
562 .with_rng(rng)
563 .fit(&dataset)
564 .expect("GMM fitting");
565
566 let w = gmm.weights();
568 assert_abs_diff_eq!(w, &weights, epsilon = 1e-1);
569 let m = gmm.means();
571 assert!(
572 abs_diff_eq!(means, &m, epsilon = 1e-1)
573 || abs_diff_eq!(means, m.slice(s![..;-1, ..]), epsilon = 1e-1)
574 );
575 let c = gmm.covariances();
577 assert!(
578 abs_diff_eq!(covars, &c, epsilon = 1e-1)
579 || abs_diff_eq!(covars, c.slice(s![..;-1, .., ..]), epsilon = 1e-1)
580 );
581 }
582
583 #[test]
584 fn test_gmm_covariances() {
585 let rng = rand_xoshiro::Xoshiro256Plus::seed_from_u64(123);
586
587 let data_0 = ndarray::Array::random((500,), Normal::new(0., 0.5).unwrap());
588 let data_1 = ndarray::Array::random((500,), Normal::new(1., 0.5).unwrap());
589 let data_2 = ndarray::Array::random((500,), Normal::new(2., 0.5).unwrap());
590 let data = ndarray::concatenate![ndarray::Axis(0), data_0, data_1, data_2];
591
592 let data_2d = data.insert_axis(ndarray::Axis(1)).to_owned();
593 let dataset = linfa::DatasetBase::from(data_2d);
594
595 let gmm = GaussianMixtureModel::params(3)
596 .n_runs(1)
597 .tolerance(1e-4)
598 .with_rng(rng)
599 .max_n_iterations(500)
600 .fit(&dataset)
601 .expect("GMM fit");
602
603 let expected = array![[[0.22564062]], [[0.26204446]], [[0.23393885]]];
605 let expected = Array::from_iter(expected.iter().cloned());
606 let actual = gmm.covariances();
607 let actual = Array::from_iter(actual.iter().cloned());
608 assert_abs_diff_eq!(expected, actual, epsilon = 1e-1);
609 }
610
611 fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
612 let mut y = Array2::zeros(x.dim());
613 Zip::from(&mut y).and(x).for_each(|yi, &xi| {
614 if xi < 0.4 {
615 *yi = xi * xi;
616 } else if (0.4..0.8).contains(&xi) {
617 *yi = 10. * xi + 1.;
618 } else {
619 *yi = f64::sin(10. * xi);
620 }
621 });
622 y
623 }
624
625 #[test]
626 fn test_zeroed_reg_covar_failure() {
627 let mut rng = Xoshiro256Plus::seed_from_u64(42);
628 let xt = Array2::random_using((50, 1), Uniform::new(0., 1.0), &mut rng);
629 let yt = function_test_1d(&xt);
630 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
631 let dataset = DatasetBase::from(data);
632
633 let gmm = GaussianMixtureModel::params(3)
635 .reg_covariance(0.)
636 .with_rng(rng.clone())
637 .fit(&dataset);
638
639 match gmm.expect_err("should generate an error with reg_covar being nul") {
640 GmmError::LinalgError(e) => {
641 assert!(matches!(e, LinalgError::NotPositiveDefinite));
642 }
643 e => panic!("should be a linear algebra error: {:?}", e),
644 }
645 assert!(GaussianMixtureModel::params(3)
647 .with_rng(rng)
648 .fit(&dataset)
649 .is_ok());
650 }
651
652 #[test]
653 fn test_zeroed_reg_covar_const_failure() {
654 let xt = Array2::ones((50, 1));
656 let data = concatenate(Axis(1), &[xt.view(), xt.view()]).unwrap();
657 let dataset = DatasetBase::from(data);
658
659 let gmm = GaussianMixtureModel::params(1)
661 .reg_covariance(0.)
662 .fit(&dataset);
663
664 gmm.expect_err("should generate an error with reg_covar being nul");
665
666 assert!(GaussianMixtureModel::params(1).fit(&dataset).is_ok());
668 }
669
670 #[test]
671 fn test_centroids_prediction() {
672 let mut rng = Xoshiro256Plus::seed_from_u64(42);
673 let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
674 let n = 1000;
675 let blobs = DatasetBase::from(generate::blobs(n, &expected_centroids, &mut rng));
676
677 let n_clusters = expected_centroids.len_of(Axis(0));
678 let gmm = GaussianMixtureModel::params(n_clusters)
679 .with_rng(rng)
680 .fit(&blobs)
681 .expect("GMM fitting");
682
683 let gmm_centroids = gmm.centroids();
684 let memberships = gmm.predict(&expected_centroids);
685
686 for (i, expected_c) in expected_centroids.outer_iter().enumerate() {
688 let closest_c = gmm_centroids.index_axis(Axis(0), memberships[i]);
689 Zip::from(&closest_c)
690 .and(&expected_c)
691 .for_each(|a, b| assert_abs_diff_eq!(a, b, epsilon = 1.))
692 }
693 }
694
695 #[test]
696 fn test_invalid_n_runs() {
697 assert!(
698 GaussianMixtureModel::params(1)
699 .n_runs(0)
700 .fit(&DatasetBase::from(array![[0.]]))
701 .is_err(),
702 "n_runs must be strictly positive"
703 );
704 }
705
706 #[test]
707 fn test_invalid_tolerance() {
708 assert!(
709 GaussianMixtureModel::params(1)
710 .tolerance(0.)
711 .fit(&DatasetBase::from(array![[0.]]))
712 .is_err(),
713 "tolerance must be strictly positive"
714 );
715 }
716
717 #[test]
718 fn test_invalid_n_clusters() {
719 assert!(
720 GaussianMixtureModel::params(0)
721 .fit(&DatasetBase::from(array![[0., 0.]]))
722 .is_err(),
723 "n_clusters must be strictly positive"
724 );
725 }
726
727 #[test]
728 fn test_invalid_reg_covariance() {
729 assert!(
730 GaussianMixtureModel::params(1)
731 .reg_covariance(-1e-6)
732 .fit(&DatasetBase::from(array![[0.]]))
733 .is_err(),
734 "reg_covariance must be positive"
735 );
736 }
737
738 #[test]
739 fn test_invalid_max_n_iterations() {
740 assert!(
741 GaussianMixtureModel::params(1)
742 .max_n_iterations(0)
743 .fit(&DatasetBase::from(array![[0.]]))
744 .is_err(),
745 "max_n_iterations must be stricly positive"
746 );
747 }
748
749 fn fittable<T: Fit<Array2<f64>, (), GmmError>>(_: T) {}
750 #[test]
751 fn thread_rng_fittable() {
752 fittable(GaussianMixtureModel::params_with_rng(
753 1,
754 ThreadRng::default(),
755 ));
756 }
757
758 #[test]
759 fn test_predict_proba() {
760 let mut rng = Xoshiro256Plus::seed_from_u64(42);
761 let centroids = array![[0.0, 1.0], [-10.0, 20.0], [-1.0, 10.0]];
762 let n_samples_per_cluster = 1000;
763 let dataset =
764 DatasetBase::from(generate::blobs(n_samples_per_cluster, ¢roids, &mut rng));
765 let n_clusters = centroids.len_of(Axis(0));
766 let n_samples = n_samples_per_cluster * n_clusters;
767
768 let gmm = GaussianMixtureModel::params(n_clusters)
769 .with_rng(rng)
770 .fit(&dataset)
771 .expect("Failed to fit GMM");
772
773 let proba = gmm.predict_proba(dataset.records());
774
775 assert_eq!(proba.dim(), (n_samples, n_clusters));
776
777 let row_sums = proba.sum_axis(Axis(1));
778 let ones = ndarray::Array1::ones(n_samples);
779 assert_abs_diff_eq!(row_sums, ones, epsilon = 1e-6);
780 }
781}