linfa_ensemble/
algorithm.rs1use crate::EnsembleLearnerValidParams;
2use linfa::{
3 dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records},
4 error::Error,
5 traits::*,
6 DatasetBase,
7};
8use linfa_trees::DecisionTree;
9use ndarray::{Array2, Axis, Zip};
10use rand::Rng;
11use std::{cmp::Eq, collections::HashMap, hash::Hash};
12
13pub type RandomForest<F, L> = EnsembleLearner<DecisionTree<F, L>>;
17
18pub struct EnsembleLearner<M> {
71 pub models: Vec<M>,
72 pub model_features: Vec<Vec<usize>>,
73}
74
75impl<M> EnsembleLearner<M> {
76 pub fn generate_predictions<'b, R: Records, T>(
78 &'b self,
79 x: &'b [R],
80 ) -> impl Iterator<Item = T> + 'b
81 where
82 M: Predict<&'b R, T>,
83 {
84 self.models
85 .iter()
86 .zip(x.iter())
87 .map(move |(m, sub_data)| m.predict(sub_data))
88 }
89}
90
91impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M>
92where
93 M: PredictInplace<Array2<F>, T>,
94 <T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug,
95 T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
96{
97 fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
98 let y_array = y.as_targets();
99 assert_eq!(
100 x.nrows(),
101 y_array.len_of(Axis(0)),
102 "The number of data points must match the number of outputs."
103 );
104
105 let sub_datas = self
106 .model_features
107 .iter()
108 .map(|feat| x.select(Axis(1), feat))
109 .collect::<Vec<_>>();
110 let predictions = self.generate_predictions(&sub_datas);
111
112 let mut prediction_maps = y_array.map(|_| HashMap::new());
114
115 for prediction in predictions {
116 let p_arr = prediction.as_targets();
117 assert_eq!(p_arr.shape(), y_array.shape());
118 Zip::from(&mut prediction_maps)
120 .and(&p_arr)
121 .for_each(|map, val| *map.entry(*val).or_insert(0) += 1);
122 }
123
124 let agg_preds = prediction_maps.map(|map| map.iter().max_by_key(|(_, v)| **v).unwrap().0);
126 let mut y_array = y.as_targets_mut();
127 for (y, pred) in y_array.iter_mut().zip(agg_preds.iter()) {
128 *y = **pred
129 }
130 }
131
132 fn default_target(&self, x: &Array2<F>) -> T {
133 self.models[0].default_target(x)
134 }
135}
136
137impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error>
138 for EnsembleLearnerValidParams<P, R>
139where
140 D: Clone,
141 T: FromTargetArrayOwned,
142 T::Elem: Copy + Eq + Hash,
143 T::Owned: AsTargets,
144{
145 type Object = EnsembleLearner<P::Object>;
146
147 fn fit(
148 &self,
149 dataset: &DatasetBase<Array2<D>, T>,
150 ) -> core::result::Result<Self::Object, Error> {
151 let mut models = Vec::with_capacity(self.ensemble_size);
152 let mut model_features = Vec::with_capacity(self.ensemble_size);
153 let mut rng = self.rng.clone();
154
155 let dataset_size =
157 ((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize;
158 let n_feat = dataset.records.ncols();
159 let n_sub = ((n_feat as f64) * self.feature_proportion).ceil() as usize;
160
161 let iter = dataset.bootstrap_with_indices((dataset_size, n_sub), &mut rng);
162 for (train, _, feature_selected) in iter {
163 let model = self.model_params.fit(&train).unwrap();
164 models.push(model);
165 model_features.push(feature_selected);
166
167 if models.len() == self.ensemble_size {
168 break;
169 }
170 }
171
172 Ok(EnsembleLearner {
173 models,
174 model_features,
175 })
176 }
177}