from dataclasses import dataclass, field
from typing import Optional, List, Any, Dict
import numpy as np
@dataclass
class HierarchyNode:
    node_id: int
    level: int
    parent: Optional['HierarchyNode'] = None
    children: List['HierarchyNode'] = field(default_factory=list)
    centroid: Optional[np.ndarray] = None
    expert: Optional[Any] = None
    def is_leaf(self) -> bool:
        return len(self.children) == 0
    def is_root(self) -> bool:
        return self.parent is None
class TreeNode:
    def __init__(
        self,
        node_id: int,
        level: int,
        centroid: np.ndarray
    ):
        self.node_id = node_id
        self.level = level
        self.centroid = centroid
        self.parent_id: Optional[int] = None
        self.child_ids: List[int] = []
        self.data_indices: Optional[np.ndarray] = None
        self.adapter: Optional[Any] = None
class HierarchyTree:
    def __init__(self, num_levels: int, branch_factor: int):
        self.num_levels = num_levels
        self.branch_factor = branch_factor
        self.root: Optional[HierarchyNode] = None
        self.levels: Dict[int, List[HierarchyNode]] = {}
    def get_leaf_nodes(self) -> List[HierarchyNode]:
        return [node for nodes in self.levels.values() for node in nodes if node.is_leaf()]
    def get_nodes_at_level(self, level: int) -> List[HierarchyNode]:
        return self.levels.get(level, [])
class BottomUpHierarchyTree:
    def __init__(self, num_levels: int, branch_factor: int):
        self.num_levels = num_levels
        self.branch_factor = branch_factor
        self.nodes: List[TreeNode] = []
        self.level_nodes: List[List[int]] = [[] for _ in range(num_levels)]
        self.root_id: Optional[int] = None
    def get_node(self, node_id: int) -> TreeNode:
        return self.nodes[node_id]
    def get_nodes_at_level(self, level: int) -> List[TreeNode]:
        return [self.nodes[nid] for nid in self.level_nodes[level]]
    def get_leaf_nodes(self) -> List[TreeNode]:
        leaf_level = self.num_levels - 1
        return [self.nodes[nid] for nid in self.level_nodes[leaf_level]]
