linfa_ensemble/
algorithm.rs1use crate::EnsembleLearnerValidParams;
2use linfa::{
3 dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records},
4 error::Error,
5 traits::*,
6 DatasetBase,
7};
8use ndarray::{Array2, Axis, Zip};
9use rand::Rng;
10use std::{cmp::Eq, collections::HashMap, hash::Hash};
11
12pub struct EnsembleLearner<M> {
13 pub models: Vec<M>,
14}
15
16impl<M> EnsembleLearner<M> {
17 pub fn generate_predictions<'b, R: Records, T>(
19 &'b self,
20 x: &'b R,
21 ) -> impl Iterator<Item = T> + 'b
22 where
23 M: Predict<&'b R, T>,
24 {
25 self.models.iter().map(move |m| m.predict(x))
26 }
27}
28
29impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M>
30where
31 M: PredictInplace<Array2<F>, T>,
32 <T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug,
33 T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
34{
35 fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
36 let y_array = y.as_targets();
37 assert_eq!(
38 x.nrows(),
39 y_array.len_of(Axis(0)),
40 "The number of data points must match the number of outputs."
41 );
42
43 let predictions = self.generate_predictions(x);
44
45 let mut prediction_maps = y_array.map(|_| HashMap::new());
47
48 for prediction in predictions {
49 let p_arr = prediction.as_targets();
50 assert_eq!(p_arr.shape(), y_array.shape());
51 Zip::from(&mut prediction_maps)
53 .and(&p_arr)
54 .for_each(|map, val| *map.entry(*val).or_insert(0) += 1);
55 }
56
57 let agg_preds = prediction_maps.map(|map| map.iter().max_by_key(|(_, v)| **v).unwrap().0);
59 let mut y_array = y.as_targets_mut();
60 for (y, pred) in y_array.iter_mut().zip(agg_preds.iter()) {
61 *y = **pred
62 }
63 }
64
65 fn default_target(&self, x: &Array2<F>) -> T {
66 self.models[0].default_target(x)
67 }
68}
69
70impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error>
71 for EnsembleLearnerValidParams<P, R>
72where
73 D: Clone,
74 T: FromTargetArrayOwned,
75 T::Elem: Copy + Eq + Hash,
76 T::Owned: AsTargets,
77{
78 type Object = EnsembleLearner<P::Object>;
79
80 fn fit(
81 &self,
82 dataset: &DatasetBase<Array2<D>, T>,
83 ) -> core::result::Result<Self::Object, Error> {
84 let mut models = Vec::new();
85 let mut rng = self.rng.clone();
86
87 let dataset_size =
88 ((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize;
89
90 let iter = dataset.bootstrap_samples(dataset_size, &mut rng);
91
92 for train in iter {
93 let model = self.model_params.fit(&train).unwrap();
94 models.push(model);
95
96 if models.len() == self.ensemble_size {
97 break;
98 }
99 }
100
101 Ok(EnsembleLearner { models })
102 }
103}