linfa_ensemble/
hyperparams.rs1use linfa::{
2 error::{Error, Result},
3 ParamGuard,
4};
5use rand::rngs::ThreadRng;
6use rand::Rng;
7
8#[derive(Clone, Copy, Debug, PartialEq)]
9pub struct EnsembleLearnerValidParams<P, R> {
10 pub ensemble_size: usize,
12 pub bootstrap_proportion: f64,
14 pub model_params: P,
16 pub rng: R,
17}
18
19#[derive(Clone, Copy, Debug, PartialEq)]
20pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);
21
22impl<P> EnsembleLearnerParams<P, ThreadRng> {
23 pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
24 Self::new_fixed_rng(model_params, rand::thread_rng())
25 }
26}
27
28impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
29 pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams<P, R> {
30 Self(EnsembleLearnerValidParams {
31 ensemble_size: 1,
32 bootstrap_proportion: 1.0,
33 model_params,
34 rng,
35 })
36 }
37
38 pub fn ensemble_size(mut self, size: usize) -> Self {
39 self.0.ensemble_size = size;
40 self
41 }
42
43 pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
44 self.0.bootstrap_proportion = proportion;
45 self
46 }
47}
48
49impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
50 type Checked = EnsembleLearnerValidParams<P, R>;
51 type Error = Error;
52
53 fn check_ref(&self) -> Result<&Self::Checked> {
54 if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 {
55 Err(Error::Parameters(format!(
56 "Bootstrap proportion should be greater than zero and less than or equal to one, but was {}",
57 self.0.bootstrap_proportion
58 )))
59 } else if self.0.ensemble_size < 1 {
60 Err(Error::Parameters(format!(
61 "Ensemble size should be less than one, but was {}",
62 self.0.ensemble_size
63 )))
64 } else {
65 Ok(&self.0)
66 }
67 }
68
69 fn check(self) -> Result<Self::Checked> {
70 self.check_ref()?;
71 Ok(self.0)
72 }
73}