linfa_trees/decision_trees/
tikz.rs
1use super::{DecisionTree, TreeNode};
2use linfa::{Float, Label};
3use std::collections::HashSet;
4use std::fmt::Debug;
5
6#[derive(Debug, Clone, PartialEq)]
31pub struct Tikz<'a, F: Float, L: Label + Debug> {
32 legend: bool,
33 complete: bool,
34 tree: &'a DecisionTree<F, L>,
35}
36
37impl<'a, F: Float, L: Debug + Label> Tikz<'a, F, L> {
38 pub fn new(tree: &'a DecisionTree<F, L>) -> Self {
44 Tikz {
45 legend: false,
46 complete: true,
47 tree,
48 }
49 }
50
51 fn format_node(node: &'a TreeNode<F, L>) -> String {
52 let depth = vec![""; node.depth() + 1].join("\t");
53 if let Some(prediction) = node.prediction() {
54 format!("{depth}[Label: {prediction:?}]")
55 } else {
56 let (idx, value, impurity_decrease) = node.split();
57 let mut out = format!(
58 "{depth}[Val(${idx}$) $ \\leq {value:.2}$ \\\\ Imp. ${impurity_decrease:.2}$"
59 );
60 for child in node.children().into_iter().filter_map(|x| x.as_ref()) {
61 out.push('\n');
62 out.push_str(&Self::format_node(child));
63 }
64 out.push(']');
65
66 out
67 }
68 }
69
70 pub fn complete(mut self, complete: bool) -> Self {
72 self.complete = complete;
73
74 self
75 }
76
77 pub fn with_legend(mut self) -> Self {
79 self.legend = true;
80
81 self
82 }
83
84 fn legend(&self) -> String {
85 if self.legend {
86 let mut map = HashSet::new();
87 let mut out = "\n".to_string()
88 + r#"\node [anchor=north west] at (current bounding box.north east) {%
89 \begin{tabular}{c c c}
90 \multicolumn{3}{@{}l@{}}{Legend:}\\
91 Imp.&:&Impurity decrease\\"#;
92 for node in self.tree.iter_nodes() {
93 if !node.is_leaf() && !map.contains(&node.split().0) {
94 let var = format!(
95 "Var({})&:&{}\\\\",
96 node.split().0,
97 node.feature_name().unwrap_or(&"".to_string())
99 );
100 out.push_str(&var);
101 map.insert(node.split().0);
102 }
103 }
104 out.push_str("\\end{tabular}};");
105 out
106 } else {
107 "".to_string()
108 }
109 }
110}
111
112use std::fmt;
113
114impl<F: Float, L: Debug + Label> fmt::Display for Tikz<'_, F, L> {
115 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116 let mut out = if self.complete {
117 String::from(
118 r#"
119\documentclass[margin=10pt]{standalone}
120\usepackage{tikz,forest}
121\usetikzlibrary{arrows.meta}"#,
122 )
123 } else {
124 String::from("")
125 };
126 out.push_str(
127 r#"
128\forestset{
129default preamble={
130before typesetting nodes={
131 !r.replace by={[, coordinate, append]}
132},
133where n children=0{
134 tier=word,
135}{
136 %diamond, aspect=2,
137},
138where level=0{}{
139 if n=1{
140 edge label={node[pos=.2, above] {Y}},
141 }{
142 edge label={node[pos=.2, above] {N}},
143 }
144},
145for tree={
146 edge+={thick, -Latex},
147 s sep'+=2cm,
148 draw,
149 thick,
150 edge path'={ (!u) -| (.parent)},
151 align=center,
152}
153}
154}"#,
155 );
156
157 if self.complete {
158 out.push_str(r#"\begin{document}"#);
159 }
160 out.push_str(r#"\begin{forest}"#);
161
162 out.push_str(&Self::format_node(self.tree.root_node()));
163 out.push_str(&self.legend());
164 out.push_str("\n\t\\end{forest}\n");
165 if self.complete {
166 out.push_str("\\end{document}");
167 }
168
169 write!(f, "{out}")
170 }
171}