1#![doc = include_str!("../README.md")]
2
3mod base_nb;
4mod bernoulli_nb;
5mod error;
6mod gaussian_nb;
7mod hyperparams;
8mod multinomial_nb;
9
10pub use base_nb::NaiveBayes;
11pub use bernoulli_nb::BernoulliNb;
12pub use error::{NaiveBayesError, Result};
13pub use gaussian_nb::GaussianNb;
14pub use hyperparams::{BernoulliNbParams, BernoulliNbValidParams};
15pub use hyperparams::{GaussianNbParams, GaussianNbValidParams};
16pub use hyperparams::{MultinomialNbParams, MultinomialNbValidParams};
17pub use multinomial_nb::MultinomialNb;
18
19use linfa::{Float, Label};
20use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
21
22#[cfg(feature = "serde")]
23use serde_crate::{Deserialize, Serialize};
24
25#[derive(Debug, Default, Clone, PartialEq)]
27#[cfg_attr(
28 feature = "serde",
29 derive(Serialize, Deserialize),
30 serde(crate = "serde_crate")
31)]
32pub(crate) struct ClassHistogram<F> {
33 class_count: usize,
34 prior: F,
35 feature_count: Array1<F>,
36 feature_log_prob: Array1<F>,
37}
38
39impl<F: Float> ClassHistogram<F> {
40 fn update_with_smoothing(&mut self, x_new: ArrayView2<F>, alpha: F, total_count: bool) {
42 if x_new.nrows() == 0 {
44 return;
45 }
46
47 let ClassHistogram {
49 class_count,
50 feature_count,
51 feature_log_prob,
52 ..
53 } = self;
54
55 let feature_count_new: Array1<F> = x_new.sum_axis(Axis(0));
57
58 if *class_count > 0 {
60 *feature_count = feature_count_new + feature_count.view();
61 } else {
62 *feature_count = feature_count_new;
63 }
64
65 let feature_count_smoothed = feature_count.mapv(|x| x + alpha);
67
68 let count = if total_count {
70 F::cast(x_new.nrows()) + alpha * F::cast(2)
71 } else {
72 feature_count_smoothed.sum()
73 };
74
75 *feature_log_prob = feature_count_smoothed.mapv(|x| x.ln() - count.ln());
77 *class_count += x_new.nrows();
79 }
80}
81
82pub(crate) fn filter<F: Float, L: Label + Ord>(
84 x: ArrayView2<F>,
85 y: ArrayView1<L>,
86 ycondition: &L,
87) -> Array2<F> {
88 let index = y
90 .into_iter()
91 .enumerate()
92 .filter(|&(_, y)| (*ycondition == *y))
93 .map(|(i, _)| i)
94 .collect::<Vec<_>>();
95
96 let mut xsubset = Array2::zeros((index.len(), x.ncols()));
98 index
99 .into_iter()
100 .enumerate()
101 .for_each(|(i, r)| xsubset.row_mut(i).assign(&x.slice(s![r, ..])));
102
103 xsubset
104}