linfa_trees/decision_trees/
tikz.rs

1use super::{DecisionTree, TreeNode};
2use linfa::{Float, Label};
3use std::collections::HashSet;
4use std::fmt::Debug;
5
6/// Struct to print a fitted decision tree in Tex using tikz and forest.
7///
8/// There are two settable parameters:
9///
10/// * `legend`: if true, a box with the names of the split features will appear in the top right of the tree
11/// * `complete`: if true, a complete and standalone Tex document will be generated; otherwise the result will an embeddable Tex tree.
12///
13/// ### Usage
14///
15/// ```rust
16/// use linfa::prelude::*;
17/// use linfa_datasets;
18/// use linfa_trees::DecisionTree;
19///
20/// // Load dataset
21/// let dataset = linfa_datasets::iris();
22/// // Fit the tree
23/// let tree = DecisionTree::params().fit(&dataset).unwrap();
24/// // Export to tikz
25/// let tikz = tree.export_to_tikz().with_legend();
26/// let latex_tree = tikz.to_string();
27/// // Now you can write latex_tree to the preferred destination
28///
29/// ```
30#[derive(Debug, Clone, PartialEq)]
31pub struct Tikz<'a, F: Float, L: Label + Debug> {
32    legend: bool,
33    complete: bool,
34    tree: &'a DecisionTree<F, L>,
35}
36
37impl<'a, F: Float, L: Debug + Label> Tikz<'a, F, L> {
38    /// Creates a new Tikz structure for the decision tree
39    /// with the following default parameters:
40    ///
41    /// * `legend=false`
42    /// * `complete=true`
43    pub fn new(tree: &'a DecisionTree<F, L>) -> Self {
44        Tikz {
45            legend: false,
46            complete: true,
47            tree,
48        }
49    }
50
51    fn format_node(node: &'a TreeNode<F, L>) -> String {
52        let depth = vec![""; node.depth() + 1].join("\t");
53        if let Some(prediction) = node.prediction() {
54            format!("{depth}[Label: {prediction:?}]")
55        } else {
56            let (idx, value, impurity_decrease) = node.split();
57            let mut out = format!(
58                "{depth}[Val(${idx}$) $ \\leq {value:.2}$ \\\\ Imp. ${impurity_decrease:.2}$"
59            );
60            for child in node.children().into_iter().filter_map(|x| x.as_ref()) {
61                out.push('\n');
62                out.push_str(&Self::format_node(child));
63            }
64            out.push(']');
65
66            out
67        }
68    }
69
70    /// Whether a complete Tex document should be generated
71    pub fn complete(mut self, complete: bool) -> Self {
72        self.complete = complete;
73
74        self
75    }
76
77    /// Add a legend to the generated tree
78    pub fn with_legend(mut self) -> Self {
79        self.legend = true;
80
81        self
82    }
83
84    fn legend(&self) -> String {
85        if self.legend {
86            let mut map = HashSet::new();
87            let mut out = "\n".to_string()
88                + r#"\node [anchor=north west] at (current bounding box.north east) {%
89                \begin{tabular}{c c c}
90                  \multicolumn{3}{@{}l@{}}{Legend:}\\
91                  Imp.&:&Impurity decrease\\"#;
92            for node in self.tree.iter_nodes() {
93                if !node.is_leaf() && !map.contains(&node.split().0) {
94                    let var = format!(
95                        "Var({})&:&{}\\\\",
96                        node.split().0,
97                        // TODO:: why use lengend if there are no feature names? Should it be allowed?
98                        node.feature_name().unwrap_or(&"".to_string())
99                    );
100                    out.push_str(&var);
101                    map.insert(node.split().0);
102                }
103            }
104            out.push_str("\\end{tabular}};");
105            out
106        } else {
107            "".to_string()
108        }
109    }
110}
111
112use std::fmt;
113
114impl<F: Float, L: Debug + Label> fmt::Display for Tikz<'_, F, L> {
115    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116        let mut out = if self.complete {
117            String::from(
118                r#"
119\documentclass[margin=10pt]{standalone}
120\usepackage{tikz,forest}
121\usetikzlibrary{arrows.meta}"#,
122            )
123        } else {
124            String::from("")
125        };
126        out.push_str(
127            r#"
128\forestset{
129default preamble={
130before typesetting nodes={
131  !r.replace by={[, coordinate, append]}
132},  
133where n children=0{
134  tier=word,
135}{  
136  %diamond, aspect=2,
137},  
138where level=0{}{
139  if n=1{
140    edge label={node[pos=.2, above] {Y}},
141  }{  
142    edge label={node[pos=.2, above] {N}},
143  }   
144},  
145for tree={
146  edge+={thick, -Latex},
147  s sep'+=2cm,
148  draw,
149  thick,
150  edge path'={ (!u) -| (.parent)},
151  align=center,
152}   
153}
154}"#,
155        );
156
157        if self.complete {
158            out.push_str(r#"\begin{document}"#);
159        }
160        out.push_str(r#"\begin{forest}"#);
161
162        out.push_str(&Self::format_node(self.tree.root_node()));
163        out.push_str(&self.legend());
164        out.push_str("\n\t\\end{forest}\n");
165        if self.complete {
166            out.push_str("\\end{document}");
167        }
168
169        write!(f, "{out}")
170    }
171}