linfa_trees/decision_trees/
hyperparams.rs

1use linfa::{
2    error::{Error, Result},
3    Float, Label, ParamGuard,
4};
5use std::marker::PhantomData;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10use crate::DecisionTree;
11
12/// The metric used to determine the feature by which a node is split
13#[cfg_attr(
14    feature = "serde",
15    derive(Serialize, Deserialize),
16    serde(crate = "serde_crate")
17)]
18#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
19pub enum SplitQuality {
20    /// Measures the degree of probability of a randomly chosen point in the subtree being misclassified, defined as
21    /// one minus the sum over all labels of the squared probability of encountering that label.
22    /// The Gini index of the root is given by the weighted sum of the indexes of its two subtrees.
23    /// At each step the split is applied to the feature which decreases the most the Gini impurity of the root.
24    Gini,
25    /// Measures the entropy of a subtree, defined as the sum over all labels of the probability of encountering that label in the
26    /// subtree times its logarithm in base two, with negative sign. The entropy of the root minus the weighted sum of the entropy
27    /// of its two subtrees defines the "information gain" obtained by applying the split. At each step the split is applied to the
28    /// feature with the biggest information gain
29    Entropy,
30}
31
32/// The set of hyperparameters that can be specified for fitting a
33/// [decision tree](DecisionTree).
34///
35/// ### Example
36///
37/// ```rust
38/// use linfa_trees::{DecisionTree, SplitQuality};
39/// use linfa_datasets::iris;
40/// use linfa::prelude::*;
41///
42/// // Initialize the default set of parameters
43/// let params = DecisionTree::params();
44/// // Set the parameters to the desired values
45/// let params = params.split_quality(SplitQuality::Entropy).max_depth(Some(5)).min_weight_leaf(2.);
46///
47/// // Load the data
48/// let (train, val) = linfa_datasets::iris().split_with_ratio(0.9);
49/// // Fit the decision tree on the training data
50/// let tree = params.fit(&train).unwrap();
51/// // Predict on validation and check accuracy
52/// let val_accuracy = tree.predict(&val).confusion_matrix(&val).unwrap().accuracy();
53/// assert!(val_accuracy > 0.9);
54/// ```
55///
56#[cfg_attr(
57    feature = "serde",
58    derive(Serialize, Deserialize),
59    serde(crate = "serde_crate")
60)]
61#[derive(Clone, Copy, Debug, PartialEq)]
62pub struct DecisionTreeValidParams<F, L> {
63    split_quality: SplitQuality,
64    max_depth: Option<usize>,
65    min_weight_split: f32,
66    min_weight_leaf: f32,
67    min_impurity_decrease: F,
68
69    label_marker: PhantomData<L>,
70}
71
72impl<F: Float, L> DecisionTreeValidParams<F, L> {
73    pub fn split_quality(&self) -> SplitQuality {
74        self.split_quality
75    }
76
77    pub fn max_depth(&self) -> Option<usize> {
78        self.max_depth
79    }
80
81    pub fn min_weight_split(&self) -> f32 {
82        self.min_weight_split
83    }
84
85    pub fn min_weight_leaf(&self) -> f32 {
86        self.min_weight_leaf
87    }
88
89    pub fn min_impurity_decrease(&self) -> F {
90        self.min_impurity_decrease
91    }
92}
93
94#[cfg_attr(
95    feature = "serde",
96    derive(Serialize, Deserialize),
97    serde(crate = "serde_crate")
98)]
99#[derive(Clone, Copy, Debug, PartialEq)]
100pub struct DecisionTreeParams<F, L>(DecisionTreeValidParams<F, L>);
101
102impl<F: Float, L: Label> DecisionTreeParams<F, L> {
103    pub fn new() -> Self {
104        Self(DecisionTreeValidParams {
105            split_quality: SplitQuality::Gini,
106            max_depth: None,
107            min_weight_split: 2.0,
108            min_weight_leaf: 1.0,
109            min_impurity_decrease: F::cast(0.00001),
110            label_marker: PhantomData,
111        })
112    }
113
114    /// Sets the metric used to decide the feature on which to split a node
115    pub fn split_quality(mut self, split_quality: SplitQuality) -> Self {
116        self.0.split_quality = split_quality;
117        self
118    }
119
120    /// Sets the optional limit to the depth of the decision tree
121    pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
122        self.0.max_depth = max_depth;
123        self
124    }
125
126    /// Sets the minimum weight of samples required to split a node.
127    ///
128    /// If the observations do not have associated weights, this value represents
129    /// the minimum number of samples required to split a node.
130    pub fn min_weight_split(mut self, min_weight_split: f32) -> Self {
131        self.0.min_weight_split = min_weight_split;
132        self
133    }
134
135    /// Sets the minimum weight of samples that a split has to place in each leaf
136    ///
137    /// If the observations do not have associated weights, this value represents
138    /// the minimum number of samples that a split has to place in each leaf.
139    pub fn min_weight_leaf(mut self, min_weight_leaf: f32) -> Self {
140        self.0.min_weight_leaf = min_weight_leaf;
141        self
142    }
143
144    /// Sets the minimum decrease in impurity that a split needs to bring in order for it to be applied
145    pub fn min_impurity_decrease(mut self, min_impurity_decrease: F) -> Self {
146        self.0.min_impurity_decrease = min_impurity_decrease;
147        self
148    }
149}
150
151impl<F: Float, L: Label> Default for DecisionTreeParams<F, L> {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157impl<F: Float, L: Label> DecisionTree<F, L> {
158    /// Defaults are provided if the optional parameters are not specified:
159    /// * `split_quality = SplitQuality::Gini`
160    /// * `max_depth = None`
161    /// * `min_weight_split = 2.0`
162    /// * `min_weight_leaf = 1.0`
163    /// * `min_impurity_decrease = 0.00001`
164    // Violates the convention that new should return a value of type `Self`
165    #[allow(clippy::new_ret_no_self)]
166    pub fn params() -> DecisionTreeParams<F, L> {
167        DecisionTreeParams::new()
168    }
169}
170
171impl<F: Float, L> ParamGuard for DecisionTreeParams<F, L> {
172    type Checked = DecisionTreeValidParams<F, L>;
173    type Error = Error;
174
175    fn check_ref(&self) -> Result<&Self::Checked> {
176        if self.0.min_impurity_decrease < F::epsilon() {
177            Err(Error::Parameters(format!(
178                "Minimum impurity decrease should be greater than zero, but was {}",
179                self.0.min_impurity_decrease
180            )))
181        } else {
182            Ok(&self.0)
183        }
184    }
185
186    fn check(self) -> Result<Self::Checked> {
187        self.check_ref()?;
188        Ok(self.0)
189    }
190}