linfa_bayes/hyperparams.rs
1use crate::NaiveBayesError;
2use linfa::{Float, ParamGuard};
3use std::marker::PhantomData;
4
5#[cfg(feature = "serde")]
6use serde_crate::{Deserialize, Serialize};
7
8/// A verified hyper-parameter set ready for the estimation of a [Gaussian Naive Bayes model](crate::gaussian_nb::GaussianNb).
9///
10/// See [`GaussianNb`](crate::gaussian_nb::GaussianNb) for information on the model and [`GaussianNbParams`](crate::hyperparams::GaussianNbParams) for information on hyperparameters.
11#[cfg_attr(
12 feature = "serde",
13 derive(Serialize, Deserialize),
14 serde(crate = "serde_crate")
15)]
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct GaussianNbValidParams<F, L> {
18 // Required for calculation stability
19 var_smoothing: F,
20 // Phantom data for label type
21 label: PhantomData<L>,
22}
23
24impl<F: Float, L> GaussianNbValidParams<F, L> {
25 /// Get the variance smoothing
26 pub fn var_smoothing(&self) -> F {
27 self.var_smoothing
28 }
29}
30
31/// A hyper-parameter set during construction for a [Gaussian Naive Bayes model](crate::gaussian_nb::GaussianNb).
32///
33/// The parameter set can be verified into a
34/// [`GaussianNbValidParams`](crate::hyperparams::GaussianNbValidParams) by calling
35/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
36/// [Fit::fit](linfa::traits::Fit::fit) or
37/// [FitWith::fit_with](linfa::traits::FitWith::fit_with) which implicitely verifies the parameter set
38/// prior to the model estimation and forwards any error.
39///
40/// See [`GaussianNb`](crate::gaussian_nb::GaussianNb) for information on the model.
41///
42/// # Parameters
43/// | Name | Default | Purpose | Range |
44/// | :--- | :--- | :---| :--- |
45/// | [var_smoothing](Self::var_smoothing) | `1e-9` | Stabilize variance calculation if ratios are small in update step | `[0, inf)` |
46///
47/// # Errors
48///
49/// The following errors can come from invalid hyper-parameters:
50///
51/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
52/// parameter is negative.
53///
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct GaussianNbParams<F, L>(GaussianNbValidParams<F, L>);
56
57impl<F: Float, L> Default for GaussianNbParams<F, L> {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl<F: Float, L> GaussianNbParams<F, L> {
64 /// Create new [GaussianNbParams] set with default values for its parameters
65 pub fn new() -> Self {
66 Self(GaussianNbValidParams {
67 var_smoothing: F::cast(1e-9),
68 label: PhantomData,
69 })
70 }
71
72 /// Specifies the portion of the largest variance of all the features that
73 /// is added to the variance for calculation stability
74 pub fn var_smoothing(mut self, var_smoothing: F) -> Self {
75 self.0.var_smoothing = var_smoothing;
76 self
77 }
78}
79
80impl<F: Float, L> ParamGuard for GaussianNbParams<F, L> {
81 type Checked = GaussianNbValidParams<F, L>;
82 type Error = NaiveBayesError;
83
84 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
85 if self.0.var_smoothing.is_negative() {
86 Err(NaiveBayesError::InvalidSmoothing(
87 self.0.var_smoothing.to_f64().unwrap(),
88 ))
89 } else {
90 Ok(&self.0)
91 }
92 }
93
94 fn check(self) -> Result<Self::Checked, Self::Error> {
95 self.check_ref()?;
96 Ok(self.0)
97 }
98}
99
100/// A verified hyper-parameter set ready for the estimation of a [Multinomial Naive Bayes model](crate::multinomial_nb::MultinomialNb).
101///
102/// See [`MultinomialNb`](crate::multinomial_nb::MultinomialNb) for information on the model and [`MultinomialNbParams`](crate::hyperparams::MultinomialNbParams) for information on hyperparameters.
103#[cfg_attr(
104 feature = "serde",
105 derive(Serialize, Deserialize),
106 serde(crate = "serde_crate")
107)]
108#[derive(Debug, Clone, PartialEq, Eq)]
109pub struct MultinomialNbValidParams<F, L> {
110 // Required for calculation stability
111 alpha: F,
112 // Phantom data for label type
113 label: PhantomData<L>,
114}
115
116impl<F: Float, L> MultinomialNbValidParams<F, L> {
117 /// Get the variance smoothing
118 pub fn alpha(&self) -> F {
119 self.alpha
120 }
121}
122
123/// A hyper-parameter set during construction for a [Multinomial Naive Bayes model](crate::multinomial_nb::MultinomialNb).
124///
125/// The parameter set can be verified into a
126/// [`MultinomialNbValidParams`](crate::hyperparams::MultinomialNbValidParams) by calling
127/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
128/// [Fit::fit](linfa::traits::Fit::fit) or
129/// [FitWith::fit_with](linfa::traits::FitWith::fit_with) which implicitely verifies the parameter set
130/// prior to the model estimation and forwards any error.
131///
132/// See [`MultinomialNb`](crate::multinomial_nb::MultinomialNb) for information on the model.
133///
134/// # Parameters
135/// | Name | Default | Purpose | Range |
136/// | :--- | :--- | :---| :--- |
137/// | [alpha](Self::alpha) | `1` | Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing) | `[0, inf)` |
138///
139/// # Errors
140///
141/// The following errors can come from invalid hyper-parameters:
142///
143/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
144/// parameter is negative.
145///
146#[derive(Debug, Clone, PartialEq, Eq)]
147pub struct MultinomialNbParams<F, L>(MultinomialNbValidParams<F, L>);
148
149impl<F: Float, L> Default for MultinomialNbParams<F, L> {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl<F: Float, L> MultinomialNbParams<F, L> {
156 /// Create new [MultinomialNbParams] set with default values for its parameters
157 pub fn new() -> Self {
158 Self(MultinomialNbValidParams {
159 alpha: F::cast(1),
160 label: PhantomData,
161 })
162 }
163
164 /// Specifies the portion of the largest variance of all the features that
165 /// is added to the variance for calculation stability
166 pub fn alpha(mut self, alpha: F) -> Self {
167 self.0.alpha = alpha;
168 self
169 }
170}
171
172impl<F: Float, L> ParamGuard for MultinomialNbParams<F, L> {
173 type Checked = MultinomialNbValidParams<F, L>;
174 type Error = NaiveBayesError;
175
176 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
177 if self.0.alpha.is_negative() {
178 Err(NaiveBayesError::InvalidSmoothing(
179 self.0.alpha.to_f64().unwrap(),
180 ))
181 } else {
182 Ok(&self.0)
183 }
184 }
185
186 fn check(self) -> Result<Self::Checked, Self::Error> {
187 self.check_ref()?;
188 Ok(self.0)
189 }
190}
191
192/// A verified hyper-parameter set ready for the estimation of a [Bernoulli Naive Bayes model](crate::bernoulli_nb::BernoulliNb).
193///
194/// See [`BernoulliNb`](crate::bernoulli_nb::BernoulliNb) for information on the model and [`BernoulliNbParams`](crate::hyperparams::BernoulliNbParams) for information on hyperparameters.
195#[derive(Debug, Clone, PartialEq)]
196pub struct BernoulliNbValidParams<F, L> {
197 // Required for calculation stability
198 alpha: F,
199 // Threshold for binarization
200 binarize: Option<F>,
201 // Phantom data for label type
202 label: PhantomData<L>,
203}
204
205impl<F: Float, L> BernoulliNbValidParams<F, L> {
206 /// Get the variance smoothing
207 pub fn alpha(&self) -> F {
208 self.alpha
209 }
210 /// Get the binarization threshold
211 pub fn binarize(&self) -> Option<F> {
212 self.binarize
213 }
214}
215
216/// A hyper-parameter set during construction for a [Bernoulli Naive Bayes model](crate::bernoulli_nb::BernoulliNb).
217///
218/// The parameter set can be verified into a
219/// [`BernoulliNbValidParams`](crate::hyperparams::BernoulliNbValidParams) by calling
220/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
221/// [Fit::fit](linfa::traits::Fit::fit) or
222/// [FitWith::fit_with](linfa::traits::FitWith::fit_with) which implicitly verifies the parameter set
223/// prior to the model estimation and forwards any error.
224///
225/// See [`BernoulliNb`](crate::bernoulli_nb::BernoulliNb) for information on the model.
226///
227/// # Parameters
228/// | Name | Default | Purpose | Range |
229/// | :--- | :--- | :---| :--- |
230/// | [alpha](Self::alpha) | `1` | Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing) | `[0, inf)` |
231/// | [binarize](Self::binarize) | `0.0` | Threshold for binarization (mapping to booleans) of sample features. If `None`, input is presumed to already consist of binary vectors. | `(-inf, inf)` |
232///
233/// # Errors
234///
235/// The following errors can come from invalid hyper-parameters:
236///
237/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
238/// parameter is negative.
239///
240#[derive(Debug, Clone, PartialEq)]
241pub struct BernoulliNbParams<F, L>(BernoulliNbValidParams<F, L>);
242
243impl<F: Float, L> Default for BernoulliNbParams<F, L> {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249impl<F: Float, L> BernoulliNbParams<F, L> {
250 /// Create new [BernoulliNbParams] set with default values for its parameters
251 pub fn new() -> Self {
252 Self(BernoulliNbValidParams {
253 alpha: F::one(),
254 binarize: Some(F::zero()),
255 label: PhantomData,
256 })
257 }
258
259 /// Specifies the portion of the largest variance of all the features that
260 /// is added to the variance for calculation stability
261 pub fn alpha(mut self, alpha: F) -> Self {
262 self.0.alpha = alpha;
263 self
264 }
265
266 /// Set the binarization threshold
267 pub fn binarize(mut self, threshold: Option<F>) -> Self {
268 self.0.binarize = threshold;
269 self
270 }
271}
272
273impl<F: Float, L> ParamGuard for BernoulliNbParams<F, L> {
274 type Checked = BernoulliNbValidParams<F, L>;
275 type Error = NaiveBayesError;
276
277 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
278 if self.0.alpha.is_negative() {
279 Err(NaiveBayesError::InvalidSmoothing(
280 self.0.alpha.to_f64().unwrap(),
281 ))
282 } else {
283 Ok(&self.0)
284 }
285 }
286
287 fn check(self) -> Result<Self::Checked, Self::Error> {
288 self.check_ref()?;
289 Ok(self.0)
290 }
291}