1use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
2use linfa::traits::{Fit, FitWith, PredictInplace};
3use linfa::{Float, Label};
4use ndarray::{Array1, ArrayBase, ArrayView2, Data, Ix2};
5use std::collections::HashMap;
6use std::hash::Hash;
7
8use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
9use crate::error::{NaiveBayesError, Result};
10use crate::hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
11use crate::{filter, ClassHistogram};
12
13#[cfg(feature = "serde")]
14use serde_crate::{Deserialize, Serialize};
15
16impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for MultinomialNbValidParams<F, L>
17where
18 F: Float,
19 L: Label + 'a,
20 D: Data<Elem = F>,
21 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
22{
23}
24
25impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for MultinomialNbValidParams<F, L>
26where
27 F: Float,
28 L: Label + Ord,
29 D: Data<Elem = F>,
30 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
31{
32 type Object = MultinomialNb<F, L>;
33 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
35 NaiveBayesValidParams::fit(self, dataset, None)
36 }
37}
38
39impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
40 for MultinomialNbValidParams<F, L>
41where
42 F: Float,
43 L: Label + 'a,
44 D: Data<Elem = F>,
45 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
46{
47 type ObjectIn = Option<MultinomialNb<F, L>>;
48 type ObjectOut = MultinomialNb<F, L>;
49
50 fn fit_with(
51 &self,
52 model_in: Self::ObjectIn,
53 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
54 ) -> Result<Self::ObjectOut> {
55 let x = dataset.records();
56 let y = dataset.as_single_targets();
57
58 let mut model = match model_in {
59 Some(temp) => temp,
60 None => MultinomialNb {
61 class_info: HashMap::new(),
62 },
63 };
64
65 let yunique = dataset.labels();
66
67 for class in yunique {
68 let xclass = filter(x.view(), y.view(), &class);
70
71 model
73 .class_info
74 .entry(class.clone())
75 .or_insert_with(ClassHistogram::default)
76 .update_with_smoothing(xclass.view(), self.alpha(), false);
77
78 dbg!(&model.class_info.get(&class));
79 }
80
81 let class_count_sum = model
83 .class_info
84 .values()
85 .map(|x| x.class_count)
86 .sum::<usize>();
87
88 for info in model.class_info.values_mut() {
89 info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
90 }
91
92 Ok(model)
93 }
94}
95
96impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for MultinomialNb<F, L>
97where
98 D: Data<Elem = F>,
99{
100 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
102 NaiveBayes::predict_inplace(self, x, y);
103 }
104
105 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
106 Array1::default(x.nrows())
107 }
108}
109
110#[cfg_attr(
157 feature = "serde",
158 derive(Serialize, Deserialize),
159 serde(crate = "serde_crate")
160)]
161#[derive(Debug, Clone, PartialEq)]
162pub struct MultinomialNb<F: PartialEq, L: Eq + Hash> {
163 class_info: HashMap<L, ClassHistogram<F>>,
164}
165
166impl<F: Float, L: Label> MultinomialNb<F, L> {
167 pub fn params() -> MultinomialNbParams<F, L> {
169 MultinomialNbParams::new()
170 }
171}
172
173impl<F, L> NaiveBayes<'_, F, L> for MultinomialNb<F, L>
174where
175 F: Float,
176 L: Label + Ord,
177{
178 fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
180 let mut joint_log_likelihood = HashMap::new();
181 for (class, info) in self.class_info.iter() {
182 let jointi = info.prior.ln();
184 let nij = x.dot(&info.feature_log_prob);
185 joint_log_likelihood.insert(class, nij + jointi);
186 }
187
188 joint_log_likelihood
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::{MultinomialNb, NaiveBayes, Result};
195 use linfa::{
196 traits::{Fit, FitWith, Predict},
197 Dataset, DatasetView, Error,
198 };
199
200 use crate::{MultinomialNbParams, MultinomialNbValidParams};
201 use approx::assert_abs_diff_eq;
202 use ndarray::{array, Axis};
203 use std::collections::HashMap;
204
205 #[test]
206 fn autotraits() {
207 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
208 has_autotraits::<MultinomialNb<f64, usize>>();
209 has_autotraits::<MultinomialNbValidParams<f64, usize>>();
210 has_autotraits::<MultinomialNbParams<f64, usize>>();
211 }
212
213 #[test]
214 fn test_multinomial_nb() -> Result<()> {
215 let ds = Dataset::new(
216 array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]],
217 array![1, 1, 1, 2, 2, 2],
218 );
219
220 let fitted_clf = MultinomialNb::params().fit(&ds)?;
221 let pred = fitted_clf.predict(ds.records());
222
223 assert_abs_diff_eq!(pred, ds.targets());
224
225 let jll = fitted_clf.joint_log_likelihood(ds.records().view());
226 let mut expected = HashMap::new();
227 expected.insert(
229 &1usize,
230 array![
231 -0.82667857,
232 -0.96020997,
233 -1.09374136,
234 -2.77258872,
235 -4.85203026,
236 -6.93147181
237 ],
238 );
239
240 expected.insert(
241 &2usize,
242 array![
243 -2.77258872,
244 -4.85203026,
245 -6.93147181,
246 -0.82667857,
247 -0.96020997,
248 -1.09374136
249 ],
250 );
251
252 for (key, value) in jll.iter() {
253 assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
254 }
255
256 Ok(())
257 }
258
259 #[test]
260 fn test_mnb_fit_with() -> Result<()> {
261 let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];
262 let y = array![1, 1, 1, 2, 2, 2];
263
264 let clf = MultinomialNb::params();
265
266 let model = x
267 .axis_chunks_iter(Axis(0), 2)
268 .zip(y.axis_chunks_iter(Axis(0), 2))
269 .map(|(a, b)| DatasetView::new(a, b))
270 .try_fold(None, |current, d| clf.fit_with(current, &d).map(Some))?
271 .ok_or(Error::NotEnoughSamples)?;
272
273 let pred = model.predict(&x);
274
275 assert_abs_diff_eq!(pred, y);
276
277 let jll = model.joint_log_likelihood(x.view());
278
279 let mut expected = HashMap::new();
280 expected.insert(
282 &1usize,
283 array![
284 -0.82667857,
285 -0.96020997,
286 -1.09374136,
287 -2.77258872,
288 -4.85203026,
289 -6.93147181
290 ],
291 );
292
293 expected.insert(
294 &2usize,
295 array![
296 -2.77258872,
297 -4.85203026,
298 -6.93147181,
299 -0.82667857,
300 -0.96020997,
301 -1.09374136
302 ],
303 );
304
305 for (key, value) in jll.iter() {
306 assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
307 }
308
309 Ok(())
310 }
311}