linfa_ensemble/
algorithm.rs

1use crate::EnsembleLearnerValidParams;
2use linfa::{
3    dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records},
4    error::Error,
5    traits::*,
6    DatasetBase,
7};
8use linfa_trees::DecisionTree;
9use ndarray::{Array2, Axis, Zip};
10use rand::Rng;
11use std::{cmp::Eq, collections::HashMap, hash::Hash};
12
13/// A fitted ensemble of [Decision Trees](DecisionTree) trained on a random subset of features.
14///
15/// Check out [EnsembleLearner] documentation for more information regarding [RandomForest] interface.
16pub type RandomForest<F, L> = EnsembleLearner<DecisionTree<F, L>>;
17
18/// A fitted ensemble of learners for classification.
19///
20/// ## Structure
21///
22/// An Ensemble Learner is composed of a collection of fitted models of type `M`.
23///
24/// ## Fitting Algorithm
25///
26/// Given a [DatasetBase](DatasetBase) denoted as `D`,
27/// 1. Create as many distinct bootstrapped subset of the original dataset `D` as number of
28///    distinct model to fit.
29/// 2. Fit each distinct model on a distinct bootstrapped subset of `D`.
30///
31/// Note that the subset size, as well as the subset of feature to use in each training subset can
32/// be specified in the [parameters](crate::EnsembleLearnerParams).
33///
34/// ## Prediction Algorithm
35///
36/// The prediction result is the result of majority voting across the fitted learners.
37///
38/// ## Example
39///
40/// This example shows how to train a bagging model using 100 decision trees,
41/// each trained on 70% of the training data (bootstrap sampling).
42/// ```no_run
43/// use linfa::prelude::{Fit, Predict};
44/// use linfa_ensemble::EnsembleLearnerParams;
45/// use linfa_trees::DecisionTree;
46/// use ndarray_rand::rand::SeedableRng;
47/// use rand::rngs::SmallRng;
48///
49/// // Load Iris dataset
50/// let mut rng = SmallRng::seed_from_u64(42);
51/// let (train, test) = linfa_datasets::iris()
52///     .shuffle(&mut rng)
53///     .split_with_ratio(0.8);
54///
55/// // Train the model on the iris dataset
56/// let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
57///     .ensemble_size(100)        // Number of Decision Tree to fit
58///     .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
59///     .fit(&train)
60///     .unwrap();
61///
62/// // Make predictions on the test set
63/// let predictions = bagging_model.predict(&test);
64/// ```
65///
66/// ## References
67///
68/// * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
69/// * [An Introduction to Statistical Learning](https://www.statlearning.com/)
70pub struct EnsembleLearner<M> {
71    pub models: Vec<M>,
72    pub model_features: Vec<Vec<usize>>,
73}
74
75impl<M> EnsembleLearner<M> {
76    // Generates prediction iterator returning predictions from each model
77    pub fn generate_predictions<'b, R: Records, T>(
78        &'b self,
79        x: &'b [R],
80    ) -> impl Iterator<Item = T> + 'b
81    where
82        M: Predict<&'b R, T>,
83    {
84        self.models
85            .iter()
86            .zip(x.iter())
87            .map(move |(m, sub_data)| m.predict(sub_data))
88    }
89}
90
91impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M>
92where
93    M: PredictInplace<Array2<F>, T>,
94    <T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug,
95    T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
96{
97    fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
98        let y_array = y.as_targets();
99        assert_eq!(
100            x.nrows(),
101            y_array.len_of(Axis(0)),
102            "The number of data points must match the number of outputs."
103        );
104
105        let sub_datas = self
106            .model_features
107            .iter()
108            .map(|feat| x.select(Axis(1), feat))
109            .collect::<Vec<_>>();
110        let predictions = self.generate_predictions(&sub_datas);
111
112        // prediction map has same shape as y_array, but the elements are maps
113        let mut prediction_maps = y_array.map(|_| HashMap::new());
114
115        for prediction in predictions {
116            let p_arr = prediction.as_targets();
117            assert_eq!(p_arr.shape(), y_array.shape());
118            // Insert each prediction value into the corresponding map
119            Zip::from(&mut prediction_maps)
120                .and(&p_arr)
121                .for_each(|map, val| *map.entry(*val).or_insert(0) += 1);
122        }
123
124        // For each prediction, pick the result with the highest number of votes
125        let agg_preds = prediction_maps.map(|map| map.iter().max_by_key(|(_, v)| **v).unwrap().0);
126        let mut y_array = y.as_targets_mut();
127        for (y, pred) in y_array.iter_mut().zip(agg_preds.iter()) {
128            *y = **pred
129        }
130    }
131
132    fn default_target(&self, x: &Array2<F>) -> T {
133        self.models[0].default_target(x)
134    }
135}
136
137impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error>
138    for EnsembleLearnerValidParams<P, R>
139where
140    D: Clone,
141    T: FromTargetArrayOwned,
142    T::Elem: Copy + Eq + Hash,
143    T::Owned: AsTargets,
144{
145    type Object = EnsembleLearner<P::Object>;
146
147    fn fit(
148        &self,
149        dataset: &DatasetBase<Array2<D>, T>,
150    ) -> core::result::Result<Self::Object, Error> {
151        let mut models = Vec::with_capacity(self.ensemble_size);
152        let mut model_features = Vec::with_capacity(self.ensemble_size);
153        let mut rng = self.rng.clone();
154
155        // Compute dataset and the subset of features ratio to be selected
156        let dataset_size =
157            ((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize;
158        let n_feat = dataset.records.ncols();
159        let n_sub = ((n_feat as f64) * self.feature_proportion).ceil() as usize;
160
161        let iter = dataset.bootstrap_with_indices((dataset_size, n_sub), &mut rng);
162        for (train, _, feature_selected) in iter {
163            let model = self.model_params.fit(&train).unwrap();
164            models.push(model);
165            model_features.push(feature_selected);
166
167            if models.len() == self.ensemble_size {
168                break;
169            }
170        }
171
172        Ok(EnsembleLearner {
173            models,
174            model_features,
175        })
176    }
177}