class TreeNode:
    def __init__(self, dimension, threshold):
        self.dimension = dimension
        self.threshold = threshold
        self.left = None
        self.right = None
        self.parent = None
        self.subtree_size = 0   # Number of inner nodes in the subtree
                                # rooted at the node.
        self.depth = None

    def is_leaf(self):
        return False

    def is_inner(self):
        return True

    def to_string(self):
        return f"Inner node (dimension: {self.dimension}, threshold: {self.threshold})"

    def print(self):
        print(self.to_string())


class LeafNode:
    def __init__(self, class_label):
        self.class_label = class_label
        self.parent = None

    def is_leaf(self):
        return True

    def is_inner(self):
        return False

    def to_string(self):
        return f"Leaf node (parent: {self.parent.to_string()}, class: {self.class_label})"

    def print(self):
        print(self.to_string())


class Tree:
    def __init__(self):
        self.root = None
        self.inner_nodes = []
        self.leaves = []

    def insert_inner(self, dimension, threshold, parent = None, left = True):
        node = TreeNode(dimension, threshold)
        node.parent = parent
        if not self.root:
            if parent is not None:
                raise ValueError("Parent cannot be specified for root nodes")
            self.root = node
        else:
            if parent is None:
                raise ValueError("Parent must be specified for non-root nodes")
            if left:
                parent.left = node
            else:
                parent.right = node
        self.inner_nodes.append(node)
        return node

    def insert_leaf(self, parent, left = True, class_label = None):
        node = LeafNode(class_label)
        node.parent = parent
        if parent is None:
            raise ValueError("Parent must be specified for leaf nodes")
        if left:
            parent.left = node
        else:
            parent.right = node
        self.leaves.append(node)
        return node

    def post_order_traversal(self):
        result = []
        self._post_order_recursive(self.root, result)
        return result

    def _post_order_recursive(self, node, result):
        if node:
            if isinstance(node, LeafNode):
                result.append(node)
            else:
                self._post_order_recursive(node.left, result)
                self._post_order_recursive(node.right, result)
                result.append(node)

    def print_tree(self):
        def print_node(node, prefix="", is_left=True):
            if node is not None:
                if isinstance(node, LeafNode):
                    print(f"{prefix}{'|-- ' if is_left else '`-- '} class {node.class_label}")
                else:
                    print(f"{prefix}{'|-- ' if is_left else '`-- '}{node.dimension} <= {node.threshold}")
                    print_node(node.left, prefix + ("    " if not is_left else "|   "), True)
                    print_node(node.right, prefix + ("    " if not is_left else "|   "), False)

        print_node(self.root, "")

    def compute_subtree_sizes(self):
        """Compute for each node the number of inner nodes in its subtree."""
        self._compute_subtree_sizes_recursive(self.root)

    def _compute_subtree_sizes_recursive(self, node):
        if node is None:
            return 0

        if isinstance(node, LeafNode):
            # Leaf nodes have a subtree size of 0 (themselves)
            return 0
        else:
            # For inner nodes, recursively compute sizes of left and right subtrees
            left_size = self._compute_subtree_sizes_recursive(node.left)
            right_size = self._compute_subtree_sizes_recursive(node.right)
            
            # The size of this subtree is the sum of left and right subtrees, plus this node
            this_size = left_size + right_size + 1
            node.subtree_size = this_size
            return this_size

    def compute_node_depths(self):
        """Compute for each node v the number of nodes above v on a path to the root."""
        self._compute_node_depths(self.root, 0)

    def _compute_node_depths(self, node, current_depth):
        node.depth = current_depth

        if not isinstance(node, LeafNode):
            self._compute_node_depths(node.left, current_depth + 1)
            self._compute_node_depths(node.right, current_depth + 1)
