from typing import Hashable, Any


class Tree[T: Hashable, Node]:
    
    def __init__(self):

        self._nodes: dict[tuple[T, ...], Node] = {}
        self._children: dict[tuple[T, ...], dict[T, Node]] = {}

    @property
    def root(self):
        return self.get_node(())
    
    def parent_path(self, path: tuple[T, ...], loop_root: bool = False):
        if len(path) == 0:
            if loop_root:
                return path
            else:
                raise ValueError("a root has no parent.")
        else:
            return path[:-1]
    
    def get_parent(self, path: tuple[T, ...], loop_root: bool = False):
        try:
            path = self.parent_path(path, loop_root)
        except ValueError:
            return None
        return self.get_node(path)

    def __contains__(self, path: tuple[T, ...]):
        return path in self._nodes
    
    def get_node_attr[Default](
            self, 
            path: tuple[T, ...],
            name: str,
            default: Default = None
    ) -> Any | Default:
        
        node = self.get_node(path)
        return getattr(node, name, default)
    
    def set_node_attr[Default](
            self, 
            path: tuple[T, ...],
            name: str,
            default: Default = None
    ) -> Any | Default:
        
        node = self.get_node(path)
        return getattr(node, name, default)

    def __setitem__(self, path: tuple[T, ...], node: Node):
        if self.__contains__(path):
            self._nodes[path] = node

        else:
            if path == ():  # add root
                self._nodes[path] = node
                self._children[path] = {}
            else:
                # update parent
                parent_path = self.parent_path(path)
                if not self.__contains__(path):
                    raise KeyError(parent_path)
                self._children[parent_path][path[-1]] = node

                self._nodes[path] = node
                self._children[path] = {}
    
    def __delitem__(self, path: tuple[T, ...]):

        if not self.__contains__(path):
            raise KeyError(path)
        
        to_remove: set[tuple[T, ...]] = set()
        stack = [path]

        while stack:
            path = stack.pop()
            to_remove.add(path)
            for child in self._children[path].keys():
                stack.append(path + (child,))
        
        for path in to_remove:
            del self._nodes[path], self._children[path]

    def clear(self):
        self._nodes.clear()
        self._children.clear()

    def get_node[Default](self, path: tuple[T, ...], default: Default = None) -> Node | Default:
        return self._nodes.get(path, default)
    
    def get_children[Default](self, path: tuple[T, ...], default: Default = None) -> dict[T, Node] | Default:
        return self._children.get(path, default)
