linfa_ensemble/
hyperparams.rs

1use linfa::{
2    error::{Error, Result},
3    ParamGuard,
4};
5use linfa_trees::DecisionTreeParams;
6use rand::rngs::ThreadRng;
7use rand::Rng;
8
9/// The set of valid hyper-parameters that can be specified for the fitting procedure of the
10/// [Ensemble Learner](crate::EnsembleLearner).
11#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct EnsembleLearnerValidParams<P, R> {
13    /// The number of models in the ensemble
14    pub ensemble_size: usize,
15    /// The proportion of the total number of training samples that should be given to each model for training
16    pub bootstrap_proportion: f64,
17    /// The proportion of the total number of training features that should be given to each model for training
18    pub feature_proportion: f64,
19    /// The model parameters for the base model
20    pub model_params: P,
21    pub rng: R,
22}
23
24/// A helper struct for building a set of [Ensemble Learner](crate::EnsembleLearner) hyper-parameters.
25#[derive(Clone, Copy, Debug, PartialEq)]
26pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);
27
28/// A helper struct for building a set of [Random Forest](crate::RandomForest) hyper-parameters.
29pub type RandomForestParams<F, L, R> = EnsembleLearnerParams<DecisionTreeParams<F, L>, R>;
30
31impl<P> EnsembleLearnerParams<P, ThreadRng> {
32    pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
33        Self::new_fixed_rng(model_params, rand::thread_rng())
34    }
35}
36
37impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
38    pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams<P, R> {
39        Self(EnsembleLearnerValidParams {
40            ensemble_size: 1,
41            bootstrap_proportion: 1.0,
42            feature_proportion: 1.0,
43            model_params,
44            rng,
45        })
46    }
47
48    /// Specifies the number of models to fit in the ensemble.
49    pub fn ensemble_size(mut self, size: usize) -> Self {
50        self.0.ensemble_size = size;
51        self
52    }
53
54    /// Sets the proportion of the total number of training samples that should be given to each model for training
55    ///
56    /// Note that the `proportion` should be in the interval (0, 1] in order to pass the  
57    /// parameter validation check.
58    pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
59        self.0.bootstrap_proportion = proportion;
60        self
61    }
62
63    /// Sets the proportion of the total number of training features that should be given to each model for training
64    ///
65    /// Note that the `proportion` should be in the interval (0, 1] in order to pass the
66    /// parameter validation check.
67    pub fn feature_proportion(mut self, proportion: f64) -> Self {
68        self.0.feature_proportion = proportion;
69        self
70    }
71}
72
73impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
74    type Checked = EnsembleLearnerValidParams<P, R>;
75    type Error = Error;
76
77    fn check_ref(&self) -> Result<&Self::Checked> {
78        if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 {
79            Err(Error::Parameters(format!(
80                "Bootstrap proportion should be greater than zero and less than or equal to one, but was {}",
81                self.0.bootstrap_proportion
82            )))
83        } else if self.0.ensemble_size < 1 {
84            Err(Error::Parameters(format!(
85                "Ensemble size should be less than one, but was {}",
86                self.0.ensemble_size
87            )))
88        } else if self.0.feature_proportion > 1.0 || self.0.feature_proportion <= 0.0 {
89            Err(Error::Parameters(format!(
90                "Feature proportion should be greater than zero and less than or equal to one, but was {}",
91                self.0.feature_proportion
92            )))
93        } else {
94            Ok(&self.0)
95        }
96    }
97
98    fn check(self) -> Result<Self::Checked> {
99        self.check_ref()?;
100        Ok(self.0)
101    }
102}