linfa_trees/decision_trees/hyperparams.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
use linfa::{
error::{Error, Result},
Float, Label, ParamGuard,
};
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use crate::DecisionTree;
/// The metric used to determine the feature by which a node is split
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum SplitQuality {
/// Measures the degree of probability of a randomly chosen point in the subtree being misclassified, defined as
/// one minus the sum over all labels of the squared probability of encountering that label.
/// The Gini index of the root is given by the weighted sum of the indexes of its two subtrees.
/// At each step the split is applied to the feature which decreases the most the Gini impurity of the root.
Gini,
/// Measures the entropy of a subtree, defined as the sum over all labels of the probability of encountering that label in the
/// subtree times its logarithm in base two, with negative sign. The entropy of the root minus the weighted sum of the entropy
/// of its two subtrees defines the "information gain" obtained by applying the split. At each step the split is applied to the
/// feature with the biggest information gain
Entropy,
}
/// The set of hyperparameters that can be specified for fitting a
/// [decision tree](DecisionTree).
///
/// ### Example
///
/// ```rust
/// use linfa_trees::{DecisionTree, SplitQuality};
/// use linfa_datasets::iris;
/// use linfa::prelude::*;
///
/// // Initialize the default set of parameters
/// let params = DecisionTree::params();
/// // Set the parameters to the desired values
/// let params = params.split_quality(SplitQuality::Entropy).max_depth(Some(5)).min_weight_leaf(2.);
///
/// // Load the data
/// let (train, val) = linfa_datasets::iris().split_with_ratio(0.9);
/// // Fit the decision tree on the training data
/// let tree = params.fit(&train).unwrap();
/// // Predict on validation and check accuracy
/// let val_accuracy = tree.predict(&val).confusion_matrix(&val).unwrap().accuracy();
/// assert!(val_accuracy > 0.9);
/// ```
///
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DecisionTreeValidParams<F, L> {
split_quality: SplitQuality,
max_depth: Option<usize>,
min_weight_split: f32,
min_weight_leaf: f32,
min_impurity_decrease: F,
label_marker: PhantomData<L>,
}
impl<F: Float, L> DecisionTreeValidParams<F, L> {
pub fn split_quality(&self) -> SplitQuality {
self.split_quality
}
pub fn max_depth(&self) -> Option<usize> {
self.max_depth
}
pub fn min_weight_split(&self) -> f32 {
self.min_weight_split
}
pub fn min_weight_leaf(&self) -> f32 {
self.min_weight_leaf
}
pub fn min_impurity_decrease(&self) -> F {
self.min_impurity_decrease
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DecisionTreeParams<F, L>(DecisionTreeValidParams<F, L>);
impl<F: Float, L: Label> DecisionTreeParams<F, L> {
pub fn new() -> Self {
Self(DecisionTreeValidParams {
split_quality: SplitQuality::Gini,
max_depth: None,
min_weight_split: 2.0,
min_weight_leaf: 1.0,
min_impurity_decrease: F::cast(0.00001),
label_marker: PhantomData,
})
}
/// Sets the metric used to decide the feature on which to split a node
pub fn split_quality(mut self, split_quality: SplitQuality) -> Self {
self.0.split_quality = split_quality;
self
}
/// Sets the optional limit to the depth of the decision tree
pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
self.0.max_depth = max_depth;
self
}
/// Sets the minimum weight of samples required to split a node.
///
/// If the observations do not have associated weights, this value represents
/// the minimum number of samples required to split a node.
pub fn min_weight_split(mut self, min_weight_split: f32) -> Self {
self.0.min_weight_split = min_weight_split;
self
}
/// Sets the minimum weight of samples that a split has to place in each leaf
///
/// If the observations do not have associated weights, this value represents
/// the minimum number of samples that a split has to place in each leaf.
pub fn min_weight_leaf(mut self, min_weight_leaf: f32) -> Self {
self.0.min_weight_leaf = min_weight_leaf;
self
}
/// Sets the minimum decrease in impurity that a split needs to bring in order for it to be applied
pub fn min_impurity_decrease(mut self, min_impurity_decrease: F) -> Self {
self.0.min_impurity_decrease = min_impurity_decrease;
self
}
}
impl<F: Float, L: Label> Default for DecisionTreeParams<F, L> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float, L: Label> DecisionTree<F, L> {
/// Defaults are provided if the optional parameters are not specified:
/// * `split_quality = SplitQuality::Gini`
/// * `max_depth = None`
/// * `min_weight_split = 2.0`
/// * `min_weight_leaf = 1.0`
/// * `min_impurity_decrease = 0.00001`
// Violates the convention that new should return a value of type `Self`
#[allow(clippy::new_ret_no_self)]
pub fn params() -> DecisionTreeParams<F, L> {
DecisionTreeParams::new()
}
}
impl<F: Float, L> ParamGuard for DecisionTreeParams<F, L> {
type Checked = DecisionTreeValidParams<F, L>;
type Error = Error;
fn check_ref(&self) -> Result<&Self::Checked> {
if self.0.min_impurity_decrease < F::epsilon() {
Err(Error::Parameters(format!(
"Minimum impurity decrease should be greater than zero, but was {}",
self.0.min_impurity_decrease
)))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked> {
self.check_ref()?;
Ok(self.0)
}
}