from random import Random
from typing import Dict, Optional, List, Tuple
from collections import deque
import numpy as np

class TreeNode:
    __slots__ = (
        "node_number",
        "parent_id",
        "parent",
        "parent_distance",
        "children",
        "children_distances",
        "capacity",
        "facility_type",
        "facility_type_vector",
        "is_client",
        "is_leaf",
    )

    def __init__(self, node_number: int):
        # Structural identifiers
        self.node_number = node_number

        # Parent information (filled in by the tree-builder)
        self.parent_id: Optional[int] = None
        self.parent: Optional["TreeNode"] = None
        self.parent_distance: int = 0  # distance to the parent

        # Children and edge lengths (arbitrary arity)
        self.children: List["TreeNode"] = []
        self.children_distances: List[int] = []

        # Facility / client attributes
        self.capacity: int = 0           # >0 only for facility leaves
        self.facility_type: int = -1     # in {1,...,t} for facilities (just a label), -1 otherwise
        self.facility_type_vector: Tuple[int, ...] = tuple()
        self.is_client: bool = False

        # Convenience flag set by the builder
        self.is_leaf: bool = False


def get_depth(node: Optional[TreeNode]) -> int:
    """Return the height of the tree rooted at *node* (0 for empty tree)."""
    if node is None:
        return 0
    if not node.children:
        return 1
    return 1 + max(get_depth(child) for child in node.children)


def get_nodes_at_level(node: Optional[TreeNode], level: int) -> List[int]:
    """Return a list of node_numbers at the given depth (root at level 0).

    This helper is mainly for debugging / visualisation; it is not used by
    the DP implementation, so its cost does not affect the main algorithm.
    """
    if node is None:
        return []
    if level == 0:
        return [node.node_number]

    result: List[int] = []
    for child in node.children:
        result.extend(get_nodes_at_level(child, level - 1))
    return result


def get_node_by_number(root: Optional[TreeNode], target_number: int) -> Optional[TreeNode]:
    """Find the node with the given *target_number* using an explicit stack.

    This avoids Python's recursion depth limits and keeps the implementation
    iterative. Complexity is still O(n) in the worst case, which is fine
    for occasional debugging lookups.
    """
    if root is None:
        return None

    stack = [root]
    while stack:
        node = stack.pop()
        if node.node_number == target_number:
            return node
        # Push children in reverse order so child 0 is visited first.
        for child in reversed(node.children):
            stack.append(child)
    return None

def tree_pairwise_distances(root: TreeNode) -> Tuple[List[int], List[List[float]]]:
    """
    Compute pairwise distances between all nodes in a TreeNode tree.

    Parameters
    ----------
    root : TreeNode
        Root of the tree.

    Returns
    -------
    node_ids : List[int]
        List of node_number values, in the order used for the matrix.
    dist_matrix : List[List[float]]
        dist_matrix[i][j] is the distance between node_ids[i] and node_ids[j]
        along the tree (sum of edge lengths on the unique path).
    """
    if root is None:
        return [], []

    # ----------------------------------------------------------------------
    # 1. Collect all nodes reachable from the root
    # ----------------------------------------------------------------------
    nodes: List[TreeNode] = []
    stack = [root]
    seen = set()

    while stack:
        node = stack.pop()
        if node in seen:
            continue
        seen.add(node)
        nodes.append(node)
        for child in node.children:
            stack.append(child)

    # Map node_number -> index in matrix
    node_ids = [node.node_number for node in nodes]
    id_to_idx: Dict[int, int] = {nid: i for i, nid in enumerate(node_ids)}
    n = len(nodes)

    # ----------------------------------------------------------------------
    # 2. Build an undirected adjacency list using children + children_distances
    # ----------------------------------------------------------------------
    adj: Dict[int, List[Tuple[int, float]]] = {nid: [] for nid in node_ids}

    for node in nodes:
        u = node.node_number
        for child, w in zip(node.children, node.children_distances):
            v = child.node_number
            w = float(w)
            adj[u].append((v, w))
            adj.setdefault(v, []).append((u, w))  # undirected edge

    # ----------------------------------------------------------------------
    # 3. For each node, BFS to compute distance to all others
    # ----------------------------------------------------------------------
    # Initialise n x n matrix with zeros
    dist_matrix: List[List[float]] = [[0.0] * n for _ in range(n)]

    for i, u in enumerate(node_ids):
        # BFS from u
        q = deque([u])
        dist: Dict[int, float] = {u: 0.0}

        while q:
            x = q.popleft()
            for y, w in adj.get(x, []):
                if y not in dist:
                    dist[y] = dist[x] + w
                    q.append(y)

        # Fill row i of the matrix
        for v, d in dist.items():
            j = id_to_idx[v]
            dist_matrix[i][j] = d

    dist_matrix_np = np.zeros((n, n), dtype=float)
    # for i in range(n):
    #     for j in range(n):
    #         dist_matrix_np[node_ids[i]][node_ids[j]] = dist_matrix[i][j]
    node_ids_np = np.asarray(node_ids, dtype=int)
    dist_matrix_np[np.ix_(node_ids_np, node_ids_np)] = np.asarray(dist_matrix, dtype=float)

    return node_ids, np.array(dist_matrix_np)

