linfa_ensemble/
lib.rs

1//! # Ensemble Learning Algorithms
2//!
3//! Ensemble methods combine the predictions of several base estimators built with a given
4//! learning algorithm in order to improve generalizability / robustness over a single estimator.
5//!
6//! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as
7//! * [Boostrap Aggregation](EnsembleLearner)
8//! * [Random Forest](RandomForest)
9//!
10//! ## Bootstrap Aggregation (aka Bagging)
11//!
12//! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of
13//! several decision trees (see [`linfa-trees`](linfa_trees)) trained on different samples subset of the training dataset.
14//!
15//! ## Random Forest
16//!
17//! A special case of Bootstrap Aggregation using decision trees (see  [`linfa-trees`](linfa_trees)) with random feature
18//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being
19//! the number of available features.
20//!
21//! ## Reference
22//!
23//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
24//! * [An Introduction to Statistical Learning](https://www.statlearning.com/)
25//!
26//! ## Example
27//!
28//! This example shows how to train a bagging model using 100 decision trees,
29//! each trained on 70% of the training data (bootstrap sampling).
30//!
31//! ```no_run
32//! use linfa::prelude::{Fit, Predict};
33//! use linfa_ensemble::EnsembleLearnerParams;
34//! use linfa_trees::DecisionTree;
35//! use ndarray_rand::rand::SeedableRng;
36//! use rand::rngs::SmallRng;
37//!
38//! // Load Iris dataset
39//! let mut rng = SmallRng::seed_from_u64(42);
40//! let (train, test) = linfa_datasets::iris()
41//!     .shuffle(&mut rng)
42//!     .split_with_ratio(0.8);
43//!
44//! // Train the model on the iris dataset
45//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
46//!     .ensemble_size(100)        // Number of Decision Tree to fit
47//!     .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
48//!     .fit(&train)
49//!     .unwrap();
50//!
51//! // Make predictions on the test set
52//! let predictions = bagging_model.predict(&test);
53//! ```
54//!
55//! This example shows how to train a [Random Forest](RandomForest) model using 100 decision trees,
56//! each trained on 70% of the training data (bootstrap sampling) and using only
57//! 30% of the available features.
58//!
59//! ```no_run
60//! use linfa::prelude::{Fit, Predict};
61//! use linfa_ensemble::RandomForestParams;
62//! use linfa_trees::DecisionTree;
63//! use ndarray_rand::rand::SeedableRng;
64//! use rand::rngs::SmallRng;
65//!
66//! // Load Iris dataset
67//! let mut rng = SmallRng::seed_from_u64(42);
68//! let (train, test) = linfa_datasets::iris()
69//!     .shuffle(&mut rng)
70//!     .split_with_ratio(0.8);
71//!
72//! // Train the model on the iris dataset
73//! let random_forest = RandomForestParams::new(DecisionTree::params())
74//!     .ensemble_size(100)        // Number of Decision Tree to fit
75//!     .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
76//!     .feature_proportion(0.3)   // Select only 30% of the feature
77//!     .fit(&train)
78//!     .unwrap();
79//!
80//! // Make predictions on the test set
81//! let predictions = random_forest.predict(&test);
82//! ```
83
84mod algorithm;
85mod hyperparams;
86
87pub use algorithm::*;
88pub use hyperparams::*;
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
94    use linfa_trees::DecisionTree;
95    use ndarray_rand::rand::SeedableRng;
96    use rand::rngs::SmallRng;
97
98    #[test]
99    fn test_random_forest_accuracy_on_iris_dataset() {
100        let mut rng = SmallRng::seed_from_u64(42);
101        let (train, test) = linfa_datasets::iris()
102            .shuffle(&mut rng)
103            .split_with_ratio(0.8);
104
105        let model = RandomForestParams::new_fixed_rng(DecisionTree::params(), rng)
106            .ensemble_size(100)
107            .bootstrap_proportion(0.7)
108            .feature_proportion(0.3)
109            .fit(&train)
110            .unwrap();
111
112        let predictions = model.predict(&test);
113
114        let cm = predictions.confusion_matrix(&test).unwrap();
115        let acc = cm.accuracy();
116        assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
117    }
118
119    #[test]
120    fn test_ensemble_learner_accuracy_on_iris_dataset() {
121        let mut rng = SmallRng::seed_from_u64(42);
122        let (train, test) = linfa_datasets::iris()
123            .shuffle(&mut rng)
124            .split_with_ratio(0.8);
125
126        let model = EnsembleLearnerParams::new_fixed_rng(DecisionTree::params(), rng)
127            .ensemble_size(100)
128            .bootstrap_proportion(0.7)
129            .fit(&train)
130            .unwrap();
131
132        let predictions = model.predict(&test);
133
134        let cm = predictions.confusion_matrix(&test).unwrap();
135        let acc = cm.accuracy();
136        assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
137    }
138}