linfa_ensemble/
hyperparams.rs

1use linfa::{
2    error::{Error, Result},
3    ParamGuard,
4};
5use rand::rngs::ThreadRng;
6use rand::Rng;
7
8#[derive(Clone, Copy, Debug, PartialEq)]
9pub struct EnsembleLearnerValidParams<P, R> {
10    /// The number of models in the ensemble
11    pub ensemble_size: usize,
12    /// The proportion of the total number of training samples that should be given to each model for training
13    pub bootstrap_proportion: f64,
14    /// The model parameters for the base model
15    pub model_params: P,
16    pub rng: R,
17}
18
19#[derive(Clone, Copy, Debug, PartialEq)]
20pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);
21
22impl<P> EnsembleLearnerParams<P, ThreadRng> {
23    pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
24        Self::new_fixed_rng(model_params, rand::thread_rng())
25    }
26}
27
28impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
29    pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams<P, R> {
30        Self(EnsembleLearnerValidParams {
31            ensemble_size: 1,
32            bootstrap_proportion: 1.0,
33            model_params,
34            rng,
35        })
36    }
37
38    pub fn ensemble_size(mut self, size: usize) -> Self {
39        self.0.ensemble_size = size;
40        self
41    }
42
43    pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
44        self.0.bootstrap_proportion = proportion;
45        self
46    }
47}
48
49impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
50    type Checked = EnsembleLearnerValidParams<P, R>;
51    type Error = Error;
52
53    fn check_ref(&self) -> Result<&Self::Checked> {
54        if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 {
55            Err(Error::Parameters(format!(
56                "Bootstrap proportion should be greater than zero and less than or equal to one, but was {}",
57                self.0.bootstrap_proportion
58            )))
59        } else if self.0.ensemble_size < 1 {
60            Err(Error::Parameters(format!(
61                "Ensemble size should be less than one, but was {}",
62                self.0.ensemble_size
63            )))
64        } else {
65            Ok(&self.0)
66        }
67    }
68
69    fn check(self) -> Result<Self::Checked> {
70        self.check_ref()?;
71        Ok(self.0)
72    }
73}