1use std::cmp::Ordering;
4use std::collections::{HashMap, HashSet, VecDeque};
5use std::hash::{Hash, Hasher};
6
7use linfa::dataset::AsSingleTargets;
8use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, Ix2};
9
10use super::NodeIter;
11use super::Tikz;
12use super::{DecisionTreeValidParams, SplitQuality};
13use linfa::{
14 dataset::{Labels, Records},
15 error::Error,
16 error::Result,
17 traits::*,
18 DatasetBase, Float, Label,
19};
20
21#[cfg(feature = "serde")]
22use serde_crate::{Deserialize, Serialize};
23
24struct RowMask {
31 mask: Vec<bool>,
32 nsamples: usize,
33}
34
35impl RowMask {
36 fn all(nsamples: usize) -> Self {
43 RowMask {
44 mask: vec![true; nsamples],
45 nsamples,
46 }
47 }
48
49 fn none(nsamples: usize) -> Self {
55 RowMask {
56 mask: vec![false; nsamples],
57 nsamples: 0,
58 }
59 }
60
61 fn mark(&mut self, idx: usize) {
72 self.mask[idx] = true;
73 self.nsamples += 1;
74 }
75}
76
77struct SortedIndex<'a, F: Float> {
79 feature_name: &'a str,
80 sorted_values: Vec<(usize, F)>,
81}
82
83impl<'a, F: Float> SortedIndex<'a, F> {
84 fn of_array_column(
97 x: &ArrayBase<impl Data<Elem = F>, Ix2>,
98 feature_idx: usize,
99 feature_name: &'a str,
100 ) -> Self {
101 let sliced_column: Vec<F> = x.index_axis(Axis(1), feature_idx).to_vec();
102 let mut pairs: Vec<(usize, F)> = sliced_column.into_iter().enumerate().collect();
103 pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Greater));
104
105 SortedIndex {
106 sorted_values: pairs,
107 feature_name,
108 }
109 }
110}
111
112#[cfg_attr(
113 feature = "serde",
114 derive(Serialize, Deserialize),
115 serde(crate = "serde_crate")
116)]
117#[derive(Debug, Clone)]
118pub struct TreeNode<F, L> {
120 feature_idx: usize,
121 feature_name: String,
122 split_value: F,
123 impurity_decrease: F,
124 left_child: Option<Box<TreeNode<F, L>>>,
125 right_child: Option<Box<TreeNode<F, L>>>,
126 leaf_node: bool,
127 prediction: L,
128 depth: usize,
129}
130
131impl<F: Float, L: Label> Hash for TreeNode<F, L> {
132 fn hash<H: Hasher>(&self, state: &mut H) {
133 let data: Vec<u64> = vec![self.feature_idx as u64, self.leaf_node as u64];
134 data.hash(state);
135 }
136}
137
138impl<F, L> Eq for TreeNode<F, L> {}
139
140impl<F, L> PartialEq for TreeNode<F, L> {
141 fn eq(&self, other: &Self) -> bool {
142 self.feature_idx == other.feature_idx
143 }
144}
145
146impl<F: Float, L: Label + std::fmt::Debug> TreeNode<F, L> {
147 fn empty_leaf(prediction: L, depth: usize) -> Self {
148 TreeNode {
149 feature_idx: 0,
150 feature_name: "".to_string(),
151 split_value: F::zero(),
152 impurity_decrease: F::zero(),
153 left_child: None,
154 right_child: None,
155 leaf_node: true,
156 prediction,
157 depth,
158 }
159 }
160
161 pub fn is_leaf(&self) -> bool {
163 self.leaf_node
164 }
165
166 pub fn depth(&self) -> usize {
168 self.depth
169 }
170
171 pub fn prediction(&self) -> Option<L> {
173 if self.is_leaf() {
174 Some(self.prediction.clone())
175 } else {
176 None
177 }
178 }
179
180 pub fn children(&self) -> Vec<&Option<Box<TreeNode<F, L>>>> {
182 vec![&self.left_child, &self.right_child]
183 }
184
185 pub fn split(&self) -> (usize, F, F) {
187 (self.feature_idx, self.split_value, self.impurity_decrease)
188 }
189
190 pub fn feature_name(&self) -> Option<&String> {
193 if self.leaf_node {
194 None
195 } else {
196 Some(&self.feature_name)
197 }
198 }
199
200 fn fit<D: Data<Elem = F>, T: AsSingleTargets<Elem = L> + Labels<Elem = L>>(
202 data: &DatasetBase<ArrayBase<D, Ix2>, T>,
203 mask: &RowMask,
204 hyperparameters: &DecisionTreeValidParams<F, L>,
205 sorted_indices: &[SortedIndex<F>],
206 depth: usize,
207 ) -> Result<Self> {
208 let parent_class_freq = data.label_frequencies_with_mask(&mask.mask);
210 let prediction = find_modal_class(&parent_class_freq);
212 let target = data.as_single_targets();
214
215 if (mask.nsamples as f32) < hyperparameters.min_weight_split()
217 || hyperparameters
218 .max_depth()
219 .map(|max_depth| depth >= max_depth)
220 .unwrap_or(false)
221 {
222 return Ok(Self::empty_leaf(prediction, depth));
223 }
224
225 let mut best = None;
227
228 for (feature_idx, sorted_index) in sorted_indices.iter().enumerate() {
230 let mut right_class_freq = parent_class_freq.clone();
231 let mut left_class_freq = HashMap::new();
232
233 let total_weight = parent_class_freq.values().sum::<f32>();
236 let mut weight_on_right_side = total_weight;
237 let mut weight_on_left_side = 0.0;
238
239 for i in 0..mask.mask.len() - 1 {
248 let (presorted_index, mut split_value) = sorted_index.sorted_values[i];
250
251 if !mask.mask[presorted_index] {
253 continue;
254 }
255
256 let sample_class = &target[presorted_index];
258 let sample_weight = data.weight_for(presorted_index);
259
260 *right_class_freq.get_mut(sample_class).unwrap() -= sample_weight;
265 weight_on_right_side -= sample_weight;
266
267 *left_class_freq.entry(sample_class.clone()).or_insert(0.0) += sample_weight;
270 weight_on_left_side += sample_weight;
271
272 if (sorted_index.sorted_values[i].1 - sorted_index.sorted_values[i + 1].1).abs()
274 < F::cast(1e-5)
275 {
276 continue;
277 }
278
279 if weight_on_right_side < hyperparameters.min_weight_leaf()
282 || weight_on_left_side < hyperparameters.min_weight_leaf()
283 {
284 continue;
285 }
286
287 let (left_score, right_score) = match hyperparameters.split_quality() {
289 SplitQuality::Gini => (
290 gini_impurity(&right_class_freq),
291 gini_impurity(&left_class_freq),
292 ),
293 SplitQuality::Entropy => {
294 (entropy(&right_class_freq), entropy(&left_class_freq))
295 }
296 };
297
298 let w = weight_on_right_side / total_weight;
300 let score = w * left_score + (1.0 - w) * right_score;
301
302 split_value = (split_value + sorted_index.sorted_values[i + 1].1) / F::cast(2.0);
304
305 best = match best.take() {
307 None => Some((feature_idx, split_value, score)),
308 Some((_, _, best_score)) if score < best_score => {
309 Some((feature_idx, split_value, score))
310 }
311 x => x,
312 };
313 }
314 }
315
316 let impurity_decrease = if let Some((_, _, best_score)) = best {
325 let parent_score = match hyperparameters.split_quality() {
326 SplitQuality::Gini => gini_impurity(&parent_class_freq),
327 SplitQuality::Entropy => entropy(&parent_class_freq),
328 };
329 let parent_score = F::cast(parent_score);
330
331 parent_score - F::cast(best_score)
333 } else {
334 F::zero()
336 };
337
338 if impurity_decrease < hyperparameters.min_impurity_decrease() {
339 return Ok(Self::empty_leaf(prediction, depth));
340 }
341
342 let (best_feature_idx, best_split_value, _) = best.unwrap();
343
344 let mut left_mask = RowMask::none(data.nsamples());
346 let mut right_mask = RowMask::none(data.nsamples());
347
348 for i in 0..data.nsamples() {
349 if mask.mask[i] {
350 if data.records()[(i, best_feature_idx)] <= best_split_value {
351 left_mask.mark(i);
352 } else {
353 right_mask.mark(i);
354 }
355 }
356 }
357
358 let left_child = if left_mask.nsamples > 0 {
360 Some(Box::new(TreeNode::fit(
361 data,
362 &left_mask,
363 hyperparameters,
364 sorted_indices,
365 depth + 1,
366 )?))
367 } else {
368 None
369 };
370
371 let right_child = if right_mask.nsamples > 0 {
372 Some(Box::new(TreeNode::fit(
373 data,
374 &right_mask,
375 hyperparameters,
376 sorted_indices,
377 depth + 1,
378 )?))
379 } else {
380 None
381 };
382
383 let leaf_node = left_child.is_none() || right_child.is_none();
384
385 Ok(TreeNode {
386 feature_idx: best_feature_idx,
387 feature_name: sorted_indices[best_feature_idx].feature_name.to_owned(),
388 split_value: best_split_value,
389 impurity_decrease,
390 left_child,
391 right_child,
392 leaf_node,
393 prediction,
394 depth,
395 })
396 }
397
398 fn prune(&mut self) -> Option<L> {
404 if self.is_leaf() {
405 return Some(self.prediction.clone());
406 }
407
408 let left = self.left_child.as_mut().and_then(|x| x.prune());
409 let right = self.right_child.as_mut().and_then(|x| x.prune());
410
411 match (left, right) {
412 (Some(x), Some(y)) => {
413 if x == y {
414 self.prediction = x.clone();
415 self.right_child = None;
416 self.left_child = None;
417 self.leaf_node = true;
418
419 Some(x)
420 } else {
421 None
422 }
423 }
424 _ => None,
425 }
426 }
427}
428
429#[cfg_attr(
483 feature = "serde",
484 derive(Serialize, Deserialize),
485 serde(crate = "serde_crate")
486)]
487#[derive(Debug, Clone, PartialEq)]
488pub struct DecisionTree<F: Float, L: Label> {
489 root_node: TreeNode<F, L>,
490 num_features: usize,
491}
492
493impl<F: Float, L: Label + Default, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<L>>
494 for DecisionTree<F, L>
495{
496 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
498 assert_eq!(
499 x.nrows(),
500 y.len(),
501 "The number of data points must match the number of output targets."
502 );
503
504 for (row, target) in x.rows().into_iter().zip(y.iter_mut()) {
505 *target = make_prediction(&row, &self.root_node);
506 }
507 }
508
509 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
510 Array1::default(x.nrows())
511 }
512}
513
514impl<F: Float, L: Label + std::fmt::Debug, D, T> Fit<ArrayBase<D, Ix2>, T, Error>
515 for DecisionTreeValidParams<F, L>
516where
517 D: Data<Elem = F>,
518 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
519{
520 type Object = DecisionTree<F, L>;
521
522 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
525 let x = dataset.records();
526 let feature_names = if dataset.feature_names().is_empty() {
527 (0..x.nfeatures())
528 .map(|idx| format!("feature-{idx}"))
529 .collect()
530 } else {
531 dataset.feature_names().to_vec()
532 };
533 let all_idxs = RowMask::all(x.nrows());
534 let sorted_indices: Vec<_> = (0..(x.ncols()))
535 .map(|feature_idx| {
536 SortedIndex::of_array_column(x, feature_idx, &feature_names[feature_idx])
537 })
538 .collect();
539
540 let mut root_node = TreeNode::fit(dataset, &all_idxs, self, &sorted_indices, 0)?;
541 root_node.prune();
542
543 Ok(DecisionTree {
544 root_node,
545 num_features: dataset.records().ncols(),
546 })
547 }
548}
549
550impl<F: Float, L: Label> DecisionTree<F, L> {
551 pub fn iter_nodes(&self) -> NodeIter<F, L> {
553 let queue = vec![&self.root_node];
555
556 NodeIter::new(VecDeque::from(queue))
557 }
558
559 pub fn features(&self) -> Vec<usize> {
561 let mut fitted_features = HashSet::new();
563
564 for node in self.iter_nodes().filter(|node| !node.is_leaf()) {
565 if !fitted_features.contains(&node.feature_idx) {
566 fitted_features.insert(node.feature_idx);
567 }
568 }
569
570 fitted_features.into_iter().collect::<Vec<_>>()
571 }
572
573 pub fn mean_impurity_decrease(&self) -> Vec<F> {
575 let mut impurity_decrease = vec![F::zero(); self.num_features];
577 let mut num_nodes = vec![0; self.num_features];
578
579 for node in self.iter_nodes().filter(|node| !node.leaf_node) {
580 impurity_decrease[node.feature_idx] += node.impurity_decrease;
582 num_nodes[node.feature_idx] += 1;
583 }
584
585 impurity_decrease
586 .into_iter()
587 .zip(num_nodes)
588 .map(|(val, n)| if n == 0 { F::zero() } else { val / F::cast(n) })
589 .collect()
590 }
591
592 pub fn relative_impurity_decrease(&self) -> Vec<F> {
594 let mean_impurity_decrease = self.mean_impurity_decrease();
595 let sum = mean_impurity_decrease.iter().cloned().sum();
596
597 mean_impurity_decrease
598 .into_iter()
599 .map(|x| x / sum)
600 .collect()
601 }
602
603 pub fn feature_importance(&self) -> Vec<F> {
605 self.relative_impurity_decrease()
606 }
607
608 pub fn root_node(&self) -> &TreeNode<F, L> {
610 &self.root_node
611 }
612
613 pub fn max_depth(&self) -> usize {
615 self.iter_nodes()
616 .fold(0, |max, node| usize::max(max, node.depth))
617 }
618
619 pub fn num_leaves(&self) -> usize {
621 self.iter_nodes().filter(|node| node.is_leaf()).count()
622 }
623
624 pub fn export_to_tikz(&self) -> Tikz<F, L> {
631 Tikz::new(self)
632 }
633}
634
635fn make_prediction<F: Float, L: Label>(
637 x: &ArrayBase<impl Data<Elem = F>, Ix1>,
638 node: &TreeNode<F, L>,
639) -> L {
640 if node.leaf_node {
641 node.prediction.clone()
642 } else if x[node.feature_idx] < node.split_value {
643 make_prediction(x, node.left_child.as_ref().unwrap())
644 } else {
645 make_prediction(x, node.right_child.as_ref().unwrap())
646 }
647}
648
649fn find_modal_class<L: Label>(class_freq: &HashMap<L, f32>) -> L {
653 let val = class_freq
656 .iter()
657 .fold(None, |acc, (idx, freq)| match acc {
658 None => Some((idx, freq)),
659 Some((_best_idx, best_freq)) => {
660 if best_freq > freq {
661 acc
662 } else {
663 Some((idx, freq))
664 }
665 }
666 })
667 .unwrap()
668 .0;
669
670 (*val).clone()
671}
672
673fn gini_impurity<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
675 let n_samples = class_freq.values().sum::<f32>();
676 assert!(n_samples > 0.0);
677
678 let purity = class_freq
679 .values()
680 .map(|x| x / n_samples)
681 .map(|x| x * x)
682 .sum::<f32>();
683
684 1.0 - purity
685}
686
687fn entropy<L: Label>(class_freq: &HashMap<L, f32>) -> f32 {
689 let n_samples = class_freq.values().sum::<f32>();
690 assert!(n_samples > 0.0);
691
692 class_freq
693 .values()
694 .map(|x| x / n_samples)
695 .map(|x| if x > 0.0 { -x * x.log2() } else { 0.0 })
696 .sum()
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702
703 use approx::assert_abs_diff_eq;
704 use linfa::{error::Result, metrics::ToConfusionMatrix, Dataset, ParamGuard};
705 use ndarray::{array, concatenate, s, Array, Array1, Array2, Axis};
706 use rand::rngs::SmallRng;
707
708 use crate::DecisionTreeParams;
709 use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
710
711 #[test]
712 fn autotraits() {
713 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
714 has_autotraits::<DecisionTree<f64, bool>>();
715 has_autotraits::<TreeNode<f64, bool>>();
716 has_autotraits::<DecisionTreeValidParams<f64, bool>>();
717 has_autotraits::<DecisionTreeParams<f64, bool>>();
718 has_autotraits::<NodeIter<f64, bool>>();
719 has_autotraits::<Tikz<f64, bool>>();
720 }
721
722 #[test]
723 fn prediction_for_rows_example() {
724 let labels = Array::from(vec![0, 0, 0, 0, 0, 0, 1, 1]);
725 let row_mask = RowMask::all(labels.len());
726
727 let dataset: DatasetBase<(), Array1<usize>> = DatasetBase::new((), labels);
728 let class_freq = dataset.label_frequencies_with_mask(&row_mask.mask);
729
730 assert_eq!(find_modal_class(&class_freq), 0);
731 }
732
733 #[test]
734 fn gini_impurity_example() {
735 let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
736
737 assert_abs_diff_eq!(gini_impurity(&class_freq), 0.375, epsilon = 1e-5);
742 }
743
744 #[test]
745 fn entropy_example() {
746 let class_freq = vec![(0, 6.0), (1, 2.0), (2, 0.0)].into_iter().collect();
747
748 assert_abs_diff_eq!(entropy(&class_freq), 0.81127, epsilon = 1e-5);
753
754 let perfect_class_freq = vec![(0, 8.0), (1, 0.0), (2, 0.0)].into_iter().collect();
756
757 assert_abs_diff_eq!(entropy(&perfect_class_freq), 0.0, epsilon = 1e-5);
758 }
759
760 #[test]
761 fn single_feature_random_noise_binary() -> Result<()> {
767 let mut data = Array::random((50, 10), Uniform::new(-4., 4.));
769 data.slice_mut(s![.., 8]).assign(
770 &(0..50)
771 .map(|x| if x < 25 { 0.0 } else { 1.0 })
772 .collect::<Array1<_>>(),
773 );
774
775 let targets = (0..50).map(|x| x < 25).collect::<Array1<_>>();
776 let dataset = Dataset::new(data, targets);
777
778 let model = DecisionTree::params().max_depth(Some(2)).fit(&dataset)?;
779
780 assert_eq!(&model.features(), &[8]);
782
783 let ground_truth = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
784
785 for (imp, truth) in model.feature_importance().iter().zip(&ground_truth) {
786 assert_abs_diff_eq!(imp, truth, epsilon = 1e-15);
787 }
788
789 let cm = model
791 .predict(dataset.records())
792 .confusion_matrix(&dataset)?;
793 assert_abs_diff_eq!(cm.accuracy(), 1.0, epsilon = 1e-15);
794
795 Ok(())
796 }
797
798 #[test]
799 fn check_max_depth() -> Result<()> {
801 let mut rng = SmallRng::seed_from_u64(42);
802
803 let data = Array::random_using((50, 50), Uniform::new(-1., 1.), &mut rng);
805 let targets = (0..50).collect::<Array1<usize>>();
806
807 let dataset = Dataset::new(data, targets);
808
809 for max_depth in &[1, 5, 10, 20] {
811 let model = DecisionTree::params()
812 .max_depth(Some(*max_depth))
813 .min_impurity_decrease(1e-10f64)
814 .min_weight_split(1e-10)
815 .fit(&dataset)?;
816 assert_eq!(model.max_depth(), *max_depth);
817 }
818
819 Ok(())
820 }
821
822 #[test]
823 fn perfectly_separable_small() -> Result<()> {
827 let data = array![[1., 2., 3.], [1., 2., 4.], [1., 3., 3.5]];
828 let targets = array![0, 0, 1];
829
830 let dataset = Dataset::new(data.clone(), targets);
831 let model = DecisionTree::params().max_depth(Some(1)).fit(&dataset)?;
832
833 assert_eq!(model.predict(&data), array![0, 0, 1]);
834
835 Ok(())
836 }
837
838 #[test]
839 fn toy_dataset() -> Result<()> {
841 let data = array![
842 [0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 1.0, -14.0, 0.0, -4.0, 0.0, 0.0, 0.0, 0.0,],
843 [0.0, 0.0, 5.0, 3.0, 0.0, -4.0, 0.0, 0.0, 1.0, -5.0, 0.2, 0.0, 4.0, 1.0,],
844 [-1.0, -1.0, 0.0, 0.0, -4.5, 0.0, 0.0, 2.1, 1.0, 0.0, 0.0, -4.5, 0.0, 1.0,],
845 [-1.0, -1.0, 0.0, -1.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 1.0,],
846 [-1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,],
847 [-1.0, -2.0, 0.0, 4.0, -3.0, 10.0, 4.0, 0.0, -3.2, 0.0, 4.0, 3.0, -4.0, 1.0,],
848 [2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
849 [2.11, 0.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
850 [2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.0, 0.0, -2.0, 1.0,],
851 [2.11, 8.0, -6.0, -0.5, 0.0, 11.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, 0.0,],
852 [2.0, 8.0, 5.0, 1.0, 0.5, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 2.0, 0.0,],
853 [2.0, 0.0, 1.0, 1.0, 1.0, -1.0, 1.0, 0.0, 0.0, -2.0, 3.0, 0.0, 1.0, 0.0,],
854 [2.0, 0.0, 1.0, 2.0, 3.0, -1.0, 10.0, 2.0, 0.0, -1.0, 1.0, 2.0, 2.0, 0.0,],
855 [1.0, 1.0, 0.0, 2.0, 2.0, -1.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 3.0, 0.0,],
856 [3.0, 1.0, 0.0, 3.0, 0.0, -4.0, 10.0, 0.0, 1.0, -5.0, 3.0, 0.0, 3.0, 1.0,],
857 [2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -3.0, 1.0,],
858 [2.11, 8.0, -6.0, -0.5, 0.0, 1.0, 0.0, 0.0, -3.2, 6.0, 1.5, 1.0, -1.0, -1.0,],
859 [2.11, 8.0, -6.0, -0.5, 0.0, 10.0, 0.0, 0.0, -3.2, 6.0, 0.5, 0.0, -1.0, -1.0,],
860 [2.0, 0.0, 5.0, 1.0, 0.5, -2.0, 10.0, 0.0, 1.0, -5.0, 3.0, 1.0, 0.0, -1.0,],
861 [2.0, 0.0, 1.0, 1.0, 1.0, -2.0, 1.0, 0.0, 0.0, -2.0, 0.0, 0.0, 0.0, 1.0,],
862 [2.0, 1.0, 1.0, 1.0, 2.0, -1.0, 10.0, 2.0, 0.0, -1.0, 0.0, 2.0, 1.0, 1.0,],
863 [1.0, 1.0, 0.0, 0.0, 1.0, -3.0, 1.0, 2.0, 0.0, -5.0, 1.0, 2.0, 1.0, 1.0,],
864 [3.0, 1.0, 0.0, 1.0, 0.0, -4.0, 1.0, 0.0, 1.0, -2.0, 0.0, 0.0, 1.0, 0.0,]
865 ];
866
867 let targets = array![1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0];
868
869 let dataset = Dataset::new(data, targets);
870 let model = DecisionTree::params().fit(&dataset)?;
871 let prediction = model.predict(&dataset);
872
873 let cm = prediction.confusion_matrix(&dataset)?;
874 assert!(cm.accuracy() > 0.95);
875
876 Ok(())
877 }
878
879 #[test]
880 fn multilabel_four_uniform() -> Result<()> {
882 let mut data = concatenate(
883 Axis(0),
884 &[Array2::random((40, 2), Uniform::new(-1., 1.)).view()],
885 )
886 .unwrap();
887
888 data.outer_iter_mut().enumerate().for_each(|(i, mut p)| {
889 if i < 10 {
890 p += &array![-2., -2.]
891 } else if i < 20 {
892 p += &array![-2., 2.];
893 } else if i < 30 {
894 p += &array![2., -2.];
895 } else {
896 p += &array![2., 2.];
897 }
898 });
899
900 let targets = (0..40)
901 .map(|x| match x {
902 x if x < 10 => 0,
903 x if x < 20 => 1,
904 x if x < 30 => 2,
905 _ => 3,
906 })
907 .collect::<Array1<_>>();
908
909 let dataset = Dataset::new(data.clone(), targets);
910
911 let model = DecisionTree::params().fit(&dataset)?;
912 let prediction = model.predict(data);
913
914 let cm = prediction.confusion_matrix(&dataset)?;
915 assert!(cm.accuracy() > 0.99);
916
917 Ok(())
918 }
919
920 #[test]
921 #[should_panic]
922 fn panic_min_impurity_decrease() {
924 DecisionTree::<f64, bool>::params()
925 .min_impurity_decrease(0.0)
926 .check()
927 .unwrap();
928 }
929}