import matplotlib.pyplot as plt
import numpy as np
import graphviz
# Retina mode
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
Basics of Classes and Plotting Trees
ML
Tutorial
class TreeNode:
def __init__(self, name, value=None, shape='rectangle'):
self.name = name
self.value = value
self.children = []
self.shape = shape
def add_child(self, child_node):
self.children.append(child_node)
def display_tree_text(self, level=0):
= " " * level
indent print(f"{indent}|- {self.name}: {self.value}")
for child in self.children:
+ 1)
child.display_tree_text(level
def display_tree_graphviz(self, dot=None, parent_name=None, graph=None):
if graph is None:
= graphviz.Digraph(format='png')
graph str(id(self)), str(self.name), shape=self.shape)
graph.node(
if parent_name is not None:
str(id(parent_name)), str(id(self)))
graph.edge(
for child in self.children:
self, graph)
child.display_tree_graphviz(dot,
return graph
def display_tree_directly(self):
= self.display_tree_graphviz()
graph = graph.source
src format='png')) display(graphviz.Source(src,
# Creating nodes
= TreeNode("Root")
root = TreeNode("Child 1")
child1 = TreeNode("Child 2")
child2 = TreeNode("Child 3")
child3
# Building the tree structure
root.add_child(child1)
root.add_child(child2) child2.add_child(child3)
# Displaying the tree in text format
root.display_tree_text()
|- Root: None
|- Child 1: None
|- Child 2: None
|- Child 3: None
= root.display_tree_graphviz()
graph graph
class DecisionTreeNode:
def __init__(self, feature, threshold, decision=None, left=None, right=None, shape='box'):
self.feature = feature
self.threshold = threshold
self.decision = decision
self.left = left
self.right = right
self.shape = shape
def display_tree_graphviz(self, dot=None, parent_name=None, graph=None, edge_label=None):
if graph is None:
= graphviz.Digraph(format='png')
graph
= self.feature
node_label
if self.threshold is not None:
+= f" <= {self.threshold}"
node_label
if self.decision is not None:
+= f"\nDecision: {self.decision}"
node_label
str(id(self)), node_label, shape=self.shape)
graph.node(
if parent_name is not None:
if edge_label is not None:
str(id(parent_name)), str(id(self)), label=edge_label)
graph.edge(else:
str(id(parent_name)), str(id(self)))
graph.edge(
if self.left is not None:
self.left.display_tree_graphviz(dot, self, graph, edge_label="True")
if self.right is not None:
self.right.display_tree_graphviz(dot, self, graph, edge_label="False")
return graph
= DecisionTreeNode("Feature A", 5.0, decision=None)
root = DecisionTreeNode("Feature B", 3.0, decision=None)
left_child = DecisionTreeNode("Feature C", 8.0, decision=None)
right_child = left_child
root.left = right_child
root.right
= DecisionTreeNode("", None, decision = 20.0)
left_left = DecisionTreeNode("", None, decision = 10.0)
left_right
= left_left
left_child.left = left_right
left_child.right
= DecisionTreeNode("", None, decision = 30.0)
right_left = DecisionTreeNode("", None, decision = 40.0)
right_right
= right_left
right_child.left = right_right right_child.right
root.display_tree_graphviz()