linfa_bayes/
bernoulli_nb.rs

1use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
2use linfa::traits::{Fit, FitWith, PredictInplace};
3use linfa::{Float, Label};
4use ndarray::{Array1, ArrayBase, ArrayView2, CowArray, Data, Ix2};
5use std::collections::HashMap;
6use std::hash::Hash;
7
8use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
9use crate::error::{NaiveBayesError, Result};
10use crate::hyperparams::{BernoulliNbParams, BernoulliNbValidParams};
11use crate::{filter, ClassHistogram};
12
13impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for BernoulliNbValidParams<F, L>
14where
15    F: Float,
16    L: Label + 'a,
17    D: Data<Elem = F>,
18    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
19{
20}
21
22impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for BernoulliNbValidParams<F, L>
23where
24    F: Float,
25    L: Label + Ord,
26    D: Data<Elem = F>,
27    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
28{
29    type Object = BernoulliNb<F, L>;
30
31    // Thin wrapper around the corresponding method of NaiveBayesValidParams
32    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
33        NaiveBayesValidParams::fit(self, dataset, None)
34    }
35}
36
37impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
38    for BernoulliNbValidParams<F, L>
39where
40    F: Float,
41    L: Label + 'a,
42    D: Data<Elem = F>,
43    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
44{
45    type ObjectIn = Option<BernoulliNb<F, L>>;
46    type ObjectOut = BernoulliNb<F, L>;
47
48    fn fit_with(
49        &self,
50        model_in: Self::ObjectIn,
51        dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
52    ) -> Result<Self::ObjectOut> {
53        let x = dataset.records();
54        let y = dataset.as_single_targets();
55
56        let mut model = match model_in {
57            Some(temp) => temp,
58            None => BernoulliNb {
59                class_info: HashMap::new(),
60                binarize: self.binarize(),
61            },
62        };
63
64        // Binarize data if the threshold is set
65        let xbin = model.binarize(x).to_owned();
66
67        // Calculate feature log probabilities
68        let yunique = dataset.labels();
69        for class in yunique {
70            // We filter for records that correspond to the current class
71            let xclass = filter(xbin.view(), y.view(), &class);
72
73            // We compute the feature log probabilities and feature counts on
74            // the slice corresponding to the current class
75            model
76                .class_info
77                .entry(class)
78                .or_insert_with(ClassHistogram::default)
79                .update_with_smoothing(xclass.view(), self.alpha(), true);
80        }
81
82        // Update the priors
83        let class_count_sum = model
84            .class_info
85            .values()
86            .map(|x| x.class_count)
87            .sum::<usize>();
88
89        for info in model.class_info.values_mut() {
90            info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
91        }
92        Ok(model)
93    }
94}
95
96impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for BernoulliNb<F, L>
97where
98    D: Data<Elem = F>,
99{
100    // Thin wrapper around the corresponding method of NaiveBayes
101    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
102        // Binarize data if the threshold is set
103        let xbin = self.binarize(x);
104        NaiveBayes::predict_inplace(self, &xbin, y);
105    }
106
107    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
108        Array1::default(x.nrows())
109    }
110}
111
112/// Fitted Bernoulli Naive Bayes classifier.
113///
114/// See [BernoulliNbParams] for more information on the hyper-parameters.
115///
116/// # Model assumptions
117///
118/// The family of Naive Bayes classifiers assume independence between variables. They do not model
119/// moments between variables and lack therefore in modelling capability. The advantage is a linear
120/// fitting time with maximum-likelihood training in a closed form.
121///
122/// # Model usage example
123///
124/// The example below creates a set of hyperparameters, and then uses it to fit
125/// a Bernoulli Naive Bayes classifier on provided data.
126///
127/// ```rust
128/// use linfa_bayes::{BernoulliNbParams, BernoulliNbValidParams, Result};
129/// use linfa::prelude::*;
130/// use ndarray::array;
131///
132/// let x = array![
133///     [-2., -1.],
134///     [-1., -1.],
135///     [-1., -2.],
136///     [1., 1.],
137///     [1., 2.],
138///     [2., 1.]
139/// ];
140/// let y = array![1, 1, 1, 2, 2, 2];
141/// let ds = DatasetView::new(x.view(), y.view());
142///
143/// // create a new parameter set with smoothing parameter equals `1`
144/// let unchecked_params = BernoulliNbParams::new()
145///     .alpha(1.0);
146///
147/// // fit model with unchecked parameter set
148/// let model = unchecked_params.fit(&ds)?;
149///
150/// // transform into a verified parameter set
151/// let checked_params = unchecked_params.check()?;
152///
153/// // update model with the verified parameters, this only returns
154/// // errors originating from the fitting process
155/// let model = checked_params.fit_with(Some(model), &ds)?;
156/// # Result::Ok(())
157/// ```
158#[derive(Debug, Clone, PartialEq)]
159pub struct BernoulliNb<F: PartialEq, L: Eq + Hash> {
160    class_info: HashMap<L, ClassHistogram<F>>,
161    binarize: Option<F>,
162}
163
164impl<F: Float, L: Label> BernoulliNb<F, L> {
165    /// Construct a new set of hyperparameters
166    pub fn params() -> BernoulliNbParams<F, L> {
167        BernoulliNbParams::new()
168    }
169
170    // Binarize data if the threshold is set
171    fn binarize<'a, D>(&'a self, x: &'a ArrayBase<D, Ix2>) -> CowArray<'a, F, Ix2>
172    where
173        D: Data<Elem = F>,
174    {
175        if let Some(thr) = self.binarize {
176            let xbin = x.map(|v| if v > &thr { F::one() } else { F::zero() });
177            CowArray::from(xbin)
178        } else {
179            CowArray::from(x)
180        }
181    }
182}
183
184impl<F, L> NaiveBayes<'_, F, L> for BernoulliNb<F, L>
185where
186    F: Float,
187    L: Label + Ord,
188{
189    // Compute unnormalized posterior log probability
190    fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
191        let mut joint_log_likelihood = HashMap::new();
192        for (class, info) in self.class_info.iter() {
193            // Combine feature log probabilities, their negatives, and class priors to
194            // get log-likelihood for each class
195            let neg_prob = info.feature_log_prob.map(|lp| (F::one() - lp.exp()).ln());
196            let feature_log_prob = &info.feature_log_prob - &neg_prob;
197            let jll = x.dot(&feature_log_prob);
198            joint_log_likelihood.insert(class, jll + info.prior.ln() + neg_prob.sum());
199        }
200
201        joint_log_likelihood
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::{BernoulliNb, NaiveBayes, Result};
208    use linfa::{
209        traits::{Fit, Predict},
210        DatasetView,
211    };
212
213    use crate::{BernoulliNbParams, BernoulliNbValidParams};
214    use approx::assert_abs_diff_eq;
215    use ndarray::array;
216    use std::collections::HashMap;
217
218    #[test]
219    fn autotraits() {
220        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
221        has_autotraits::<BernoulliNb<f64, usize>>();
222        has_autotraits::<BernoulliNbValidParams<f64, usize>>();
223        has_autotraits::<BernoulliNbParams<f64, usize>>();
224    }
225
226    #[test]
227    fn test_bernoulli_nb() -> Result<()> {
228        let x = array![[1., 0.], [0., 0.], [1., 1.], [0., 1.]];
229        let y = array![1, 1, 2, 2];
230        let data = DatasetView::new(x.view(), y.view());
231
232        let params = BernoulliNb::params().binarize(None);
233        let fitted_clf = params.fit(&data)?;
234        assert!(&fitted_clf.binarize.is_none());
235
236        let pred = fitted_clf.predict(&x);
237        assert_abs_diff_eq!(pred, y);
238
239        let jll = fitted_clf.joint_log_likelihood(x.view());
240        let mut expected = HashMap::new();
241        expected.insert(
242            &1usize,
243            (array![0.1875f64, 0.1875, 0.0625, 0.0625]).map(|v| v.ln()),
244        );
245
246        expected.insert(
247            &2usize,
248            (array![0.0625f64, 0.0625, 0.1875, 0.1875,]).map(|v| v.ln()),
249        );
250
251        for (key, value) in jll.iter() {
252            assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
253        }
254
255        Ok(())
256    }
257
258    #[test]
259    fn test_text_class() -> Result<()> {
260        // From https://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html#tab:nbtoy
261        let train = array![
262            // C, B, S, M, T, J
263            [2., 1., 0., 0., 0., 0.0f64],
264            [2., 0., 1., 0., 0., 0.],
265            [1., 0., 0., 1., 0., 0.],
266            [1., 0., 0., 0., 1., 1.],
267        ];
268        let y = array![1, 1, 1, 2];
269        let test = array![[3., 0., 0., 0., 1., 1.0f64]];
270
271        let data = DatasetView::new(train.view(), y.view());
272        let fitted_clf = BernoulliNb::params().fit(&data)?;
273        let pred = fitted_clf.predict(&test);
274
275        assert_abs_diff_eq!(pred, array![2]);
276
277        // See: https://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
278        let jll = fitted_clf.joint_log_likelihood(fitted_clf.binarize(&test).view());
279        assert_abs_diff_eq!(jll.get(&1).unwrap()[0].exp(), 0.005, epsilon = 1e-3);
280        assert_abs_diff_eq!(jll.get(&2).unwrap()[0].exp(), 0.022, epsilon = 1e-3);
281
282        Ok(())
283    }
284}