linfa_bayes/
hyperparams.rs

1use crate::NaiveBayesError;
2use linfa::{Float, ParamGuard};
3use std::marker::PhantomData;
4
5#[cfg(feature = "serde")]
6use serde_crate::{Deserialize, Serialize};
7
8/// A verified hyper-parameter set ready for the estimation of a [Gaussian Naive Bayes model](crate::gaussian_nb::GaussianNb).
9///
10/// See [`GaussianNb`](crate::gaussian_nb::GaussianNb) for information on the model and [`GaussianNbParams`](crate::hyperparams::GaussianNbParams) for information on hyperparameters.
11#[cfg_attr(
12    feature = "serde",
13    derive(Serialize, Deserialize),
14    serde(crate = "serde_crate")
15)]
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct GaussianNbValidParams<F, L> {
18    // Required for calculation stability
19    var_smoothing: F,
20    // Phantom data for label type
21    label: PhantomData<L>,
22}
23
24impl<F: Float, L> GaussianNbValidParams<F, L> {
25    /// Get the variance smoothing
26    pub fn var_smoothing(&self) -> F {
27        self.var_smoothing
28    }
29}
30
31/// A hyper-parameter set during construction for a [Gaussian Naive Bayes model](crate::gaussian_nb::GaussianNb).
32///
33/// The parameter set can be verified into a
34/// [`GaussianNbValidParams`](crate::hyperparams::GaussianNbValidParams) by calling
35/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
36/// [Fit::fit](linfa::traits::Fit::fit) or
37/// [FitWith::fit_with](linfa::traits::FitWith::fit_with) which implicitely verifies the parameter set
38/// prior to the model estimation and forwards any error.
39///
40/// See [`GaussianNb`](crate::gaussian_nb::GaussianNb) for information on the model.
41///
42/// # Parameters
43/// | Name | Default | Purpose | Range |
44/// | :--- | :--- | :---| :--- |
45/// | [var_smoothing](Self::var_smoothing) | `1e-9` | Stabilize variance calculation if ratios are small in update step | `[0, inf)` |
46///
47/// # Errors
48///
49/// The following errors can come from invalid hyper-parameters:
50///
51/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
52/// parameter is negative.
53///
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct GaussianNbParams<F, L>(GaussianNbValidParams<F, L>);
56
57impl<F: Float, L> Default for GaussianNbParams<F, L> {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl<F: Float, L> GaussianNbParams<F, L> {
64    /// Create new [GaussianNbParams] set with default values for its parameters
65    pub fn new() -> Self {
66        Self(GaussianNbValidParams {
67            var_smoothing: F::cast(1e-9),
68            label: PhantomData,
69        })
70    }
71
72    /// Specifies the portion of the largest variance of all the features that
73    /// is added to the variance for calculation stability
74    pub fn var_smoothing(mut self, var_smoothing: F) -> Self {
75        self.0.var_smoothing = var_smoothing;
76        self
77    }
78}
79
80impl<F: Float, L> ParamGuard for GaussianNbParams<F, L> {
81    type Checked = GaussianNbValidParams<F, L>;
82    type Error = NaiveBayesError;
83
84    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
85        if self.0.var_smoothing.is_negative() {
86            Err(NaiveBayesError::InvalidSmoothing(
87                self.0.var_smoothing.to_f64().unwrap(),
88            ))
89        } else {
90            Ok(&self.0)
91        }
92    }
93
94    fn check(self) -> Result<Self::Checked, Self::Error> {
95        self.check_ref()?;
96        Ok(self.0)
97    }
98}
99
100/// A verified hyper-parameter set ready for the estimation of a [Multinomial Naive Bayes model](crate::multinomial_nb::MultinomialNb).
101///
102/// See [`MultinomialNb`](crate::multinomial_nb::MultinomialNb) for information on the model and [`MultinomialNbParams`](crate::hyperparams::MultinomialNbParams) for information on hyperparameters.
103#[cfg_attr(
104    feature = "serde",
105    derive(Serialize, Deserialize),
106    serde(crate = "serde_crate")
107)]
108#[derive(Debug, Clone, PartialEq, Eq)]
109pub struct MultinomialNbValidParams<F, L> {
110    // Required for calculation stability
111    alpha: F,
112    // Phantom data for label type
113    label: PhantomData<L>,
114}
115
116impl<F: Float, L> MultinomialNbValidParams<F, L> {
117    /// Get the variance smoothing
118    pub fn alpha(&self) -> F {
119        self.alpha
120    }
121}
122
123/// A hyper-parameter set during construction for a [Multinomial Naive Bayes model](crate::multinomial_nb::MultinomialNb).
124///
125/// The parameter set can be verified into a
126/// [`MultinomialNbValidParams`](crate::hyperparams::MultinomialNbValidParams) by calling
127/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
128/// [Fit::fit](linfa::traits::Fit::fit) or
129/// [FitWith::fit_with](linfa::traits::FitWith::fit_with) which implicitely verifies the parameter set
130/// prior to the model estimation and forwards any error.
131///
132/// See [`MultinomialNb`](crate::multinomial_nb::MultinomialNb) for information on the model.
133///
134/// # Parameters
135/// | Name | Default | Purpose | Range |
136/// | :--- | :--- | :---| :--- |
137/// | [alpha](Self::alpha) | `1` | Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing) | `[0, inf)` |
138///
139/// # Errors
140///
141/// The following errors can come from invalid hyper-parameters:
142///
143/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
144/// parameter is negative.
145///
146#[derive(Debug, Clone, PartialEq, Eq)]
147pub struct MultinomialNbParams<F, L>(MultinomialNbValidParams<F, L>);
148
149impl<F: Float, L> Default for MultinomialNbParams<F, L> {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl<F: Float, L> MultinomialNbParams<F, L> {
156    /// Create new [MultinomialNbParams] set with default values for its parameters
157    pub fn new() -> Self {
158        Self(MultinomialNbValidParams {
159            alpha: F::cast(1),
160            label: PhantomData,
161        })
162    }
163
164    /// Specifies the portion of the largest variance of all the features that
165    /// is added to the variance for calculation stability
166    pub fn alpha(mut self, alpha: F) -> Self {
167        self.0.alpha = alpha;
168        self
169    }
170}
171
172impl<F: Float, L> ParamGuard for MultinomialNbParams<F, L> {
173    type Checked = MultinomialNbValidParams<F, L>;
174    type Error = NaiveBayesError;
175
176    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
177        if self.0.alpha.is_negative() {
178            Err(NaiveBayesError::InvalidSmoothing(
179                self.0.alpha.to_f64().unwrap(),
180            ))
181        } else {
182            Ok(&self.0)
183        }
184    }
185
186    fn check(self) -> Result<Self::Checked, Self::Error> {
187        self.check_ref()?;
188        Ok(self.0)
189    }
190}
191
192/// A verified hyper-parameter set ready for the estimation of a [Bernoulli Naive Bayes model](crate::bernoulli_nb::BernoulliNb).
193///
194/// See [`BernoulliNb`](crate::bernoulli_nb::BernoulliNb) for information on the model and [`BernoulliNbParams`](crate::hyperparams::BernoulliNbParams) for information on hyperparameters.
195#[derive(Debug, Clone, PartialEq)]
196pub struct BernoulliNbValidParams<F, L> {
197    // Required for calculation stability
198    alpha: F,
199    // Threshold for binarization
200    binarize: Option<F>,
201    // Phantom data for label type
202    label: PhantomData<L>,
203}
204
205impl<F: Float, L> BernoulliNbValidParams<F, L> {
206    /// Get the variance smoothing
207    pub fn alpha(&self) -> F {
208        self.alpha
209    }
210    /// Get the binarization threshold
211    pub fn binarize(&self) -> Option<F> {
212        self.binarize
213    }
214}
215
216/// A hyper-parameter set during construction for a [Bernoulli Naive Bayes model](crate::bernoulli_nb::BernoulliNb).
217///
218/// The parameter set can be verified into a
219/// [`BernoulliNbValidParams`](crate::hyperparams::BernoulliNbValidParams) by calling
220/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
221/// [Fit::fit](linfa::traits::Fit::fit) or
222/// [FitWith::fit_with](linfa::traits::FitWith::fit_with) which implicitly verifies the parameter set
223/// prior to the model estimation and forwards any error.
224///
225/// See [`BernoulliNb`](crate::bernoulli_nb::BernoulliNb) for information on the model.
226///
227/// # Parameters
228/// | Name | Default | Purpose | Range |
229/// | :--- | :--- | :---| :--- |
230/// | [alpha](Self::alpha) | `1` | Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing) | `[0, inf)` |
231/// | [binarize](Self::binarize) | `0.0` | Threshold for binarization (mapping to booleans) of sample features. If `None`, input is presumed to already consist of binary vectors. | `(-inf, inf)` |
232///
233/// # Errors
234///
235/// The following errors can come from invalid hyper-parameters:
236///
237/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
238/// parameter is negative.
239///
240#[derive(Debug, Clone, PartialEq)]
241pub struct BernoulliNbParams<F, L>(BernoulliNbValidParams<F, L>);
242
243impl<F: Float, L> Default for BernoulliNbParams<F, L> {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249impl<F: Float, L> BernoulliNbParams<F, L> {
250    /// Create new [BernoulliNbParams] set with default values for its parameters
251    pub fn new() -> Self {
252        Self(BernoulliNbValidParams {
253            alpha: F::one(),
254            binarize: Some(F::zero()),
255            label: PhantomData,
256        })
257    }
258
259    /// Specifies the portion of the largest variance of all the features that
260    /// is added to the variance for calculation stability
261    pub fn alpha(mut self, alpha: F) -> Self {
262        self.0.alpha = alpha;
263        self
264    }
265
266    /// Set the binarization threshold
267    pub fn binarize(mut self, threshold: Option<F>) -> Self {
268        self.0.binarize = threshold;
269        self
270    }
271}
272
273impl<F: Float, L> ParamGuard for BernoulliNbParams<F, L> {
274    type Checked = BernoulliNbValidParams<F, L>;
275    type Error = NaiveBayesError;
276
277    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
278        if self.0.alpha.is_negative() {
279            Err(NaiveBayesError::InvalidSmoothing(
280                self.0.alpha.to_f64().unwrap(),
281            ))
282        } else {
283            Ok(&self.0)
284        }
285    }
286
287    fn check(self) -> Result<Self::Checked, Self::Error> {
288        self.check_ref()?;
289        Ok(self.0)
290    }
291}