linfa_trees/decision_trees/hyperparams.rs
1use linfa::{
2 error::{Error, Result},
3 Float, Label, ParamGuard,
4};
5use std::marker::PhantomData;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10use crate::DecisionTree;
11
12/// The metric used to determine the feature by which a node is split
13#[cfg_attr(
14 feature = "serde",
15 derive(Serialize, Deserialize),
16 serde(crate = "serde_crate")
17)]
18#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
19pub enum SplitQuality {
20 /// Measures the degree of probability of a randomly chosen point in the subtree being misclassified, defined as
21 /// one minus the sum over all labels of the squared probability of encountering that label.
22 /// The Gini index of the root is given by the weighted sum of the indexes of its two subtrees.
23 /// At each step the split is applied to the feature which decreases the most the Gini impurity of the root.
24 Gini,
25 /// Measures the entropy of a subtree, defined as the sum over all labels of the probability of encountering that label in the
26 /// subtree times its logarithm in base two, with negative sign. The entropy of the root minus the weighted sum of the entropy
27 /// of its two subtrees defines the "information gain" obtained by applying the split. At each step the split is applied to the
28 /// feature with the biggest information gain
29 Entropy,
30}
31
32/// The set of hyperparameters that can be specified for fitting a
33/// [decision tree](DecisionTree).
34///
35/// ### Example
36///
37/// ```rust
38/// use linfa_trees::{DecisionTree, SplitQuality};
39/// use linfa_datasets::iris;
40/// use linfa::prelude::*;
41///
42/// // Initialize the default set of parameters
43/// let params = DecisionTree::params();
44/// // Set the parameters to the desired values
45/// let params = params.split_quality(SplitQuality::Entropy).max_depth(Some(5)).min_weight_leaf(2.);
46///
47/// // Load the data
48/// let (train, val) = linfa_datasets::iris().split_with_ratio(0.9);
49/// // Fit the decision tree on the training data
50/// let tree = params.fit(&train).unwrap();
51/// // Predict on validation and check accuracy
52/// let val_accuracy = tree.predict(&val).confusion_matrix(&val).unwrap().accuracy();
53/// assert!(val_accuracy > 0.9);
54/// ```
55///
56#[cfg_attr(
57 feature = "serde",
58 derive(Serialize, Deserialize),
59 serde(crate = "serde_crate")
60)]
61#[derive(Clone, Copy, Debug, PartialEq)]
62pub struct DecisionTreeValidParams<F, L> {
63 split_quality: SplitQuality,
64 max_depth: Option<usize>,
65 min_weight_split: f32,
66 min_weight_leaf: f32,
67 min_impurity_decrease: F,
68
69 label_marker: PhantomData<L>,
70}
71
72impl<F: Float, L> DecisionTreeValidParams<F, L> {
73 pub fn split_quality(&self) -> SplitQuality {
74 self.split_quality
75 }
76
77 pub fn max_depth(&self) -> Option<usize> {
78 self.max_depth
79 }
80
81 pub fn min_weight_split(&self) -> f32 {
82 self.min_weight_split
83 }
84
85 pub fn min_weight_leaf(&self) -> f32 {
86 self.min_weight_leaf
87 }
88
89 pub fn min_impurity_decrease(&self) -> F {
90 self.min_impurity_decrease
91 }
92}
93
94#[cfg_attr(
95 feature = "serde",
96 derive(Serialize, Deserialize),
97 serde(crate = "serde_crate")
98)]
99#[derive(Clone, Copy, Debug, PartialEq)]
100pub struct DecisionTreeParams<F, L>(DecisionTreeValidParams<F, L>);
101
102impl<F: Float, L: Label> DecisionTreeParams<F, L> {
103 pub fn new() -> Self {
104 Self(DecisionTreeValidParams {
105 split_quality: SplitQuality::Gini,
106 max_depth: None,
107 min_weight_split: 2.0,
108 min_weight_leaf: 1.0,
109 min_impurity_decrease: F::cast(0.00001),
110 label_marker: PhantomData,
111 })
112 }
113
114 /// Sets the metric used to decide the feature on which to split a node
115 pub fn split_quality(mut self, split_quality: SplitQuality) -> Self {
116 self.0.split_quality = split_quality;
117 self
118 }
119
120 /// Sets the optional limit to the depth of the decision tree
121 pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
122 self.0.max_depth = max_depth;
123 self
124 }
125
126 /// Sets the minimum weight of samples required to split a node.
127 ///
128 /// If the observations do not have associated weights, this value represents
129 /// the minimum number of samples required to split a node.
130 pub fn min_weight_split(mut self, min_weight_split: f32) -> Self {
131 self.0.min_weight_split = min_weight_split;
132 self
133 }
134
135 /// Sets the minimum weight of samples that a split has to place in each leaf
136 ///
137 /// If the observations do not have associated weights, this value represents
138 /// the minimum number of samples that a split has to place in each leaf.
139 pub fn min_weight_leaf(mut self, min_weight_leaf: f32) -> Self {
140 self.0.min_weight_leaf = min_weight_leaf;
141 self
142 }
143
144 /// Sets the minimum decrease in impurity that a split needs to bring in order for it to be applied
145 pub fn min_impurity_decrease(mut self, min_impurity_decrease: F) -> Self {
146 self.0.min_impurity_decrease = min_impurity_decrease;
147 self
148 }
149}
150
151impl<F: Float, L: Label> Default for DecisionTreeParams<F, L> {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157impl<F: Float, L: Label> DecisionTree<F, L> {
158 /// Defaults are provided if the optional parameters are not specified:
159 /// * `split_quality = SplitQuality::Gini`
160 /// * `max_depth = None`
161 /// * `min_weight_split = 2.0`
162 /// * `min_weight_leaf = 1.0`
163 /// * `min_impurity_decrease = 0.00001`
164 // Violates the convention that new should return a value of type `Self`
165 #[allow(clippy::new_ret_no_self)]
166 pub fn params() -> DecisionTreeParams<F, L> {
167 DecisionTreeParams::new()
168 }
169}
170
171impl<F: Float, L> ParamGuard for DecisionTreeParams<F, L> {
172 type Checked = DecisionTreeValidParams<F, L>;
173 type Error = Error;
174
175 fn check_ref(&self) -> Result<&Self::Checked> {
176 if self.0.min_impurity_decrease < F::epsilon() {
177 Err(Error::Parameters(format!(
178 "Minimum impurity decrease should be greater than zero, but was {}",
179 self.0.min_impurity_decrease
180 )))
181 } else {
182 Ok(&self.0)
183 }
184 }
185
186 fn check(self) -> Result<Self::Checked> {
187 self.check_ref()?;
188 Ok(self.0)
189 }
190}