linfa_bayes/
gaussian_nb.rs

1use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
2use linfa::traits::{Fit, FitWith, PredictInplace};
3use linfa::{Float, Label};
4use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
5use ndarray_stats::QuantileExt;
6use std::collections::HashMap;
7use std::hash::Hash;
8
9use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
10use crate::error::{NaiveBayesError, Result};
11use crate::filter;
12use crate::hyperparams::{GaussianNbParams, GaussianNbValidParams};
13
14#[cfg(feature = "serde")]
15use serde_crate::{Deserialize, Serialize};
16
17impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for GaussianNbValidParams<F, L>
18where
19    F: Float,
20    L: Label + 'a,
21    D: Data<Elem = F>,
22    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
23{
24}
25
26impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for GaussianNbValidParams<F, L>
27where
28    F: Float,
29    L: Label + Ord,
30    D: Data<Elem = F>,
31    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
32{
33    type Object = GaussianNb<F, L>;
34
35    // Thin wrapper around the corresponding method of NaiveBayesValidParams
36    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
37        NaiveBayesValidParams::fit(self, dataset, None)
38    }
39}
40
41impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
42    for GaussianNbValidParams<F, L>
43where
44    F: Float,
45    L: Label + 'a,
46    D: Data<Elem = F>,
47    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
48{
49    type ObjectIn = Option<GaussianNb<F, L>>;
50    type ObjectOut = GaussianNb<F, L>;
51
52    fn fit_with(
53        &self,
54        model_in: Self::ObjectIn,
55        dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
56    ) -> Result<Self::ObjectOut> {
57        let x = dataset.records();
58        let y = dataset.as_single_targets();
59
60        // If the ratio of the variance between dimensions is too small, it will cause
61        // numerical errors. We address this by artificially boosting the variance
62        // by `epsilon` (a small fraction of the variance of the largest feature)
63        let epsilon = self.var_smoothing() * *x.var_axis(Axis(0), F::zero()).max()?;
64
65        let mut model = match model_in {
66            Some(mut temp) => {
67                temp.class_info
68                    .values_mut()
69                    .for_each(|x| x.sigma -= epsilon);
70                temp
71            }
72            None => GaussianNb {
73                class_info: HashMap::new(),
74            },
75        };
76
77        let yunique = dataset.labels();
78
79        for class in yunique {
80            // We filter for records that correspond to the current class
81            let xclass = filter(x.view(), y.view(), &class);
82
83            // We count the number of occurences of the class
84            let nclass = xclass.nrows();
85
86            // We compute the update of the gaussian mean and variance
87            let class_info = model
88                .class_info
89                .entry(class)
90                .or_insert_with(GaussianClassInfo::default);
91
92            let (theta_new, sigma_new) = Self::update_mean_variance(class_info, xclass.view());
93
94            // We now update the mean, variance and class count
95            class_info.theta = theta_new;
96            class_info.sigma = sigma_new;
97            class_info.class_count += nclass;
98        }
99
100        // We add back the epsilon previously subtracted for numerical
101        // calculation stability
102        model
103            .class_info
104            .values_mut()
105            .for_each(|x| x.sigma += epsilon);
106
107        // We update the priors
108        let class_count_sum = model
109            .class_info
110            .values()
111            .map(|x| x.class_count)
112            .sum::<usize>();
113
114        for info in model.class_info.values_mut() {
115            info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
116        }
117
118        Ok(model)
119    }
120}
121
122impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for GaussianNb<F, L>
123where
124    D: Data<Elem = F>,
125{
126    // Thin wrapper around the corresponding method of NaiveBayes
127    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
128        NaiveBayes::predict_inplace(self, x, y);
129    }
130
131    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
132        Array1::default(x.nrows())
133    }
134}
135
136impl<F, L> GaussianNbValidParams<F, L>
137where
138    F: Float,
139{
140    // Compute online update of gaussian mean and variance
141    fn update_mean_variance(
142        info_old: &GaussianClassInfo<F>,
143        x_new: ArrayView2<F>,
144    ) -> (Array1<F>, Array1<F>) {
145        // Deconstruct old state
146        let (count_old, mu_old, var_old) = (info_old.class_count, &info_old.theta, &info_old.sigma);
147
148        // If incoming data is empty no updates required
149        if x_new.nrows() == 0 {
150            return (mu_old.to_owned(), var_old.to_owned());
151        }
152
153        let count_new = x_new.nrows();
154
155        // unwrap is safe because None is returned only when number of records
156        // along the specified axis is 0, we return early if we have 0 rows
157        let mu_new = x_new.mean_axis(Axis(0)).unwrap();
158        let var_new = x_new.var_axis(Axis(0), F::zero());
159
160        // If previous batch was empty, we send the new mean and variance calculated
161        if count_old == 0 {
162            return (mu_new, var_new);
163        }
164
165        let count_total = count_old + count_new;
166
167        // Combine old and new mean, taking into consideration the number
168        // of observations
169        let mu_new_weighted = &mu_new * F::cast(count_new);
170        let mu_old_weighted = mu_old * F::cast(count_old);
171        let mu_weighted = (mu_new_weighted + mu_old_weighted).mapv(|x| x / F::cast(count_total));
172
173        // Combine old and new variance, taking into consideration the number
174        // of observations. This is achieved by combining the sum of squared
175        // differences
176        let ssd_old = var_old * F::cast(count_old);
177        let ssd_new = var_new * F::cast(count_new);
178        let weight = F::cast(count_new * count_old) / F::cast(count_total);
179        let ssd_weighted = ssd_old + ssd_new + (mu_old - mu_new).mapv(|x| weight * x.powi(2));
180        let var_weighted = ssd_weighted.mapv(|x| x / F::cast(count_total));
181
182        (mu_weighted, var_weighted)
183    }
184}
185
186/// Fitted Gaussian Naive Bayes classifier.
187///
188/// See [GaussianNbParams] for more information on the hyper-parameters.
189///
190/// # Model assumptions
191///
192/// The family of Naive Bayes classifiers assume independence between variables. They do not model
193/// moments between variables and lack therefore in modelling capability. The advantage is a linear
194/// fitting time with maximum-likelihood training in a closed form.
195///
196/// # Model usage example
197///
198/// The example below creates a set of hyperparameters, and then uses it to fit a Gaussian Naive Bayes
199/// classifier on provided data.
200///
201/// ```rust
202/// use linfa_bayes::{GaussianNbParams, GaussianNbValidParams, Result};
203/// use linfa::prelude::*;
204/// use ndarray::array;
205///
206/// let x = array![
207///     [-2., -1.],
208///     [-1., -1.],
209///     [-1., -2.],
210///     [1., 1.],
211///     [1., 2.],
212///     [2., 1.]
213/// ];
214/// let y = array![1, 1, 1, 2, 2, 2];
215/// let ds = DatasetView::new(x.view(), y.view());
216///
217/// // create a new parameter set with variance smoothing equals `1e-5`
218/// let unchecked_params = GaussianNbParams::new()
219///     .var_smoothing(1e-5);
220///
221/// // fit model with unchecked parameter set
222/// let model = unchecked_params.fit(&ds)?;
223///
224/// // transform into a verified parameter set
225/// let checked_params = unchecked_params.check()?;
226///
227/// // update model with the verified parameters, this only returns
228/// // errors originating from the fitting process
229/// let model = checked_params.fit_with(Some(model), &ds)?;
230/// # Result::Ok(())
231/// ```
232#[cfg_attr(
233    feature = "serde",
234    derive(Serialize, Deserialize),
235    serde(crate = "serde_crate")
236)]
237#[derive(Debug, Clone, PartialEq)]
238pub struct GaussianNb<F: PartialEq, L: Eq + Hash> {
239    class_info: HashMap<L, GaussianClassInfo<F>>,
240}
241
242#[cfg_attr(
243    feature = "serde",
244    derive(Serialize, Deserialize),
245    serde(crate = "serde_crate")
246)]
247#[derive(Debug, Default, Clone, PartialEq)]
248struct GaussianClassInfo<F> {
249    class_count: usize,
250    prior: F,
251    theta: Array1<F>,
252    sigma: Array1<F>,
253}
254
255impl<F: Float, L: Label> GaussianNb<F, L> {
256    /// Construct a new set of hyperparameters
257    pub fn params() -> GaussianNbParams<F, L> {
258        GaussianNbParams::new()
259    }
260}
261
262impl<F, L> NaiveBayes<'_, F, L> for GaussianNb<F, L>
263where
264    F: Float,
265    L: Label + Ord,
266{
267    // Compute unnormalized posterior log probability
268    fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
269        let mut joint_log_likelihood = HashMap::new();
270
271        for (class, info) in self.class_info.iter() {
272            let jointi = info.prior.ln();
273
274            let mut nij = info
275                .sigma
276                .mapv(|x| F::cast(2. * std::f64::consts::PI) * x)
277                .mapv(|x| x.ln())
278                .sum();
279            nij = F::cast(-0.5) * nij;
280
281            let nij = ((x.to_owned() - &info.theta).mapv(|x| x.powi(2)) / &info.sigma)
282                .sum_axis(Axis(1))
283                .mapv(|x| x * F::cast(0.5))
284                .mapv(|x| nij - x);
285
286            joint_log_likelihood.insert(class, nij + jointi);
287        }
288
289        joint_log_likelihood
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::{GaussianNb, NaiveBayes, Result};
296    use linfa::{
297        traits::{Fit, FitWith, Predict},
298        DatasetView, Error,
299    };
300
301    use crate::gaussian_nb::GaussianClassInfo;
302    use crate::{GaussianNbParams, GaussianNbValidParams, NaiveBayesError};
303    use approx::assert_abs_diff_eq;
304    use ndarray::{array, Axis};
305    use std::collections::HashMap;
306
307    #[test]
308    fn autotraits() {
309        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
310        has_autotraits::<GaussianNb<f64, usize>>();
311        has_autotraits::<GaussianClassInfo<f64>>();
312        has_autotraits::<GaussianNbParams<f64, usize>>();
313        has_autotraits::<GaussianNbValidParams<f64, usize>>();
314        has_autotraits::<NaiveBayesError>();
315    }
316
317    #[test]
318    fn test_gaussian_nb() -> Result<()> {
319        let x = array![
320            [-2., -1.],
321            [-1., -1.],
322            [-1., -2.],
323            [1., 1.],
324            [1., 2.],
325            [2., 1.]
326        ];
327        let y = array![1, 1, 1, 2, 2, 2];
328
329        let data = DatasetView::new(x.view(), y.view());
330        let fitted_clf = GaussianNb::params().fit(&data)?;
331        let pred = fitted_clf.predict(&x);
332
333        assert_abs_diff_eq!(pred, y);
334
335        let jll = fitted_clf.joint_log_likelihood(x.view());
336
337        // expected values from GaussianNB scikit-learn 1.6.1
338        let mut expected = HashMap::new();
339        expected.insert(
340            &1usize,
341            array![
342                -2.276946847943017,
343                -1.5269468546930165,
344                -2.276946847943017,
345                -25.52694663869301,
346                -38.27694652394301,
347                -38.27694652394301
348            ],
349        );
350        expected.insert(
351            &2usize,
352            array![
353                -38.27694652394301,
354                -25.52694663869301,
355                -38.27694652394301,
356                -1.5269468546930165,
357                -2.276946847943017,
358                -2.276946847943017
359            ],
360        );
361
362        assert_eq!(jll, expected);
363
364        let expected_proba = array![
365            [1.00000000e+00, 2.31952358e-16],
366            [1.00000000e+00, 3.77513536e-11],
367            [1.00000000e+00, 2.31952358e-16],
368            [3.77513536e-11, 1.00000000e+00],
369            [2.31952358e-16, 1.00000000e+00],
370            [2.31952358e-16, 1.00000000e+00]
371        ];
372
373        let (y_pred_proba, classes) = fitted_clf.predict_proba(x.view());
374        assert_eq!(classes, vec![&1usize, &2]);
375        assert_abs_diff_eq!(expected_proba, y_pred_proba, epsilon = 1e-10);
376
377        let (y_pred_log_proba, classes) = fitted_clf.predict_log_proba(x.view());
378        assert_eq!(classes, vec![&1usize, &2]);
379        assert_abs_diff_eq!(
380            y_pred_proba.mapv(f64::ln),
381            y_pred_log_proba,
382            epsilon = 1e-10
383        );
384
385        Ok(())
386    }
387
388    #[test]
389    fn test_gnb_fit_with() -> Result<()> {
390        let x = array![
391            [-2., -1.],
392            [-1., -1.],
393            [-1., -2.],
394            [1., 1.],
395            [1., 2.],
396            [2., 1.]
397        ];
398        let y = array![1, 1, 1, 2, 2, 2];
399
400        let clf = GaussianNb::params();
401
402        let model = x
403            .axis_chunks_iter(Axis(0), 2)
404            .zip(y.axis_chunks_iter(Axis(0), 2))
405            .map(|(a, b)| DatasetView::new(a, b))
406            .try_fold(None, |current, d| clf.fit_with(current, &d).map(Some))?
407            .ok_or(Error::NotEnoughSamples)?;
408
409        let pred = model.predict(&x);
410
411        assert_abs_diff_eq!(pred, y);
412
413        let jll = model.joint_log_likelihood(x.view());
414
415        let mut expected = HashMap::new();
416        expected.insert(
417            &1usize,
418            array![
419                -2.276946847943017,
420                -1.5269468546930165,
421                -2.276946847943017,
422                -25.52694663869301,
423                -38.27694652394301,
424                -38.27694652394301
425            ],
426        );
427        expected.insert(
428            &2usize,
429            array![
430                -38.27694652394301,
431                -25.52694663869301,
432                -38.27694652394301,
433                -1.5269468546930165,
434                -2.276946847943017,
435                -2.276946847943017
436            ],
437        );
438
439        for (key, value) in jll.iter() {
440            assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
441        }
442
443        Ok(())
444    }
445}