linfa_ensemble/
lib.rs

1//! # Ensemble Learning Algorithms
2//!
3//! Ensemble methods combine the predictions of several base estimators built with a given
4//! learning algorithm in order to improve generalizability / robustness over a single estimator.
5//!
6//! ## Bootstrap Aggregation (aka Bagging)
7//!
8//! A typical example of ensemble method is Bootstrapo AGgregation, which combines the predictions of
9//! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset.
10//!
11//! ## Reference
12//!
13//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
14//!
15//! ## Example
16//!
17//! This example shows how to train a bagging model using 100 decision trees,
18//! each trained on 70% of the training data (bootstrap sampling).
19//!
20//! ```no_run
21//! use linfa::prelude::{Fit, Predict};
22//! use linfa_ensemble::EnsembleLearnerParams;
23//! use linfa_trees::DecisionTree;
24//! use ndarray_rand::rand::SeedableRng;
25//! use rand::rngs::SmallRng;
26//!
27//! // Load Iris dataset
28//! let mut rng = SmallRng::seed_from_u64(42);
29//! let (train, test) = linfa_datasets::iris()
30//!     .shuffle(&mut rng)
31//!     .split_with_ratio(0.8);
32//!
33//! // Train the model on the iris dataset
34//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
35//!     .ensemble_size(100)
36//!     .bootstrap_proportion(0.7)
37//!     .fit(&train)
38//!     .unwrap();
39//!
40//! // Make predictions on the test set
41//! let predictions = bagging_model.predict(&test);
42//! ```
43//!
44mod algorithm;
45mod hyperparams;
46
47pub use algorithm::*;
48pub use hyperparams::*;
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
54    use linfa_trees::DecisionTree;
55    use ndarray_rand::rand::SeedableRng;
56    use rand::rngs::SmallRng;
57
58    #[test]
59    fn test_ensemble_learner_accuracy_on_iris_dataset() {
60        let mut rng = SmallRng::seed_from_u64(42);
61        let (train, test) = linfa_datasets::iris()
62            .shuffle(&mut rng)
63            .split_with_ratio(0.8);
64
65        let model = EnsembleLearnerParams::new(DecisionTree::params())
66            .ensemble_size(100)
67            .bootstrap_proportion(0.7)
68            .fit(&train)
69            .unwrap();
70
71        let predictions = model.predict(&test);
72
73        let cm = predictions.confusion_matrix(&test).unwrap();
74        let acc = cm.accuracy();
75        assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
76    }
77}