linfa_bayes/
multinomial_nb.rs

1use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
2use linfa::traits::{Fit, FitWith, PredictInplace};
3use linfa::{Float, Label};
4use ndarray::{Array1, ArrayBase, ArrayView2, Data, Ix2};
5use std::collections::HashMap;
6use std::hash::Hash;
7
8use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
9use crate::error::{NaiveBayesError, Result};
10use crate::hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
11use crate::{filter, ClassHistogram};
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for MultinomialNbValidParams<F, L>
17where
18    F: Float,
19    L: Label + 'a,
20    D: Data<Elem = F>,
21    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
22{
23}
24
25impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for MultinomialNbValidParams<F, L>
26where
27    F: Float,
28    L: Label + Ord,
29    D: Data<Elem = F>,
30    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
31{
32    type Object = MultinomialNb<F, L>;
33    // Thin wrapper around the corresponding method of NaiveBayesValidParams
34    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
35        NaiveBayesValidParams::fit(self, dataset, None)
36    }
37}
38
39impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
40    for MultinomialNbValidParams<F, L>
41where
42    F: Float,
43    L: Label + 'a,
44    D: Data<Elem = F>,
45    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
46{
47    type ObjectIn = Option<MultinomialNb<F, L>>;
48    type ObjectOut = MultinomialNb<F, L>;
49
50    fn fit_with(
51        &self,
52        model_in: Self::ObjectIn,
53        dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
54    ) -> Result<Self::ObjectOut> {
55        let x = dataset.records();
56        let y = dataset.as_single_targets();
57
58        let mut model = match model_in {
59            Some(temp) => temp,
60            None => MultinomialNb {
61                class_info: HashMap::new(),
62            },
63        };
64
65        let yunique = dataset.labels();
66
67        for class in yunique {
68            // filter dataset for current class
69            let xclass = filter(x.view(), y.view(), &class);
70
71            // compute feature log probabilities and counts
72            model
73                .class_info
74                .entry(class.clone())
75                .or_insert_with(ClassHistogram::default)
76                .update_with_smoothing(xclass.view(), self.alpha(), false);
77
78            dbg!(&model.class_info.get(&class));
79        }
80
81        // update priors
82        let class_count_sum = model
83            .class_info
84            .values()
85            .map(|x| x.class_count)
86            .sum::<usize>();
87
88        for info in model.class_info.values_mut() {
89            info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
90        }
91
92        Ok(model)
93    }
94}
95
96impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for MultinomialNb<F, L>
97where
98    D: Data<Elem = F>,
99{
100    // Thin wrapper around the corresponding method of NaiveBayes
101    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
102        NaiveBayes::predict_inplace(self, x, y);
103    }
104
105    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
106        Array1::default(x.nrows())
107    }
108}
109
110/// Fitted Multinomial Naive Bayes classifier.
111///
112/// See [MultinomialNbParams] for more information on the hyper-parameters.
113///
114/// # Model assumptions
115///
116/// The family of Naive Bayes classifiers assume independence between variables. They do not model
117/// moments between variables and lack therefore in modelling capability. The advantage is a linear
118/// fitting time with maximum-likelihood training in a closed form.
119///
120/// # Model usage example
121///
122/// The example below creates a set of hyperparameters, and then uses it to fit a Multinomial Naive
123/// Bayes classifier on provided data.
124///
125/// ```rust
126/// use linfa_bayes::{MultinomialNbParams, MultinomialNbValidParams, Result};
127/// use linfa::prelude::*;
128/// use ndarray::array;
129///
130/// let x = array![
131///     [-2., -1.],
132///     [-1., -1.],
133///     [-1., -2.],
134///     [1., 1.],
135///     [1., 2.],
136///     [2., 1.]
137/// ];
138/// let y = array![1, 1, 1, 2, 2, 2];
139/// let ds = DatasetView::new(x.view(), y.view());
140///
141/// // create a new parameter set with smoothing parameter equals `1`
142/// let unchecked_params = MultinomialNbParams::new()
143///     .alpha(1.0);
144///
145/// // fit model with unchecked parameter set
146/// let model = unchecked_params.fit(&ds)?;
147///
148/// // transform into a verified parameter set
149/// let checked_params = unchecked_params.check()?;
150///
151/// // update model with the verified parameters, this only returns
152/// // errors originating from the fitting process
153/// let model = checked_params.fit_with(Some(model), &ds)?;
154/// # Result::Ok(())
155/// ```
156#[cfg_attr(
157    feature = "serde",
158    derive(Serialize, Deserialize),
159    serde(crate = "serde_crate")
160)]
161#[derive(Debug, Clone, PartialEq)]
162pub struct MultinomialNb<F: PartialEq, L: Eq + Hash> {
163    class_info: HashMap<L, ClassHistogram<F>>,
164}
165
166impl<F: Float, L: Label> MultinomialNb<F, L> {
167    /// Construct a new set of hyperparameters
168    pub fn params() -> MultinomialNbParams<F, L> {
169        MultinomialNbParams::new()
170    }
171}
172
173impl<F, L> NaiveBayes<'_, F, L> for MultinomialNb<F, L>
174where
175    F: Float,
176    L: Label + Ord,
177{
178    // Compute unnormalized posterior log probability
179    fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
180        let mut joint_log_likelihood = HashMap::new();
181        for (class, info) in self.class_info.iter() {
182            // Combine feature log probabilities and class priors to get log-likelihood for each class
183            let jointi = info.prior.ln();
184            let nij = x.dot(&info.feature_log_prob);
185            joint_log_likelihood.insert(class, nij + jointi);
186        }
187
188        joint_log_likelihood
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::{MultinomialNb, NaiveBayes, Result};
195    use linfa::{
196        traits::{Fit, FitWith, Predict},
197        Dataset, DatasetView, Error,
198    };
199
200    use crate::{MultinomialNbParams, MultinomialNbValidParams};
201    use approx::assert_abs_diff_eq;
202    use ndarray::{array, Axis};
203    use std::collections::HashMap;
204
205    #[test]
206    fn autotraits() {
207        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
208        has_autotraits::<MultinomialNb<f64, usize>>();
209        has_autotraits::<MultinomialNbValidParams<f64, usize>>();
210        has_autotraits::<MultinomialNbParams<f64, usize>>();
211    }
212
213    #[test]
214    fn test_multinomial_nb() -> Result<()> {
215        let ds = Dataset::new(
216            array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]],
217            array![1, 1, 1, 2, 2, 2],
218        );
219
220        let fitted_clf = MultinomialNb::params().fit(&ds)?;
221        let pred = fitted_clf.predict(ds.records());
222
223        assert_abs_diff_eq!(pred, ds.targets());
224
225        let jll = fitted_clf.joint_log_likelihood(ds.records().view());
226        let mut expected = HashMap::new();
227        // Computed with sklearn.naive_bayes.MultinomialNB
228        expected.insert(
229            &1usize,
230            array![
231                -0.82667857,
232                -0.96020997,
233                -1.09374136,
234                -2.77258872,
235                -4.85203026,
236                -6.93147181
237            ],
238        );
239
240        expected.insert(
241            &2usize,
242            array![
243                -2.77258872,
244                -4.85203026,
245                -6.93147181,
246                -0.82667857,
247                -0.96020997,
248                -1.09374136
249            ],
250        );
251
252        for (key, value) in jll.iter() {
253            assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
254        }
255
256        Ok(())
257    }
258
259    #[test]
260    fn test_mnb_fit_with() -> Result<()> {
261        let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];
262        let y = array![1, 1, 1, 2, 2, 2];
263
264        let clf = MultinomialNb::params();
265
266        let model = x
267            .axis_chunks_iter(Axis(0), 2)
268            .zip(y.axis_chunks_iter(Axis(0), 2))
269            .map(|(a, b)| DatasetView::new(a, b))
270            .try_fold(None, |current, d| clf.fit_with(current, &d).map(Some))?
271            .ok_or(Error::NotEnoughSamples)?;
272
273        let pred = model.predict(&x);
274
275        assert_abs_diff_eq!(pred, y);
276
277        let jll = model.joint_log_likelihood(x.view());
278
279        let mut expected = HashMap::new();
280        // Computed with sklearn.naive_bayes.MultinomialNB
281        expected.insert(
282            &1usize,
283            array![
284                -0.82667857,
285                -0.96020997,
286                -1.09374136,
287                -2.77258872,
288                -4.85203026,
289                -6.93147181
290            ],
291        );
292
293        expected.insert(
294            &2usize,
295            array![
296                -2.77258872,
297                -4.85203026,
298                -6.93147181,
299                -0.82667857,
300                -0.96020997,
301                -1.09374136
302            ],
303        );
304
305        for (key, value) in jll.iter() {
306            assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
307        }
308
309        Ok(())
310    }
311}