linfa_ensemble/lib.rs
1//! # Ensemble Learning Algorithms
2//!
3//! Ensemble methods combine the predictions of several base estimators built with a given
4//! learning algorithm in order to improve generalizability / robustness over a single estimator.
5//!
6//! ## Bootstrap Aggregation (aka Bagging)
7//!
8//! A typical example of ensemble method is Bootstrapo AGgregation, which combines the predictions of
9//! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset.
10//!
11//! ## Reference
12//!
13//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
14//!
15//! ## Example
16//!
17//! This example shows how to train a bagging model using 100 decision trees,
18//! each trained on 70% of the training data (bootstrap sampling).
19//!
20//! ```no_run
21//! use linfa::prelude::{Fit, Predict};
22//! use linfa_ensemble::EnsembleLearnerParams;
23//! use linfa_trees::DecisionTree;
24//! use ndarray_rand::rand::SeedableRng;
25//! use rand::rngs::SmallRng;
26//!
27//! // Load Iris dataset
28//! let mut rng = SmallRng::seed_from_u64(42);
29//! let (train, test) = linfa_datasets::iris()
30//! .shuffle(&mut rng)
31//! .split_with_ratio(0.8);
32//!
33//! // Train the model on the iris dataset
34//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
35//! .ensemble_size(100)
36//! .bootstrap_proportion(0.7)
37//! .fit(&train)
38//! .unwrap();
39//!
40//! // Make predictions on the test set
41//! let predictions = bagging_model.predict(&test);
42//! ```
43//!
44mod algorithm;
45mod hyperparams;
46
47pub use algorithm::*;
48pub use hyperparams::*;
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
54 use linfa_trees::DecisionTree;
55 use ndarray_rand::rand::SeedableRng;
56 use rand::rngs::SmallRng;
57
58 #[test]
59 fn test_ensemble_learner_accuracy_on_iris_dataset() {
60 let mut rng = SmallRng::seed_from_u64(42);
61 let (train, test) = linfa_datasets::iris()
62 .shuffle(&mut rng)
63 .split_with_ratio(0.8);
64
65 let model = EnsembleLearnerParams::new(DecisionTree::params())
66 .ensemble_size(100)
67 .bootstrap_proportion(0.7)
68 .fit(&train)
69 .unwrap();
70
71 let predictions = model.predict(&test);
72
73 let cm = predictions.confusion_matrix(&test).unwrap();
74 let acc = cm.accuracy();
75 assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
76 }
77}