linfa_ensemble/lib.rs
1//! # Ensemble Learning Algorithms
2//!
3//! Ensemble methods combine the predictions of several base estimators built with a given
4//! learning algorithm in order to improve generalizability / robustness over a single estimator.
5//!
6//! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as
7//! * [Boostrap Aggregation](EnsembleLearner)
8//! * [Random Forest](RandomForest)
9//!
10//! ## Bootstrap Aggregation (aka Bagging)
11//!
12//! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of
13//! several decision trees (see [`linfa-trees`](linfa_trees)) trained on different samples subset of the training dataset.
14//!
15//! ## Random Forest
16//!
17//! A special case of Bootstrap Aggregation using decision trees (see [`linfa-trees`](linfa_trees)) with random feature
18//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being
19//! the number of available features.
20//!
21//! ## Reference
22//!
23//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
24//! * [An Introduction to Statistical Learning](https://www.statlearning.com/)
25//!
26//! ## Example
27//!
28//! This example shows how to train a bagging model using 100 decision trees,
29//! each trained on 70% of the training data (bootstrap sampling).
30//!
31//! ```no_run
32//! use linfa::prelude::{Fit, Predict};
33//! use linfa_ensemble::EnsembleLearnerParams;
34//! use linfa_trees::DecisionTree;
35//! use ndarray_rand::rand::SeedableRng;
36//! use rand::rngs::SmallRng;
37//!
38//! // Load Iris dataset
39//! let mut rng = SmallRng::seed_from_u64(42);
40//! let (train, test) = linfa_datasets::iris()
41//! .shuffle(&mut rng)
42//! .split_with_ratio(0.8);
43//!
44//! // Train the model on the iris dataset
45//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
46//! .ensemble_size(100) // Number of Decision Tree to fit
47//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
48//! .fit(&train)
49//! .unwrap();
50//!
51//! // Make predictions on the test set
52//! let predictions = bagging_model.predict(&test);
53//! ```
54//!
55//! This example shows how to train a [Random Forest](RandomForest) model using 100 decision trees,
56//! each trained on 70% of the training data (bootstrap sampling) and using only
57//! 30% of the available features.
58//!
59//! ```no_run
60//! use linfa::prelude::{Fit, Predict};
61//! use linfa_ensemble::RandomForestParams;
62//! use linfa_trees::DecisionTree;
63//! use ndarray_rand::rand::SeedableRng;
64//! use rand::rngs::SmallRng;
65//!
66//! // Load Iris dataset
67//! let mut rng = SmallRng::seed_from_u64(42);
68//! let (train, test) = linfa_datasets::iris()
69//! .shuffle(&mut rng)
70//! .split_with_ratio(0.8);
71//!
72//! // Train the model on the iris dataset
73//! let random_forest = RandomForestParams::new(DecisionTree::params())
74//! .ensemble_size(100) // Number of Decision Tree to fit
75//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
76//! .feature_proportion(0.3) // Select only 30% of the feature
77//! .fit(&train)
78//! .unwrap();
79//!
80//! // Make predictions on the test set
81//! let predictions = random_forest.predict(&test);
82//! ```
83
84mod algorithm;
85mod hyperparams;
86
87pub use algorithm::*;
88pub use hyperparams::*;
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
94 use linfa_trees::DecisionTree;
95 use ndarray_rand::rand::SeedableRng;
96 use rand::rngs::SmallRng;
97
98 #[test]
99 fn test_random_forest_accuracy_on_iris_dataset() {
100 let mut rng = SmallRng::seed_from_u64(42);
101 let (train, test) = linfa_datasets::iris()
102 .shuffle(&mut rng)
103 .split_with_ratio(0.8);
104
105 let model = RandomForestParams::new_fixed_rng(DecisionTree::params(), rng)
106 .ensemble_size(100)
107 .bootstrap_proportion(0.7)
108 .feature_proportion(0.3)
109 .fit(&train)
110 .unwrap();
111
112 let predictions = model.predict(&test);
113
114 let cm = predictions.confusion_matrix(&test).unwrap();
115 let acc = cm.accuracy();
116 assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
117 }
118
119 #[test]
120 fn test_ensemble_learner_accuracy_on_iris_dataset() {
121 let mut rng = SmallRng::seed_from_u64(42);
122 let (train, test) = linfa_datasets::iris()
123 .shuffle(&mut rng)
124 .split_with_ratio(0.8);
125
126 let model = EnsembleLearnerParams::new_fixed_rng(DecisionTree::params(), rng)
127 .ensemble_size(100)
128 .bootstrap_proportion(0.7)
129 .fit(&train)
130 .unwrap();
131
132 let predictions = model.predict(&test);
133
134 let cm = predictions.confusion_matrix(&test).unwrap();
135 let acc = cm.accuracy();
136 assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
137 }
138}