linfa_trees/decision_trees/
algorithm.rs

1//! Linear decision trees
2//!
3use std::cmp::Ordering;
4use std::collections::{HashMap, HashSet, VecDeque};
5use std::hash::{Hash, Hasher};
6
7use linfa::dataset::AsSingleTargets;
8use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, Ix2};
9
10use super::NodeIter;
11use super::Tikz;
12use super::{DecisionTreeValidParams, SplitQuality};
13use linfa::{
14    dataset::{Labels, Records},
15    error::Error,
16    error::Result,
17    traits::*,
18    DatasetBase, Float, Label,
19};
20
21#[cfg(feature = "serde")]
22use serde_crate::{Deserialize, Serialize};
23
24/// RowMask tracks observations
25///
26/// The decision tree algorithm splits observations at a certain split value for a specific feature. The
27/// left and right children can then only use a certain number of observations. In order to track
28/// that, the observations are masked with a boolean vector, hiding all observations which are not
29/// applicable in a lower tree.
30struct RowMask {
31    mask: Vec<bool>,
32    nsamples: usize,
33}
34
35impl RowMask {
36    /// Generates a RowMask without hidden observations
37    ///
38    /// ### Parameters
39    ///
40    /// * `nsamples`: the total number of observations
41    ///
42    fn all(nsamples: usize) -> Self {
43        RowMask {
44            mask: vec![true; nsamples],
45            nsamples,
46        }
47    }
48
49    /// Generates a RowMask where all observations are hidden
50    ///
51    /// ### Parameters
52    ///
53    /// * `nsamples`: the total number of observations
54    fn none(nsamples: usize) -> Self {
55        RowMask {
56            mask: vec![false; nsamples],
57            nsamples: 0,
58        }
59    }
60
61    /// Sets the observation at the specified index as visible
62    ///
63    /// ### Parameters
64    ///
65    /// * `idx`: the index of the observation to turn visible
66    ///
67    /// ### Panics
68    ///
69    /// If `idx` is out of bounds
70    ///
71    fn mark(&mut self, idx: usize) {
72        self.mask[idx] = true;
73        self.nsamples += 1;
74    }
75}
76
77/// Sorted values of observations with indices (always for a particular feature)
78struct SortedIndex<'a, F: Float> {
79    feature_name: &'a str,
80    sorted_values: Vec<(usize, F)>,
81}
82
83impl<'a, F: Float> SortedIndex<'a, F> {
84    /// Sorts the values of a given feature in ascending order
85    ///
86    /// ### Parameters
87    ///
88    /// * `x`: the observations to sort
89    /// * `feature_idx`: the index of the feature on whch to sort the data
90    /// * `feature_name`: the human readable name of the feature
91    ///
92    /// ### Returns
93    ///
94    /// A sorted vector of (index, value) pairs obtained by sorting the observations by
95    /// the value of the specified feature.
96    fn of_array_column(
97        x: &ArrayBase<impl Data<Elem = F>, Ix2>,
98        feature_idx: usize,
99        feature_name: &'a str,
100    ) -> Self {
101        let sliced_column: Vec<F> = x.index_axis(Axis(1), feature_idx).to_vec();
102        let mut pairs: Vec<(usize, F)> = sliced_column.into_iter().enumerate().collect();
103        pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Greater));
104
105        SortedIndex {
106            sorted_values: pairs,
107            feature_name,
108        }
109    }
110}
111
112#[cfg_attr(
113    feature = "serde",
114    derive(Serialize, Deserialize),
115    serde(crate = "serde_crate")
116)]
117#[derive(Debug, Clone)]
118/// A node in the decision tree
119pub struct TreeNode<F, L> {
120    feature_idx: usize,
121    feature_name: String,
122    split_value: F,
123    impurity_decrease: F,
124    left_child: Option<Box<TreeNode<F, L>>>,
125    right_child: Option<Box<TreeNode<F, L>>>,
126    leaf_node: bool,
127    prediction: L,
128    depth: usize,
129}
130
131impl<F: Float, L: Label> Hash for TreeNode<F, L> {
132    fn hash<H: Hasher>(&self, state: &mut H) {
133        let data: Vec<u64> = vec![self.feature_idx as u64, self.leaf_node as u64];
134        data.hash(state);
135    }
136}
137
138impl<F, L> Eq for TreeNode<F, L> {}
139
140impl<F, L> PartialEq for TreeNode<F, L> {
141    fn eq(&self, other: &Self) -> bool {
142        self.feature_idx == other.feature_idx
143    }
144}
145
146impl<F: Float, L: Label + std::fmt::Debug> TreeNode<F, L> {
147    fn empty_leaf(prediction: L, depth: usize) -> Self {
148        TreeNode {
149            feature_idx: 0,
150            feature_name: "".to_string(),
151            split_value: F::zero(),
152            impurity_decrease: F::zero(),
153            left_child: None,
154            right_child: None,
155            leaf_node: true,
156            prediction,
157            depth,
158        }
159    }
160
161    /// Returns true if the node has no children
162    pub fn is_leaf(&self) -> bool {
163        self.leaf_node
164    }
165
166    /// Returns the depth of the node in the decision tree
167    pub fn depth(&self) -> usize {
168        self.depth
169    }
170
171    /// Returns `Some(prediction)` for leaf nodes and `None` for internal nodes.
172    pub fn prediction(&self) -> Option<L> {
173        if self.is_leaf() {
174            Some(self.prediction.clone())
175        } else {
176            None
177        }
178    }
179
180    /// Returns both children, first left then right
181    pub fn children(&self) -> Vec<&Option<Box<TreeNode<F, L>>>> {
182        vec![&self.left_child, &self.right_child]
183    }
184
185    /// Return the split (feature index, value) and its impurity decrease
186    pub fn split(&self) -> (usize, F, F) {
187        (self.feature_idx, self.split_value, self.impurity_decrease)
188    }
189
190    /// Returns the name of the feature used in the split if the node is internal,
191    /// `None` otherwise
192    pub fn feature_name(&self) -> Option<&String> {
193        if self.leaf_node {
194            None
195        } else {
196            Some(&self.feature_name)
197        }
198    }
199
200    /// Recursively fits the node
201    fn fit<D: Data<Elem = F>, T: AsSingleTargets<Elem = L> + Labels<Elem = L>>(
202        data: &DatasetBase<ArrayBase<D, Ix2>, T>,
203        mask: &RowMask,
204        hyperparameters: &DecisionTreeValidParams<F, L>,
205        sorted_indices: &[SortedIndex<F>],
206        depth: usize,
207    ) -> Result<Self> {
208        // compute weighted frequencies for target classes
209        let parent_class_freq = data.label_frequencies_with_mask(&mask.mask);
210        // set our prediction for this subset to the modal class
211        let prediction = find_modal_class(&parent_class_freq);
212        // get targets from dataset
213        let target = data.as_single_targets();
214
215        // return empty leaf when we don't have enough samples or the maximal depth is reached
216        if (mask.nsamples as f32) < hyperparameters.min_weight_split()
217            || hyperparameters
218                .max_depth()
219                .map(|max_depth| depth >= max_depth)
220                .unwrap_or(false)
221        {
222            return Ok(Self::empty_leaf(prediction, depth));
223        }
224
225        // Find best split for current level
226        let mut best = None;
227
228        // Iterate over all features
229        for (feature_idx, sorted_index) in sorted_indices.iter().enumerate() {
230            let mut right_class_freq = parent_class_freq.clone();
231            let mut left_class_freq = HashMap::new();
232
233            // We keep a running total of the aggregate weight in the right split
234            // to avoid having to sum over the hash map
235            let total_weight = parent_class_freq.values().sum::<f32>();
236            let mut weight_on_right_side = total_weight;
237            let mut weight_on_left_side = 0.0;
238
239            // We start by putting all available observations in the right subtree
240            // and then move the (sorted by `feature_idx`) observations one by one to
241            // the left subtree and evaluate the quality of the resulting split. At each
242            // iteration, the obtained split is compared with `best`, in order
243            // to find the best possible split.
244            // The resulting split will then have the observations with a value of their `feature_idx`
245            // feature smaller than the split value in the left subtree and the others still in the right
246            // subtree
247            for i in 0..mask.mask.len() - 1 {
248                // (index of the observation, value of its `feature_idx` feature)
249                let (presorted_index, mut split_value) = sorted_index.sorted_values[i];
250
251                // Skip if the observation is unavailable in this subtree
252                if !mask.mask[presorted_index] {
253                    continue;
254                }
255
256                // Target and weight of the current observation
257                let sample_class = &target[presorted_index];
258                let sample_weight = data.weight_for(presorted_index);
259
260                // Move the observation from the right subtree to the left subtree
261
262                // Decrement the weight on the class for this sample on the right
263                // side by the weight of this sample
264                *right_class_freq.get_mut(sample_class).unwrap() -= sample_weight;
265                weight_on_right_side -= sample_weight;
266
267                // Increment the weight on the class for this sample on the
268                // right side by the weight of this sample
269                *left_class_freq.entry(sample_class.clone()).or_insert(0.0) += sample_weight;
270                weight_on_left_side += sample_weight;
271
272                // Continue if the next value is equal, so that equal values end up in the same subtree
273                if (sorted_index.sorted_values[i].1 - sorted_index.sorted_values[i + 1].1).abs()
274                    < F::cast(1e-5)
275                {
276                    continue;
277                }
278
279                // If the split would result in too few samples in a leaf
280                // then skip computing the quality
281                if weight_on_right_side < hyperparameters.min_weight_leaf()
282                    || weight_on_left_side < hyperparameters.min_weight_leaf()
283                {
284                    continue;
285                }
286
287                // Calculate the quality of each resulting subset of the dataset
288                let (left_score, right_score) = match hyperparameters.split_quality() {
289                    SplitQuality::Gini => (
290                        gini_impurity(&right_class_freq),
291                        gini_impurity(&left_class_freq),
292                    ),
293                    SplitQuality::Entropy => {
294                        (entropy(&right_class_freq), entropy(&left_class_freq))
295                    }
296                };
297
298                // Weight the qualities based on the number of samples in each subset
299                let w = weight_on_right_side / total_weight;
300                let score = w * left_score + (1.0 - w) * right_score;
301
302                // Take the midpoint from this value and the next one as split_value
303                split_value = (split_value + sorted_index.sorted_values[i + 1].1) / F::cast(2.0);
304
305                // override best indices when score improved
306                best = match best.take() {
307                    None => Some((feature_idx, split_value, score)),
308                    Some((_, _, best_score)) if score < best_score => {
309                        Some((feature_idx, split_value, score))
310                    }
311                    x => x,
312                };
313            }
314        }
315
316        // At this point all possible splits for all possible features have been computed
317        // and the best one (if any) is stored in `best`. Now we can compute the
318        // impurity decrease as `impurity of the node before splitting - impurity of the split`.
319        // If the impurity decrease is above the treshold set in the parameters, then the split is
320        // applied and `fit` is recursively called in the two resulting subtrees. If there is no
321        // possible split, or if it doesn't bring enough impurity decrease, then the node is set as
322        // a leaf node that predicts the most common label in the available observations.
323
324        let impurity_decrease = if let Some((_, _, best_score)) = best {
325            let parent_score = match hyperparameters.split_quality() {
326                SplitQuality::Gini => gini_impurity(&parent_class_freq),
327                SplitQuality::Entropy => entropy(&parent_class_freq),
328            };
329            let parent_score = F::cast(parent_score);
330
331            // return empty leaf if impurity has not decreased enough
332            parent_score - F::cast(best_score)
333        } else {
334            // return zero impurity decrease if we have not found any solution
335            F::zero()
336        };
337
338        if impurity_decrease < hyperparameters.min_impurity_decrease() {
339            return Ok(Self::empty_leaf(prediction, depth));
340        }
341
342        let (best_feature_idx, best_split_value, _) = best.unwrap();
343
344        // determine new masks for the left and right subtrees
345        let mut left_mask = RowMask::none(data.nsamples());
346        let mut right_mask = RowMask::none(data.nsamples());
347
348        for i in 0..data.nsamples() {
349            if mask.mask[i] {
350                if data.records()[(i, best_feature_idx)] <= best_split_value {
351                    left_mask.mark(i);
352                } else {
353                    right_mask.mark(i);
354                }
355            }
356        }
357
358        // Recurse and refit on left and right subtrees
359        let left_child = if left_mask.nsamples > 0 {
360            Some(Box::new(TreeNode::fit(
361                data,
362                &left_mask,
363                hyperparameters,
364                sorted_indices,
365                depth + 1,
366            )?))
367        } else {
368            None
369        };
370
371        let right_child = if right_mask.nsamples > 0 {
372            Some(Box::new(TreeNode::fit(
373                data,
374                &right_mask,
375                hyperparameters,
376                sorted_indices,
377                depth + 1,
378            )?))
379        } else {
380            None
381        };
382
383        let leaf_node = left_child.is_none() || right_child.is_none();
384
385        Ok(TreeNode {
386            feature_idx: best_feature_idx,
387            feature_name: sorted_indices[best_feature_idx].feature_name.to_owned(),
388            split_value: best_split_value,
389            impurity_decrease,
390            left_child,
391            right_child,
392            leaf_node,
393            prediction,
394            depth,
395        })
396    }
397
398    /// Prune tree after fitting it
399    ///
400    /// This removes parts of the tree which results in the same prediction for
401    /// all sub-trees. This is called right after fit to ensure that the tree
402    /// is small.
403    fn prune(&mut self) -> Option<L> {
404        if self.is_leaf() {
405            return Some(self.prediction.clone());
406        }
407
408        let left = self.left_child.as_mut().and_then(|x| x.prune());
409        let right = self.right_child.as_mut().and_then(|x| x.prune());
410
411        match (left, right) {
412            (Some(x), Some(y)) => {
413                if x == y {
414                    self.prediction = x.clone();
415                    self.right_child = None;
416                    self.left_child = None;
417                    self.leaf_node = true;
418
419                    Some(x)
420                } else {
421                    None
422                }
423            }
424            _ => None,
425        }
426    }
427}
428
429/// A fitted decision tree model for classification.
430///
431/// ### Structure
432/// A decision tree structure is a binary tree where:
433/// * Each internal node specifies a decision, represented by a choice of a feature and a "split value" such that all observations for which
434///   `feature <= split_value` is true fall in the left subtree, while the others fall in the right subtree.
435///
436/// * leaf nodes make predictions, and their prediction is the most popular label in the node
437///
438/// ### Algorithm
439///
440/// Starting with a single root node, decision trees are trained recursively by applying the following rule to every
441/// node considered:
442///
443/// * Find the best split value for each feature of the observations belonging in the node;
444/// * Select the feature (and its best split value) that maximizes the quality of the split;
445/// * If the score of the split is sufficiently larger than the score of the unsplit node, then two child nodes are generated, the left one
446///   containing all observations with `feature <= split value` and the right one containing the rest.
447/// * If no suitable split is found, the node is marked as leaf and its prediction is set to be the most common label in the node;
448///
449/// The [quality score](SplitQuality) used can be specified in the [parameters](crate::DecisionTreeParams).
450///
451/// ### Predictions
452///
453/// To predict the label of a sample, the tree is traversed from the root to a leaf, choosing between left and right children according to
454/// the values of the features of the sample. The final prediction for the sample is the prediction of the reached leaf.
455///
456/// ### Additional constraints
457///
458/// In order to avoid overfitting the training data, some additional constraints on the quality/quantity of splits can be added to the tree.
459/// A description of these additional rules is provided in the [parameters](crate::DecisionTreeParams) page.
460///
461/// ### Example
462///
463/// Here is an example on how to train a decision tree from its parameters:
464///
465/// ```rust
466///
467/// use linfa_trees::DecisionTree;
468/// use linfa::prelude::*;
469/// use linfa_datasets;
470///
471/// // Load the dataset
472/// let dataset = linfa_datasets::iris();
473/// // Fit the tree
474/// let tree = DecisionTree::params().fit(&dataset).unwrap();
475/// // Get accuracy on training set
476/// let accuracy = tree.predict(&dataset).confusion_matrix(&dataset).unwrap().accuracy();
477///
478/// assert!(accuracy > 0.9);
479///
480/// ```
481///
482#[cfg_attr(
483    feature = "serde",
484    derive(Serialize, Deserialize),
485    serde(crate = "serde_crate")
486)]
487#[derive(Debug, Clone, PartialEq)]
488pub struct DecisionTree<F: Float, L: Label> {
489    root_node: TreeNode<F, L>,
490    num_features: usize,
491}
492
493impl<F: Float, L: Label + Default, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<L>>
494    for DecisionTree<F, L>
495{
496    /// Make predictions for each row of a matrix of features `x`.
497    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
498        assert_eq!(
499            x.nrows(),
500            y.len(),
501            "The number of data points must match the number of output targets."
502        );
503
504        for (row, target) in x.rows().into_iter().zip(y.iter_mut()) {
505            *target = make_prediction(&row, &self.root_node);
506        }
507    }
508
509    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
510        Array1::default(x.nrows())
511    }
512}
513
514impl<F: Float, L: Label + std::fmt::Debug, D, T> Fit<ArrayBase<D, Ix2>, T, Error>
515    for DecisionTreeValidParams<F, L>
516where
517    D: Data<Elem = F>,
518    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
519{
520    type Object = DecisionTree<F, L>;
521
522    /// Fit a decision tree using `hyperparamters` on the dataset consisting of
523    /// a matrix of features `x` and an array of labels `y`.
524    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
525        let x = dataset.records();
526        let feature_names = if dataset.feature_names().is_empty() {
527            (0..x.nfeatures())
528                .map(|idx| format!("feature-{idx}"))
529                .collect()
530        } else {
531            dataset.feature_names().to_vec()
532        };
533        let all_idxs = RowMask::all(x.nrows());
534        let sorted_indices: Vec<_> = (0..(x.ncols()))
535            .map(|feature_idx| {
536                SortedIndex::of_array_column(x, feature_idx, &feature_names[feature_idx])
537            })
538            .collect();
539
540        let mut root_node = TreeNode::fit(dataset, &all_idxs, self, &sorted_indices, 0)?;
541        root_node.prune();
542
543        Ok(DecisionTree {
544            root_node,
545            num_features: dataset.records().ncols(),
546        })
547    }
548}
549
550impl<F: Float, L: Label> DecisionTree<F, L> {
551    /// Create a node iterator in level-order (BFT)
552    pub fn iter_nodes(&self) -> NodeIter<F, L> {
553        // queue of nodes yet to explore
554        let queue = vec![&self.root_node];
555
556        NodeIter::new(VecDeque::from(queue))
557    }
558
559    /// Return features_idx of this tree (BFT)
560    pub fn features(&self) -> Vec<usize> {
561        // vector of feature indexes to return
562        let mut fitted_features = HashSet::new();
563
564        for node in self.iter_nodes().filter(|node| !node.is_leaf()) {
565            if !fitted_features.contains(&node.feature_idx) {
566                fitted_features.insert(node.feature_idx);
567            }
568        }
569
570        fitted_features.into_iter().collect::<Vec<_>>()
571    }
572
573    /// Return the mean impurity decrease for each feature
574    pub fn mean_impurity_decrease(&self) -> Vec<F> {
575        // total impurity decrease for each feature
576        let mut impurity_decrease = vec![F::zero(); self.num_features];
577        let mut num_nodes = vec![0; self.num_features];
578
579        for node in self.iter_nodes().filter(|node| !node.leaf_node) {
580            // add feature impurity decrease to list
581            impurity_decrease[node.feature_idx] += node.impurity_decrease;
582            num_nodes[node.feature_idx] += 1;
583        }
584
585        impurity_decrease
586            .into_iter()
587            .zip(num_nodes)
588            .map(|(val, n)| if n == 0 { F::zero() } else { val / F::cast(n) })
589            .collect()
590    }
591
592    /// Return the relative impurity decrease for each feature
593    pub fn relative_impurity_decrease(&self) -> Vec<F> {
594        let mean_impurity_decrease = self.mean_impurity_decrease();
595        let sum = mean_impurity_decrease.iter().cloned().sum();
596
597        mean_impurity_decrease
598            .into_iter()
599            .map(|x| x / sum)
600            .collect()
601    }
602
603    /// Return the feature importance, i.e. the relative impurity decrease, for each feature
604    pub fn feature_importance(&self) -> Vec<F> {
605        self.relative_impurity_decrease()
606    }
607
608    /// Return root node of the tree
609    pub fn root_node(&self) -> &TreeNode<F, L> {
610        &self.root_node
611    }
612
613    /// Return max depth of the tree
614    pub fn max_depth(&self) -> usize {
615        self.iter_nodes()
616            .fold(0, |max, node| usize::max(max, node.depth))
617    }
618
619    /// Return the number of leaves in this tree
620    pub fn num_leaves(&self) -> usize {
621        self.iter_nodes().filter(|node| node.is_leaf()).count()
622    }
623
624    /// Generates a [`Tikz`] structure to print the
625    /// fitted tree in Tex using tikz and forest, with the following default parameters:
626    ///
627    /// * `legend=false`
628    /// * `complete=true`
629    ///
630    pub fn export_to_tikz(&self) -> Tikz<F, L> {
631        Tikz::new(self)
632    }
633}
634
635/// Classify a sample &x recursively using the tree node `node`.
636fn make_prediction<F: Float, L: Label>(
637    x: &ArrayBase<impl Data<Elem = F>, Ix1>,
638    node: &TreeNode<F, L>,
639) -> L {
640    if node.leaf_node {
641        node.prediction.clone()
642    } else if x[node.feature_idx] < node.split_value {
643        make_prediction(x, node.left_child.as_ref().unwrap())
644    } else {
645        make_prediction(x, node.right_child.as_ref().unwrap())
646    }
647}
648
649/// Finds the most frequent class for a hash map of frequencies. If two
650/// classes have the same weight then the first class found with that
651/// frequency is returned.
652fn find_modal_class<L: Label>(class_freq: &HashMap<L, f32>) -> L {
653    // TODO: Refactor this with fold_first
654
655    let val = class_freq
656        .iter()
657        .fold(None, |acc, (idx, freq)| match acc {
658            None => Some((idx, freq)),
659            Some((_best_idx, best_freq)) => {
660                if best_freq > freq {
661                    acc
662                } else {
663                    Some((idx, freq))
664                }
665            }
666        })
667        .unwrap()
668        .0;
669
670    (*val).clone()
671}
672
673/// Given the class frequencies calculates the gini impurity of the subset.
674fn gini_impurity<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
675    let n_samples = class_freq.values().sum::<f32>();
676    assert!(n_samples > 0.0);
677
678    let purity = class_freq
679        .values()
680        .map(|x| x / n_samples)
681        .map(|x| x * x)
682        .sum::<f32>();
683
684    1.0 - purity
685}
686
687/// Given the class frequencies calculates the entropy of the subset.
688fn entropy<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
689    let n_samples = class_freq.values().sum::<f32>();
690    assert!(n_samples > 0.0);
691
692    class_freq
693        .values()
694        .map(|x| x / n_samples)
695        .map(|x| if x > 0.0 { -x * x.log2() } else { 0.0 })
696        .sum()
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702
703    use approx::assert_abs_diff_eq;
704    use linfa::{error::Result, metrics::ToConfusionMatrix, Dataset, ParamGuard};
705    use ndarray::{array, concatenate, s, Array, Array1, Array2, Axis};
706    use rand::rngs::SmallRng;
707
708    use crate::DecisionTreeParams;
709    use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
710
711    #[test]
712    fn autotraits() {
713        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
714        has_autotraits::<DecisionTree<f64, bool>>();
715        has_autotraits::<TreeNode<f64, bool>>();
716        has_autotraits::<DecisionTreeValidParams<f64, bool>>();
717        has_autotraits::<DecisionTreeParams<f64, bool>>();
718        has_autotraits::<NodeIter<f64, bool>>();
719        has_autotraits::<Tikz<f64, bool>>();
720    }
721
722    #[test]
723    fn prediction_for_rows_example() {
724        let labels = Array::from(vec![0, 0, 0, 0, 0, 0, 1, 1]);
725        let row_mask = RowMask::all(labels.len());
726
727        let dataset: DatasetBase<(), Array1<usize>> = DatasetBase::new((), labels);
728        let class_freq = dataset.label_frequencies_with_mask(&row_mask.mask);
729
730        assert_eq!(find_modal_class(&class_freq), 0);
731    }
732
733    #[test]
734    fn gini_impurity_example() {
735        let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
736
737        // Class 0 occurs 75% of the time
738        // Class 1 occurs 25% of the time
739        // Class 2 occurs 0% of the time
740        // Gini impurity is 1 - 0.75*0.75 - 0.25*0.25 - 0*0 = 0.375
741        assert_abs_diff_eq!(gini_impurity(&class_freq), 0.375, epsilon = 1e-5);
742    }
743
744    #[test]
745    fn entropy_example() {
746        let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
747
748        // Class 0 occurs 75% of the time
749        // Class 1 occurs 25% of the time
750        // Class 2 occurs 0% of the time
751        // Entropy is -0.75*log2(0.75) - 0.25*log2(0.25) - 0*log2(0) = 0.81127812
752        assert_abs_diff_eq!(entropy(&class_freq), 0.81127, epsilon = 1e-5);
753
754        // If split is perfect then entropy is zero
755        let perfect_class_freq = vec![(0, 8.0), (1, 0.0), (2, 0.0)].into_iter().collect();
756
757        assert_abs_diff_eq!(entropy(&perfect_class_freq), 0.0, epsilon = 1e-5);
758    }
759
760    #[test]
761    /// Single feature test
762    ///
763    /// Generate a dataset where a single feature perfectly correlates
764    /// with the target while the remaining features are random gaussian
765    /// noise and do not add any information.
766    fn single_feature_random_noise_binary() -> Result<()> {
767        // generate data with 9 white noise and a single correlated feature
768        let mut data = Array::random((50, 10), Uniform::new(-4., 4.));
769        data.slice_mut(s![.., 8]).assign(
770            &(0..50)
771                .map(|x| if x < 25 { 0.0 } else { 1.0 })
772                .collect::<Array1<_>>(),
773        );
774
775        let targets = (0..50).map(|x| x < 25).collect::<Array1<_>>();
776        let dataset = Dataset::new(data, targets);
777
778        let model = DecisionTree::params().max_depth(Some(2)).fit(&dataset)?;
779
780        // we should only use feature index 8 here
781        assert_eq!(&model.features(), &[8]);
782
783        let ground_truth = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
784
785        for (imp, truth) in model.feature_importance().iter().zip(&ground_truth) {
786            assert_abs_diff_eq!(imp, truth, epsilon = 1e-15);
787        }
788
789        // check for perfect accuracy
790        let cm = model
791            .predict(dataset.records())
792            .confusion_matrix(&dataset)?;
793        assert_abs_diff_eq!(cm.accuracy(), 1.0, epsilon = 1e-15);
794
795        Ok(())
796    }
797
798    #[test]
799    /// Check that for random data the max depth is used
800    fn check_max_depth() -> Result<()> {
801        let mut rng = SmallRng::seed_from_u64(42);
802
803        // create very sparse data
804        let data = Array::random_using((50, 50), Uniform::new(-1., 1.), &mut rng);
805        let targets = (0..50).collect::<Array1<usize>>();
806
807        let dataset = Dataset::new(data, targets);
808
809        // check that the provided depth is actually used
810        for max_depth in &[1, 5, 10, 20] {
811            let model = DecisionTree::params()
812                .max_depth(Some(*max_depth))
813                .min_impurity_decrease(1e-10f64)
814                .min_weight_split(1e-10)
815                .fit(&dataset)?;
816            assert_eq!(model.max_depth(), *max_depth);
817        }
818
819        Ok(())
820    }
821
822    #[test]
823    /// Small perfectly separable dataset test
824    ///
825    /// This dataset of three elements is perfectly using the second feature.
826    fn perfectly_separable_small() -> Result<()> {
827        let data = array![[1., 2., 3.], [1., 2., 4.], [1., 3., 3.5]];
828        let targets = array![0, 0, 1];
829
830        let dataset = Dataset::new(data.clone(), targets);
831        let model = DecisionTree::params().max_depth(Some(1)).fit(&dataset)?;
832
833        assert_eq!(model.predict(&data), array![0, 0, 1]);
834
835        Ok(())
836    }
837
838    #[test]
839    /// Small toy dataset from scikit-sklearn
840    fn toy_dataset() -> Result<()> {
841        let data = array![
842            [0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 1.0, -14.0, 0.0, -4.0, 0.0, 0.0, 0.0, 0.0,],
843            [0.0, 0.0, 5.0, 3.0, 0.0, -4.0, 0.0, 0.0, 1.0, -5.0, 0.2, 0.0, 4.0, 1.0,],
844            [-1.0, -1.0, 0.0, 0.0, -4.5, 0.0, 0.0, 2.1, 1.0, 0.0, 0.0, -4.5, 0.0, 1.0,],
845            [-1.0, -1.0, 0.0, -1.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 1.0,],
846            [-1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,],
847            [-1.0, -2.0, 0.0, 4.0, -3.0, 10.0, 4.0, 0.0, -3.2, 0.0, 4.0, 3.0, -4.0, 1.0,],
848            [2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
849            [2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
850            [2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
851            [2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, 0.0,],
852            [2.0, 8.0, 5.0, 1.0, 0.5, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 2.0, 0.0,],
853            [2.0, 0.0, 1.0, 1.0, 1.0, -1.0, 1.0, 0.0, 0.0, -2.0, 3.0, 0.0, 1.0, 0.0,],
854            [2.0, 0.0, 1.0, 2.0, 3.0, -1.0, 10.0, 2.0, 0.0, -1.0, 1.0, 2.0, 2.0, 0.0,],
855            [1.0, 1.0, 0.0, 2.0, 2.0, -1.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 3.0, 0.0,],
856            [3.0, 1.0, 0.0, 3.0, 0.0, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 3.0, 1.0,],
857            [2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
858            [2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 1.5, 1.0, -1.0, -1.0,],
859            [2.11, 8.0, -6.0, -0.5, 0.0, 10.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, -1.0,],
860            [2.0, 0.0, 5.0, 1.0, 0.5, -2.0, 10.0, 0.0, 1.0, -5.0, 3.0, 1.0, 0.0, -1.0,],
861            [2.0, 0.0, 1.0, 1.0, 1.0, -2.0, 1.0, 0.0, 0.0, -2.0, 0.0, 0.0, 0.0, 1.0,],
862            [2.0, 1.0, 1.0, 1.0, 2.0, -1.0, 10.0, 2.0, 0.0, -1.0, 0.0, 2.0, 1.0, 1.0,],
863            [1.0, 1.0, 0.0, 0.0, 1.0, -3.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 1.0, 1.0,],
864            [3.0, 1.0, 0.0, 1.0, 0.0, -4.0, 1.0, 0.0, 1.0, -2.0, 0.0, 0.0, 1.0, 0.0,]
865        ];
866
867        let targets = array![1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0];
868
869        let dataset = Dataset::new(data, targets);
870        let model = DecisionTree::params().fit(&dataset)?;
871        let prediction = model.predict(&dataset);
872
873        let cm = prediction.confusion_matrix(&dataset)?;
874        assert!(cm.accuracy() > 0.95);
875
876        Ok(())
877    }
878
879    #[test]
880    /// Multilabel classification
881    fn multilabel_four_uniform() -> Result<()> {
882        let mut data = concatenate(
883            Axis(0),
884            &[Array2::random((40, 2), Uniform::new(-1., 1.)).view()],
885        )
886        .unwrap();
887
888        data.outer_iter_mut().enumerate().for_each(|(i, mut p)| {
889            if i < 10 {
890                p += &array![-2., -2.]
891            } else if i < 20 {
892                p += &array![-2., 2.];
893            } else if i < 30 {
894                p += &array![2., -2.];
895            } else {
896                p += &array![2., 2.];
897            }
898        });
899
900        let targets = (0..40)
901            .map(|x| match x {
902                x if x < 10 => 0,
903                x if x < 20 => 1,
904                x if x < 30 => 2,
905                _ => 3,
906            })
907            .collect::<Array1<_>>();
908
909        let dataset = Dataset::new(data.clone(), targets);
910
911        let model = DecisionTree::params().fit(&dataset)?;
912        let prediction = model.predict(data);
913
914        let cm = prediction.confusion_matrix(&dataset)?;
915        assert!(cm.accuracy() > 0.99);
916
917        Ok(())
918    }
919
920    #[test]
921    #[should_panic]
922    /// Check that a small or negative impurity decrease panics
923    fn panic_min_impurity_decrease() {
924        DecisionTree::<f64, bool>::params()
925            .min_impurity_decrease(0.0)
926            .check()
927            .unwrap();
928    }
929}