1use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, Ix2, Zip};
2use ndarray_stats::QuantileExt;
3use std::collections::HashMap;
4
5use crate::error::{NaiveBayesError, Result};
6use linfa::dataset::{AsTargets, DatasetBase, Labels};
7use linfa::traits::FitWith;
8use linfa::{Float, Label};
9
10pub trait NaiveBayes<'a, F, L>
12where
13 F: Float,
14 L: Label + Ord,
15{
16 fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>>;
20
21 #[doc(hidden)]
22 fn predict_inplace<D: Data<Elem = F>>(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
23 assert_eq!(
24 x.nrows(),
25 y.len(),
26 "The number of data points must match the number of output targets."
27 );
28
29 let joint_log_likelihood = self.joint_log_likelihood(x.view());
30
31 let nclasses = joint_log_likelihood.keys().len();
35 let n = x.nrows();
36 let mut classes = Vec::with_capacity(nclasses);
37 let mut likelihood = Array2::zeros((nclasses, n));
38 joint_log_likelihood
39 .iter()
40 .enumerate()
41 .for_each(|(i, (&key, value))| {
42 classes.push(key.clone());
43 likelihood.row_mut(i).assign(value);
44 });
45
46 *y = likelihood.map_axis(Axis(0), |x| {
48 let i = x.argmax().unwrap();
49 classes[i].clone()
50 });
51 }
52
53 fn predict_log_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
56 let log_likelihood = self.joint_log_likelihood(x);
57
58 let mut classes = log_likelihood.keys().cloned().collect::<Vec<_>>();
59 classes.sort();
60
61 let n_samples = x.nrows();
62 let n_classes = log_likelihood.len();
63 let mut log_prob_mat = Array2::<F>::zeros((n_samples, n_classes));
64
65 Zip::from(log_prob_mat.columns_mut())
66 .and(&classes)
67 .for_each(|mut jll, &class| jll.assign(log_likelihood.get(class).unwrap()));
68
69 let log_prob_x = log_prob_mat
70 .mapv(|x| x.exp())
71 .sum_axis(Axis(1))
72 .mapv(|x| x.ln())
73 .into_shape((n_samples, 1))
74 .unwrap();
75
76 (log_prob_mat - log_prob_x, classes)
77 }
78
79 fn predict_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
82 let (log_prob_mat, classes) = self.predict_log_proba(x);
83
84 (log_prob_mat.mapv(|v| v.exp()), classes)
85 }
86}
87
88pub(crate) trait NaiveBayesValidParams<'a, F, L, D, T>:
90 FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
91where
92 F: Float,
93 L: Label + Ord,
94 D: Data<Elem = F>,
95 T: AsTargets<Elem = L> + Labels<Elem = L>,
96{
97 fn fit(
98 &self,
99 dataset: &'a DatasetBase<ArrayBase<D, Ix2>, T>,
100 model_none: Self::ObjectIn,
101 ) -> Result<Self::ObjectOut> {
102 let mut unique_classes = dataset.targets.labels();
104 unique_classes.sort_unstable();
105
106 self.fit_with(model_none, dataset)
107 }
108}