1use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
2use linfa::traits::{Fit, FitWith, PredictInplace};
3use linfa::{Float, Label};
4use ndarray::{Array1, ArrayBase, ArrayView2, CowArray, Data, Ix2};
5use std::collections::HashMap;
6use std::hash::Hash;
7
8use crate::base_nb::{NaiveBayes, NaiveBayesValidParams};
9use crate::error::{NaiveBayesError, Result};
10use crate::hyperparams::{BernoulliNbParams, BernoulliNbValidParams};
11use crate::{filter, ClassHistogram};
12
13impl<'a, F, L, D, T> NaiveBayesValidParams<'a, F, L, D, T> for BernoulliNbValidParams<F, L>
14where
15 F: Float,
16 L: Label + 'a,
17 D: Data<Elem = F>,
18 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
19{
20}
21
22impl<F, L, D, T> Fit<ArrayBase<D, Ix2>, T, NaiveBayesError> for BernoulliNbValidParams<F, L>
23where
24 F: Float,
25 L: Label + Ord,
26 D: Data<Elem = F>,
27 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
28{
29 type Object = BernoulliNb<F, L>;
30
31 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
33 NaiveBayesValidParams::fit(self, dataset, None)
34 }
35}
36
37impl<'a, F, L, D, T> FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
38 for BernoulliNbValidParams<F, L>
39where
40 F: Float,
41 L: Label + 'a,
42 D: Data<Elem = F>,
43 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
44{
45 type ObjectIn = Option<BernoulliNb<F, L>>;
46 type ObjectOut = BernoulliNb<F, L>;
47
48 fn fit_with(
49 &self,
50 model_in: Self::ObjectIn,
51 dataset: &DatasetBase<ArrayBase<D, Ix2>, T>,
52 ) -> Result<Self::ObjectOut> {
53 let x = dataset.records();
54 let y = dataset.as_single_targets();
55
56 let mut model = match model_in {
57 Some(temp) => temp,
58 None => BernoulliNb {
59 class_info: HashMap::new(),
60 binarize: self.binarize(),
61 },
62 };
63
64 let xbin = model.binarize(x).to_owned();
66
67 let yunique = dataset.labels();
69 for class in yunique {
70 let xclass = filter(xbin.view(), y.view(), &class);
72
73 model
76 .class_info
77 .entry(class)
78 .or_insert_with(ClassHistogram::default)
79 .update_with_smoothing(xclass.view(), self.alpha(), true);
80 }
81
82 let class_count_sum = model
84 .class_info
85 .values()
86 .map(|x| x.class_count)
87 .sum::<usize>();
88
89 for info in model.class_info.values_mut() {
90 info.prior = F::cast(info.class_count) / F::cast(class_count_sum);
91 }
92 Ok(model)
93 }
94}
95
96impl<F: Float, L: Label, D> PredictInplace<ArrayBase<D, Ix2>, Array1<L>> for BernoulliNb<F, L>
97where
98 D: Data<Elem = F>,
99{
100 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
102 let xbin = self.binarize(x);
104 NaiveBayes::predict_inplace(self, &xbin, y);
105 }
106
107 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
108 Array1::default(x.nrows())
109 }
110}
111
112#[derive(Debug, Clone, PartialEq)]
159pub struct BernoulliNb<F: PartialEq, L: Eq + Hash> {
160 class_info: HashMap<L, ClassHistogram<F>>,
161 binarize: Option<F>,
162}
163
164impl<F: Float, L: Label> BernoulliNb<F, L> {
165 pub fn params() -> BernoulliNbParams<F, L> {
167 BernoulliNbParams::new()
168 }
169
170 fn binarize<'a, D>(&'a self, x: &'a ArrayBase<D, Ix2>) -> CowArray<'a, F, Ix2>
172 where
173 D: Data<Elem = F>,
174 {
175 if let Some(thr) = self.binarize {
176 let xbin = x.map(|v| if v > &thr { F::one() } else { F::zero() });
177 CowArray::from(xbin)
178 } else {
179 CowArray::from(x)
180 }
181 }
182}
183
184impl<F, L> NaiveBayes<'_, F, L> for BernoulliNb<F, L>
185where
186 F: Float,
187 L: Label + Ord,
188{
189 fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>> {
191 let mut joint_log_likelihood = HashMap::new();
192 for (class, info) in self.class_info.iter() {
193 let neg_prob = info.feature_log_prob.map(|lp| (F::one() - lp.exp()).ln());
196 let feature_log_prob = &info.feature_log_prob - &neg_prob;
197 let jll = x.dot(&feature_log_prob);
198 joint_log_likelihood.insert(class, jll + info.prior.ln() + neg_prob.sum());
199 }
200
201 joint_log_likelihood
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::{BernoulliNb, NaiveBayes, Result};
208 use linfa::{
209 traits::{Fit, Predict},
210 DatasetView,
211 };
212
213 use crate::{BernoulliNbParams, BernoulliNbValidParams};
214 use approx::assert_abs_diff_eq;
215 use ndarray::array;
216 use std::collections::HashMap;
217
218 #[test]
219 fn autotraits() {
220 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
221 has_autotraits::<BernoulliNb<f64, usize>>();
222 has_autotraits::<BernoulliNbValidParams<f64, usize>>();
223 has_autotraits::<BernoulliNbParams<f64, usize>>();
224 }
225
226 #[test]
227 fn test_bernoulli_nb() -> Result<()> {
228 let x = array![[1., 0.], [0., 0.], [1., 1.], [0., 1.]];
229 let y = array![1, 1, 2, 2];
230 let data = DatasetView::new(x.view(), y.view());
231
232 let params = BernoulliNb::params().binarize(None);
233 let fitted_clf = params.fit(&data)?;
234 assert!(&fitted_clf.binarize.is_none());
235
236 let pred = fitted_clf.predict(&x);
237 assert_abs_diff_eq!(pred, y);
238
239 let jll = fitted_clf.joint_log_likelihood(x.view());
240 let mut expected = HashMap::new();
241 expected.insert(
242 &1usize,
243 (array![0.1875f64, 0.1875, 0.0625, 0.0625]).map(|v| v.ln()),
244 );
245
246 expected.insert(
247 &2usize,
248 (array![0.0625f64, 0.0625, 0.1875, 0.1875,]).map(|v| v.ln()),
249 );
250
251 for (key, value) in jll.iter() {
252 assert_abs_diff_eq!(value, expected.get(key).unwrap(), epsilon = 1e-6);
253 }
254
255 Ok(())
256 }
257
258 #[test]
259 fn test_text_class() -> Result<()> {
260 let train = array![
262 [2., 1., 0., 0., 0., 0.0f64],
264 [2., 0., 1., 0., 0., 0.],
265 [1., 0., 0., 1., 0., 0.],
266 [1., 0., 0., 0., 1., 1.],
267 ];
268 let y = array![1, 1, 1, 2];
269 let test = array![[3., 0., 0., 0., 1., 1.0f64]];
270
271 let data = DatasetView::new(train.view(), y.view());
272 let fitted_clf = BernoulliNb::params().fit(&data)?;
273 let pred = fitted_clf.predict(&test);
274
275 assert_abs_diff_eq!(pred, array![2]);
276
277 let jll = fitted_clf.joint_log_likelihood(fitted_clf.binarize(&test).view());
279 assert_abs_diff_eq!(jll.get(&1).unwrap()[0].exp(), 0.005, epsilon = 1e-3);
280 assert_abs_diff_eq!(jll.get(&2).unwrap()[0].exp(), 0.022, epsilon = 1e-3);
281
282 Ok(())
283 }
284}