linfa_bayes/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod base_nb;
4mod bernoulli_nb;
5mod error;
6mod gaussian_nb;
7mod hyperparams;
8mod multinomial_nb;
9
10pub use base_nb::NaiveBayes;
11pub use bernoulli_nb::BernoulliNb;
12pub use error::{NaiveBayesError, Result};
13pub use gaussian_nb::GaussianNb;
14pub use hyperparams::{BernoulliNbParams, BernoulliNbValidParams};
15pub use hyperparams::{GaussianNbParams, GaussianNbValidParams};
16pub use hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
17pub use multinomial_nb::MultinomialNb;
18
19use linfa::{Float, Label};
20use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
21
22#[cfg(feature = "serde")]
23use serde_crate::{Deserialize, Serialize};
24
25/// Histogram of class occurrences for multinomial and binomial parameter estimation
26#[derive(Debug, Default, Clone, PartialEq)]
27#[cfg_attr(
28    feature = "serde",
29    derive(Serialize, Deserialize),
30    serde(crate = "serde_crate")
31)]
32pub(crate) struct ClassHistogram<F> {
33    class_count: usize,
34    prior: F,
35    feature_count: Array1<F>,
36    feature_log_prob: Array1<F>,
37}
38
39impl<F: Float> ClassHistogram<F> {
40    // Update log probabilities of features given class
41    fn update_with_smoothing(&mut self, x_new: ArrayView2<F>, alpha: F, total_count: bool) {
42        // If incoming data is empty no updates required
43        if x_new.nrows() == 0 {
44            return;
45        }
46
47        // unpack old class information
48        let ClassHistogram {
49            class_count,
50            feature_count,
51            feature_log_prob,
52            ..
53        } = self;
54
55        // count new feature occurrences
56        let feature_count_new: Array1<F> = x_new.sum_axis(Axis(0));
57
58        // if previous batch was empty, we send the new feature count calculated
59        if *class_count > 0 {
60            *feature_count = feature_count_new + feature_count.view();
61        } else {
62            *feature_count = feature_count_new;
63        }
64
65        // apply smoothing to feature counts
66        let feature_count_smoothed = feature_count.mapv(|x| x + alpha);
67
68        // compute total count (smoothed)
69        let count = if total_count {
70            F::cast(x_new.nrows()) + alpha * F::cast(2)
71        } else {
72            feature_count_smoothed.sum()
73        };
74
75        // compute log probabilities of each feature
76        *feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - count.ln());
77        // update class count
78        *class_count += x_new.nrows();
79    }
80}
81
82/// Returns a subset of x corresponding to the class specified by `ycondition`
83pub(crate) fn filter<F: Float, L: Label + Ord>(
84    x: ArrayView2<F>,
85    y: ArrayView1<L>,
86    ycondition: &L,
87) -> Array2<F> {
88    // We identify the row numbers corresponding to the class we are interested in
89    let index = y
90        .into_iter()
91        .enumerate()
92        .filter(|&(_, y)| (*ycondition == *y))
93        .map(|(i, _)| i)
94        .collect::<Vec<_>>();
95
96    // We subset x to only records corresponding to the class represented in `ycondition`
97    let mut xsubset = Array2::zeros((index.len(), x.ncols()));
98    index
99        .into_iter()
100        .enumerate()
101        .for_each(|(i, r)| xsubset.row_mut(i).assign(&x.slice(s![r, ..])));
102
103    xsubset
104}