linfa_ensemble/
hyperparams.rs1use linfa::{
2 error::{Error, Result},
3 ParamGuard,
4};
5use linfa_trees::DecisionTreeParams;
6use rand::rngs::ThreadRng;
7use rand::Rng;
8
9#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct EnsembleLearnerValidParams<P, R> {
13 pub ensemble_size: usize,
15 pub bootstrap_proportion: f64,
17 pub feature_proportion: f64,
19 pub model_params: P,
21 pub rng: R,
22}
23
24#[derive(Clone, Copy, Debug, PartialEq)]
26pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);
27
28pub type RandomForestParams<F, L, R> = EnsembleLearnerParams<DecisionTreeParams<F, L>, R>;
30
31impl<P> EnsembleLearnerParams<P, ThreadRng> {
32 pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
33 Self::new_fixed_rng(model_params, rand::thread_rng())
34 }
35}
36
37impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
38 pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams<P, R> {
39 Self(EnsembleLearnerValidParams {
40 ensemble_size: 1,
41 bootstrap_proportion: 1.0,
42 feature_proportion: 1.0,
43 model_params,
44 rng,
45 })
46 }
47
48 pub fn ensemble_size(mut self, size: usize) -> Self {
50 self.0.ensemble_size = size;
51 self
52 }
53
54 pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
59 self.0.bootstrap_proportion = proportion;
60 self
61 }
62
63 pub fn feature_proportion(mut self, proportion: f64) -> Self {
68 self.0.feature_proportion = proportion;
69 self
70 }
71}
72
73impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
74 type Checked = EnsembleLearnerValidParams<P, R>;
75 type Error = Error;
76
77 fn check_ref(&self) -> Result<&Self::Checked> {
78 if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 {
79 Err(Error::Parameters(format!(
80 "Bootstrap proportion should be greater than zero and less than or equal to one, but was {}",
81 self.0.bootstrap_proportion
82 )))
83 } else if self.0.ensemble_size < 1 {
84 Err(Error::Parameters(format!(
85 "Ensemble size should be less than one, but was {}",
86 self.0.ensemble_size
87 )))
88 } else if self.0.feature_proportion > 1.0 || self.0.feature_proportion <= 0.0 {
89 Err(Error::Parameters(format!(
90 "Feature proportion should be greater than zero and less than or equal to one, but was {}",
91 self.0.feature_proportion
92 )))
93 } else {
94 Ok(&self.0)
95 }
96 }
97
98 fn check(self) -> Result<Self::Checked> {
99 self.check_ref()?;
100 Ok(self.0)
101 }
102}