###############################################################################
def build_tree_with_facilities(
    n: int,
    t: int,
    facility_probability: float = 0.5,
    max_facility_types: int = 1,  # kept for API compatibility; not used
    seed: int = 12345,
) -> Optional[TreeNode]:
    """Build a general rooted tree with *n* nodes and randomly placed facilities.

    Each leaf independently becomes either:
      - a client (probability 1 - facility_probability), or
      - a facility (probability facility_probability).

    For facilities:
      - capacity is chosen uniformly from {1, ..., 10};
      - a nonempty subset of the t groups is chosen,
        so each facility may belong to one or more groups.

    The group membership is encoded as a 0/1 vector of length t in
    `facility_type_vector`. Multiple 1's mean the facility is in
    several (possibly intersecting) groups.

    The tree structure itself is arbitrary-arity: every node may have
    zero or more children. This is convenient when working with tree
    metrics before converting to a binary tree using dummy nodes.

    Parameters
    ----------
    n : int
        Total number of nodes in the tree.
    t : int
        Number of groups.
    facility_probability : float
        Probability that a leaf becomes a facility (otherwise it is a client).
    max_facility_types : int
        Unused in this version (multi-group facilities); kept for API compat.
    seed : int
        Seed for the RNG so experiments are reproducible.

    Returns
    -------
    TreeNode or None
        The root of the generated tree, or None if n <= 0.
    """
    if n <= 0:
        return None

    l_random = Random()
    l_random.seed(seed)

    # ------------------------------------------------------------------
    # 1. Create all nodes
    # ------------------------------------------------------------------
    nodes: List[TreeNode] = [TreeNode(i) for i in range(0, n)]

    # ------------------------------------------------------------------
    # 2. Random rooted tree structure with arbitrary number of children.
    #
    #    We connect node i (2..n) to a random parent in {1, ..., i-1}
    #    which yields a random labelled rooted tree with arbitrary degrees.
    #    Edge lengths are drawn from {0, ..., 10}.
    # ------------------------------------------------------------------
    root = nodes[0]
    root.parent = None
    root.parent_id = None
    root.parent_distance = 0

    for i in range(1, n):
        node = nodes[i]
        # Choose a random parent among previous nodes (1..i)
        parent_index = l_random.randint(0, i - 1)
        parent = nodes[parent_index]

        node.parent = parent
        node.parent_id = parent.node_number
        dist = l_random.randint(0, 10)
        node.parent_distance = dist

        parent.children.append(node)
        parent.children_distances.append(dist)

    # After all children are linked, mark leaf nodes once.
    for node in nodes:
        node.is_leaf = (len(node.children) == 0)

    # ------------------------------------------------------------------
    # 3. For each leaf, randomly decide if it is a facility or a client
    #    and assign capacities / group labels. Internal nodes remain
    #    purely structural.
    # ------------------------------------------------------------------
    for node in nodes:
        if not node.is_leaf:
            continue

        if l_random.random() < facility_probability:
            # Facility leaf
            node.capacity = l_random.randint(1, 10)

            # Choose a random nonempty subset of {0, ..., t-1} as group indices.
            # This allows multi-group membership, i.e., intersecting groups.
            num_groups = l_random.randint(1, t)
            group_indices = l_random.sample(range(t), num_groups)

            # For convenience, we store one representative group index (1-based).
            node.facility_type = group_indices[0] + 1

            # Multi-group membership encoded as a 0/1 vector of length t.
            membership = [0] * t
            for gi in group_indices:
                membership[gi] = 1
            node.facility_type_vector = tuple(membership)

            node.is_client = False
        else:
            # Client leaf
            node.is_client = True
            node.capacity = 0
            node.facility_type = -1
            # Clients have no group membership; use the all-zero vector.
            node.facility_type_vector = tuple(0 for _ in range(t))

    nodes_dict: Dict[int, TreeNode] = {node.node_number: node for node in nodes}
    return root, nodes_dict

###############################################################################
# Helper routines for pretty-printing / debugging
###############################################################################
def print_tree(
    node: Optional[TreeNode],
    level: int = 0,
    prefix: str = "Root: ",
) -> None:
    """Print the tree sideways (root at left, children to the right).

    Facilities are shown with their group membership vector and capacity.
    Clients are shown with `[C]`.
    """
    if node is None:
        return

    # Print children in reverse order so that child 0 appears at the bottom.
    for idx in range(len(node.children) - 1, -1, -1):
        child = node.children[idx]
        print_tree(child, level + 1, prefix=f"C{idx}--- ")

    # Describe this node
    desc = f"{node.node_number}"
    if node.is_client:
        desc += " [C]"
    else:
        desc += f" [F groups={node.facility_type_vector} cap={node.capacity}]"

    print("     " * level + prefix + desc)


def print_tree_node(node: Optional[TreeNode]) -> None:
    """Print all stored attributes of a single node (for debugging)."""
    if node is None:
        print("None")
        return

    child_ids = [child.node_number for child in node.children]
    details = [
        "---------Tree Node Details---------",
        f"Node Number       : {node.node_number}",
        f"Parent Node       : {node.parent.node_number if node.parent else 'None'}",
        f"Parent Distance   : {node.parent_distance}",
        f"Children          : {child_ids}",
        f"Children Distances: {node.children_distances}",
        f"Capacity          : {node.capacity}",
        f"Facility Type     : {node.facility_type}",
        f"Type Vector       : {node.facility_type_vector}",
        f"Is Client         : {node.is_client}",
        f"Is Leaf           : {node.is_leaf}",
        "-----------------------------------",
    ]
    print("\n".join(details))


###############################################################################
def main() -> None:
    n = 10
    t = 3
    facility_probability = 0.5
    seed = 123456
    root = build_tree_with_facilities(
        n, t, facility_probability=facility_probability, seed=seed
    )
    print_tree(root)


if __name__ == "__main__":
    main()
