linfa_bayes/
base_nb.rs

1use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, Ix2, Zip};
2use ndarray_stats::QuantileExt;
3use std::collections::HashMap;
4
5use crate::error::{NaiveBayesError, Result};
6use linfa::dataset::{AsTargets, DatasetBase, Labels};
7use linfa::traits::FitWith;
8use linfa::{Float, Label};
9
10// Trait computing predictions for fitted Naive Bayes models
11pub trait NaiveBayes<'a, F, L>
12where
13    F: Float,
14    L: Label + Ord,
15{
16    /// Compute the unnormalized posterior log probabilities.
17    /// The result is returned as an HashMap indexing log probabilities for each samples (eg x rows) by classes
18    /// (eg jll\[class\] -> (n_samples,) array)
19    fn joint_log_likelihood(&self, x: ArrayView2<F>) -> HashMap<&L, Array1<F>>;
20
21    #[doc(hidden)]
22    fn predict_inplace<D: Data<Elem = F>>(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
23        assert_eq!(
24            x.nrows(),
25            y.len(),
26            "The number of data points must match the number of output targets."
27        );
28
29        let joint_log_likelihood = self.joint_log_likelihood(x.view());
30
31        // We store the classes and likelihood info in an vec and matrix
32        // respectively for easier identification of the dominant class for
33        // each input
34        let nclasses = joint_log_likelihood.keys().len();
35        let n = x.nrows();
36        let mut classes = Vec::with_capacity(nclasses);
37        let mut likelihood = Array2::zeros((nclasses, n));
38        joint_log_likelihood
39            .iter()
40            .enumerate()
41            .for_each(|(i, (&key, value))| {
42                classes.push(key.clone());
43                likelihood.row_mut(i).assign(value);
44            });
45
46        // Identify the class with the maximum log likelihood
47        *y = likelihood.map_axis(Axis(0), |x| {
48            let i = x.argmax().unwrap();
49            classes[i].clone()
50        });
51    }
52
53    /// Compute log-probability estimates for each sample wrt classes.
54    /// The columns corresponds to classes in sorted order returned as the second output.
55    fn predict_log_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
56        let log_likelihood = self.joint_log_likelihood(x);
57
58        let mut classes = log_likelihood.keys().cloned().collect::<Vec<_>>();
59        classes.sort();
60
61        let n_samples = x.nrows();
62        let n_classes = log_likelihood.len();
63        let mut log_prob_mat = Array2::<F>::zeros((n_samples, n_classes));
64
65        Zip::from(log_prob_mat.columns_mut())
66            .and(&classes)
67            .for_each(|mut jll, &class| jll.assign(log_likelihood.get(class).unwrap()));
68
69        let log_prob_x = log_prob_mat
70            .mapv(|x| x.exp())
71            .sum_axis(Axis(1))
72            .mapv(|x| x.ln())
73            .into_shape((n_samples, 1))
74            .unwrap();
75
76        (log_prob_mat - log_prob_x, classes)
77    }
78
79    /// Compute probability estimates for each sample wrt classes.
80    /// The columns corresponds to classes in sorted order returned as the second output.  
81    fn predict_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<&L>) {
82        let (log_prob_mat, classes) = self.predict_log_proba(x);
83
84        (log_prob_mat.mapv(|v| v.exp()), classes)
85    }
86}
87
88// Common functionality for hyper-parameter sets of Naive Bayes models ready for estimation
89pub(crate) trait NaiveBayesValidParams<'a, F, L, D, T>:
90    FitWith<'a, ArrayBase<D, Ix2>, T, NaiveBayesError>
91where
92    F: Float,
93    L: Label + Ord,
94    D: Data<Elem = F>,
95    T: AsTargets<Elem = L> + Labels<Elem = L>,
96{
97    fn fit(
98        &self,
99        dataset: &'a DatasetBase<ArrayBase<D, Ix2>, T>,
100        model_none: Self::ObjectIn,
101    ) -> Result<Self::ObjectOut> {
102        // We extract the unique classes in sorted order
103        let mut unique_classes = dataset.targets.labels();
104        unique_classes.sort_unstable();
105
106        self.fit_with(model_none, dataset)
107    }
108}