1use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
2use linfa::traits::{Fit, FitWith, PredictInplace};
3use linfa::{Float, Label};
4use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
5use ndarray_stats::QuantileExt;
6use std::collections::HashMap;
7use std::hash::Hash;
8
9use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
10use crate::error::{NaiveBayesError, Result};
11use crate::filter;
12use crate::hyperparams::{GaussianNbParams, GaussianNbValidParams};
13
14#[cfg(feature = "serde")]
15use serde_crate::{Deserialize, Serialize};
16
17impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for GaussianNbValidParams<F, L>
18where
19 F: Float,
20 L: Label + 'a,
21 D: Data<Elem = F>,
22 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
23{
24}
25
26impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for GaussianNbValidParams<F, L>
27where
28 F: Float,
29 L: Label + Ord,
30 D: Data<Elem = F>,
31 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
32{
33 type Object = GaussianNb<F, L>;
34
35 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
37 NaiveBayesValidParams::fit(self, dataset, None)
38 }
39}
40
41impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
42 for GaussianNbValidParams<F, L>
43where
44 F: Float,
45 L: Label + 'a,
46 D: Data<Elem = F>,
47 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
48{
49 type ObjectIn = Option<GaussianNb<F, L>>;
50 type ObjectOut = GaussianNb<F, L>;
51
52 fn fit_with(
53 &self,
54 model_in: Self::ObjectIn,
55 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
56 ) -> Result<Self::ObjectOut> {
57 let x = dataset.records();
58 let y = dataset.as_single_targets();
59
60 let epsilon = self.var_smoothing() * *x.var_axis(Axis(0), F::zero()).max()?;
64
65 let mut model = match model_in {
66 Some(mut temp) => {
67 temp.class_info
68 .values_mut()
69 .for_each(|x| x.sigma -= epsilon);
70 temp
71 }
72 None => GaussianNb {
73 class_info: HashMap::new(),
74 },
75 };
76
77 let yunique = dataset.labels();
78
79 for class in yunique {
80 let xclass = filter(x.view(), y.view(), &class);
82
83 let nclass = xclass.nrows();
85
86 let class_info = model
88 .class_info
89 .entry(class)
90 .or_insert_with(GaussianClassInfo::default);
91
92 let (theta_new, sigma_new) = Self::update_mean_variance(class_info, xclass.view());
93
94 class_info.theta = theta_new;
96 class_info.sigma = sigma_new;
97 class_info.class_count += nclass;
98 }
99
100 model
103 .class_info
104 .values_mut()
105 .for_each(|x| x.sigma += epsilon);
106
107 let class_count_sum = model
109 .class_info
110 .values()
111 .map(|x| x.class_count)
112 .sum::<usize>();
113
114 for info in model.class_info.values_mut() {
115 info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
116 }
117
118 Ok(model)
119 }
120}
121
122impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for GaussianNb<F, L>
123where
124 D: Data<Elem = F>,
125{
126 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
128 NaiveBayes::predict_inplace(self, x, y);
129 }
130
131 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
132 Array1::default(x.nrows())
133 }
134}
135
136impl<F, L> GaussianNbValidParams<F, L>
137where
138 F: Float,
139{
140 fn update_mean_variance(
142 info_old: &GaussianClassInfo<F>,
143 x_new: ArrayView2<F>,
144 ) -> (Array1<F>, Array1<F>) {
145 let (count_old, mu_old, var_old) = (info_old.class_count, &info_old.theta, &info_old.sigma);
147
148 if x_new.nrows() == 0 {
150 return (mu_old.to_owned(), var_old.to_owned());
151 }
152
153 let count_new = x_new.nrows();
154
155 let mu_new = x_new.mean_axis(Axis(0)).unwrap();
158 let var_new = x_new.var_axis(Axis(0), F::zero());
159
160 if count_old == 0 {
162 return (mu_new, var_new);
163 }
164
165 let count_total = count_old + count_new;
166
167 let mu_new_weighted = &mu_new * F::cast(count_new);
170 let mu_old_weighted = mu_old * F::cast(count_old);
171 let mu_weighted = (mu_new_weighted + mu_old_weighted).mapv(|x| x / F::cast(count_total));
172
173 let ssd_old = var_old * F::cast(count_old);
177 let ssd_new = var_new * F::cast(count_new);
178 let weight = F::cast(count_new * count_old) / F::cast(count_total);
179 let ssd_weighted = ssd_old + ssd_new + (mu_old - mu_new).mapv(|x| weight * x.powi(2));
180 let var_weighted = ssd_weighted.mapv(|x| x / F::cast(count_total));
181
182 (mu_weighted, var_weighted)
183 }
184}
185
186#[cfg_attr(
233 feature = "serde",
234 derive(Serialize, Deserialize),
235 serde(crate = "serde_crate")
236)]
237#[derive(Debug, Clone, PartialEq)]
238pub struct GaussianNb<F: PartialEq, L: Eq + Hash> {
239 class_info: HashMap<L, GaussianClassInfo<F>>,
240}
241
242#[cfg_attr(
243 feature = "serde",
244 derive(Serialize, Deserialize),
245 serde(crate = "serde_crate")
246)]
247#[derive(Debug, Default, Clone, PartialEq)]
248struct GaussianClassInfo<F> {
249 class_count: usize,
250 prior: F,
251 theta: Array1<F>,
252 sigma: Array1<F>,
253}
254
255impl<F: Float, L: Label> GaussianNb<F, L> {
256 pub fn params() -> GaussianNbParams<F, L> {
258 GaussianNbParams::new()
259 }
260}
261
262impl<F, L> NaiveBayes<'_, F, L> for GaussianNb<F, L>
263where
264 F: Float,
265 L: Label + Ord,
266{
267 fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
269 let mut joint_log_likelihood = HashMap::new();
270
271 for (class, info) in self.class_info.iter() {
272 let jointi = info.prior.ln();
273
274 let mut nij = info
275 .sigma
276 .mapv(|x| F::cast(2. * std::f64::consts::PI) * x)
277 .mapv(|x| x.ln())
278 .sum();
279 nij = F::cast(-0.5) * nij;
280
281 let nij = ((x.to_owned() - &info.theta).mapv(|x| x.powi(2)) / &info.sigma)
282 .sum_axis(Axis(1))
283 .mapv(|x| x * F::cast(0.5))
284 .mapv(|x| nij - x);
285
286 joint_log_likelihood.insert(class, nij + jointi);
287 }
288
289 joint_log_likelihood
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::{GaussianNb, NaiveBayes, Result};
296 use linfa::{
297 traits::{Fit, FitWith, Predict},
298 DatasetView, Error,
299 };
300
301 use crate::gaussian_nb::GaussianClassInfo;
302 use crate::{GaussianNbParams, GaussianNbValidParams, NaiveBayesError};
303 use approx::assert_abs_diff_eq;
304 use ndarray::{array, Axis};
305 use std::collections::HashMap;
306
307 #[test]
308 fn autotraits() {
309 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
310 has_autotraits::<GaussianNb<f64, usize>>();
311 has_autotraits::<GaussianClassInfo<f64>>();
312 has_autotraits::<GaussianNbParams<f64, usize>>();
313 has_autotraits::<GaussianNbValidParams<f64, usize>>();
314 has_autotraits::<NaiveBayesError>();
315 }
316
317 #[test]
318 fn test_gaussian_nb() -> Result<()> {
319 let x = array![
320 [-2., -1.],
321 [-1., -1.],
322 [-1., -2.],
323 [1., 1.],
324 [1., 2.],
325 [2., 1.]
326 ];
327 let y = array![1, 1, 1, 2, 2, 2];
328
329 let data = DatasetView::new(x.view(), y.view());
330 let fitted_clf = GaussianNb::params().fit(&data)?;
331 let pred = fitted_clf.predict(&x);
332
333 assert_abs_diff_eq!(pred, y);
334
335 let jll = fitted_clf.joint_log_likelihood(x.view());
336
337 let mut expected = HashMap::new();
339 expected.insert(
340 &1usize,
341 array![
342 -2.276946847943017,
343 -1.5269468546930165,
344 -2.276946847943017,
345 -25.52694663869301,
346 -38.27694652394301,
347 -38.27694652394301
348 ],
349 );
350 expected.insert(
351 &2usize,
352 array![
353 -38.27694652394301,
354 -25.52694663869301,
355 -38.27694652394301,
356 -1.5269468546930165,
357 -2.276946847943017,
358 -2.276946847943017
359 ],
360 );
361
362 assert_eq!(jll, expected);
363
364 let expected_proba = array![
365 [1.00000000e+00, 2.31952358e-16],
366 [1.00000000e+00, 3.77513536e-11],
367 [1.00000000e+00, 2.31952358e-16],
368 [3.77513536e-11, 1.00000000e+00],
369 [2.31952358e-16, 1.00000000e+00],
370 [2.31952358e-16, 1.00000000e+00]
371 ];
372
373 let (y_pred_proba, classes) = fitted_clf.predict_proba(x.view());
374 assert_eq!(classes, vec![&1usize, &2]);
375 assert_abs_diff_eq!(expected_proba, y_pred_proba, epsilon = 1e-10);
376
377 let (y_pred_log_proba, classes) = fitted_clf.predict_log_proba(x.view());
378 assert_eq!(classes, vec![&1usize, &2]);
379 assert_abs_diff_eq!(
380 y_pred_proba.mapv(f64::ln),
381 y_pred_log_proba,
382 epsilon = 1e-10
383 );
384
385 Ok(())
386 }
387
388 #[test]
389 fn test_gnb_fit_with() -> Result<()> {
390 let x = array![
391 [-2., -1.],
392 [-1., -1.],
393 [-1., -2.],
394 [1., 1.],
395 [1., 2.],
396 [2., 1.]
397 ];
398 let y = array![1, 1, 1, 2, 2, 2];
399
400 let clf = GaussianNb::params();
401
402 let model = x
403 .axis_chunks_iter(Axis(0), 2)
404 .zip(y.axis_chunks_iter(Axis(0), 2))
405 .map(|(a, b)| DatasetView::new(a, b))
406 .try_fold(None, |current, d| clf.fit_with(current, &d).map(Some))?
407 .ok_or(Error::NotEnoughSamples)?;
408
409 let pred = model.predict(&x);
410
411 assert_abs_diff_eq!(pred, y);
412
413 let jll = model.joint_log_likelihood(x.view());
414
415 let mut expected = HashMap::new();
416 expected.insert(
417 &1usize,
418 array![
419 -2.276946847943017,
420 -1.5269468546930165,
421 -2.276946847943017,
422 -25.52694663869301,
423 -38.27694652394301,
424 -38.27694652394301
425 ],
426 );
427 expected.insert(
428 &2usize,
429 array![
430 -38.27694652394301,
431 -25.52694663869301,
432 -38.27694652394301,
433 -1.5269468546930165,
434 -2.276946847943017,
435 -2.276946847943017
436 ],
437 );
438
439 for (key, value) in jll.iter() {
440 assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
441 }
442
443 Ok(())
444 }
445}