linfa_ensemble/
algorithm.rs

1use crate::EnsembleLearnerValidParams;
2use linfa::{
3    dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records},
4    error::Error,
5    traits::*,
6    DatasetBase,
7};
8use ndarray::{Array2, Axis, Zip};
9use rand::Rng;
10use std::{cmp::Eq, collections::HashMap, hash::Hash};
11
12pub struct EnsembleLearner<M> {
13    pub models: Vec<M>,
14}
15
16impl<M> EnsembleLearner<M> {
17    // Generates prediction iterator returning predictions from each model
18    pub fn generate_predictions<'b, R: Records, T>(
19        &'b self,
20        x: &'b R,
21    ) -> impl Iterator<Item = T> + 'b
22    where
23        M: Predict<&'b R, T>,
24    {
25        self.models.iter().map(move |m| m.predict(x))
26    }
27}
28
29impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M>
30where
31    M: PredictInplace<Array2<F>, T>,
32    <T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug,
33    T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
34{
35    fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
36        let y_array = y.as_targets();
37        assert_eq!(
38            x.nrows(),
39            y_array.len_of(Axis(0)),
40            "The number of data points must match the number of outputs."
41        );
42
43        let predictions = self.generate_predictions(x);
44
45        // prediction map has same shape as y_array, but the elements are maps
46        let mut prediction_maps = y_array.map(|_| HashMap::new());
47
48        for prediction in predictions {
49            let p_arr = prediction.as_targets();
50            assert_eq!(p_arr.shape(), y_array.shape());
51            // Insert each prediction value into the corresponding map
52            Zip::from(&mut prediction_maps)
53                .and(&p_arr)
54                .for_each(|map, val| *map.entry(*val).or_insert(0) += 1);
55        }
56
57        // For each prediction, pick the result with the highest number of votes
58        let agg_preds = prediction_maps.map(|map| map.iter().max_by_key(|(_, v)| **v).unwrap().0);
59        let mut y_array = y.as_targets_mut();
60        for (y, pred) in y_array.iter_mut().zip(agg_preds.iter()) {
61            *y = **pred
62        }
63    }
64
65    fn default_target(&self, x: &Array2<F>) -> T {
66        self.models[0].default_target(x)
67    }
68}
69
70impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error>
71    for EnsembleLearnerValidParams<P, R>
72where
73    D: Clone,
74    T: FromTargetArrayOwned,
75    T::Elem: Copy + Eq + Hash,
76    T::Owned: AsTargets,
77{
78    type Object = EnsembleLearner<P::Object>;
79
80    fn fit(
81        &self,
82        dataset: &DatasetBase<Array2<D>, T>,
83    ) -> core::result::Result<Self::Object, Error> {
84        let mut models = Vec::new();
85        let mut rng = self.rng.clone();
86
87        let dataset_size =
88            ((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize;
89
90        let iter = dataset.bootstrap_samples(dataset_size, &mut rng);
91
92        for train in iter {
93            let model = self.model_params.fit(&train).unwrap();
94            models.push(model);
95
96            if models.len() == self.ensemble_size {
97                break;
98            }
99        }
100
101        Ok(EnsembleLearner { models })
102    }
103}