from random import Random
from typing import Optional, List, Tuple, Dict
from collections import deque
import numpy as np

class BinTreeNode:
    __slots__ = (
        "node_number",
        "parent_id",
        "parent",
        "parent_distance",
        "left",
        "left_distance",
        "right",
        "right_distance",
        "capacity",
        "facility_type",
        "facility_type_vector",
        "num_clients",
        "is_client",
        "is_leaf",
    )

    def __init__(self, node_number: int, left: "BinTreeNode" = None, right: "BinTreeNode" = None):
        # Structural identifiers
        self.node_number = node_number

        # Parent information (filled in by the tree-builder)
        self.parent_id: Optional[int] = None
        self.parent: Optional["BinTreeNode"] = None
        self.parent_distance: int = 0  # distance to the parent

        # Children and edge lengths
        self.left: Optional["BinTreeNode"] = left
        self.left_distance: int = 0
        self.right: Optional["BinTreeNode"] = right
        self.right_distance: int = 0

        # Facility / client attributes
        self.capacity: int = 0           # >0 only for facility leaves
        self.facility_type: int = -1     # in {1,...,t} for facilities (just a label), -1 otherwise
        self.facility_type_vector: Tuple[int, ...] = tuple()
        self.is_client: bool = False

        # Convenience flag set by the builder
        self.is_leaf: bool = False
# end BinTreeNode


def get_depth(node: Optional[BinTreeNode]) -> int:
    """Return the height of the tree rooted at *node* (0 for empty tree)."""
    if node is None:
        return 0
    return 1 + max(get_depth(node.left), get_depth(node.right))
# end get_depth


def get_nodes_at_level(node: Optional[BinTreeNode], level: int) -> List[int]:
    """Return a list of node_numbers at the given depth (root at level 0).

    This helper is mainly for debugging / visualisation; it is not used by
    the DP implementation, so its cost does not affect the main algorithm.
    """
    if node is None:
        return []
    if level == 0:
        return [node.node_number]

    result: List[int] = []
    result.extend(get_nodes_at_level(node.left, level - 1))
    result.extend(get_nodes_at_level(node.right, level - 1))
    return result
# end get_nodes_at_level


def get_node_by_number(root: Optional[BinTreeNode], target_number: int) -> Optional[BinTreeNode]:
    """Find the node with the given *target_number* using an explicit stack.

    This avoids Python's recursion depth limits and keeps the implementation
    iterative. Complexity is still O(n) in the worst case, which is fine
    for occasional debugging lookups. The main DP never calls this helper.
    """
    if root is None:
        return None

    stack = [root]
    while stack:
        node = stack.pop()
        if node.node_number == target_number:
            return node
        if node.right is not None:
            stack.append(node.right)
        if node.left is not None:
            stack.append(node.left)
    return None
# end get_node_by_number

def binary_tree_pairwise_distances_old(root: BinTreeNode) -> Tuple[List[int], List[List[float]]]:
    """
    Compute pairwise distances between all nodes in a BinTreeNode tree.

    Parameters
    ----------
    root : BinTreeNode
        Root of the binary tree.

    Returns
    -------
    node_ids : List[int]
        List of node_number values, in the order used for the matrix.
    dist_matrix : List[List[float]]
        dist_matrix[i][j] is the distance between node_ids[i] and node_ids[j]
        along the tree (sum of edge lengths on the unique path).
    """
    if root is None:
        return [], []

    # ----------------------------------------------------------------------
    # 1. Collect all nodes reachable from the root
    # ----------------------------------------------------------------------
    nodes: List[BinTreeNode] = []
    stack = [root]
    seen = set()

    while stack:
        node = stack.pop()
        if node in seen:
            continue
        seen.add(node)
        nodes.append(node)
        if node.left is not None:
            stack.append(node.left)
        if node.right is not None:
            stack.append(node.right)

    # Map node_number -> index in matrix
    node_ids = [node.node_number for node in nodes]
    id_to_idx: Dict[int, int] = {nid: i for i, nid in enumerate(node_ids)}
    n = len(nodes)

    # ----------------------------------------------------------------------
    # 2. Build an undirected adjacency list using parent/children pointers
    # ----------------------------------------------------------------------
    adj: Dict[int, List[Tuple[int, float]]] = {nid: [] for nid in node_ids}

    for node in nodes:
        u = node.node_number

        # Edge to parent
        if node.parent is not None:
            v = node.parent.node_number
            w = float(node.parent_distance)
            adj[u].append((v, w))
            # We'll also add the reverse when we iterate over the parent node

        # Edge to left child
        if node.left is not None:
            v = node.left.node_number
            w = float(node.left_distance)
            adj[u].append((v, w))
            if v not in adj:
                adj[v] = []

        # Edge to right child
        if node.right is not None:
            v = node.right.node_number
            w = float(node.right_distance)
            adj[u].append((v, w))
            if v not in adj:
                adj[v] = []

    # ----------------------------------------------------------------------
    # 3. For each node, BFS to compute distance to all others
    # ----------------------------------------------------------------------
    dist_matrix: List[List[float]] = [[0.0] * n for _ in range(n)]

    for i, u in enumerate(node_ids):
        q = deque([u])
        dist: Dict[int, float] = {u: 0.0}

        while q:
            x = q.popleft()
            for y, w in adj.get(x, []):
                if y not in dist:
                    dist[y] = dist[x] + w
                    q.append(y)

        # Fill row i of the matrix
        for v, d in dist.items():
            j = id_to_idx[v]
            dist_matrix[i][j] = d

    dist_matrix_np = np.ndarray(shape = (n, n), dtype=float)
    # for i in range(n):
    #    for j in range(n):
    #        dist_matrix_np[node_ids[i]][node_ids[j]] = dist_matrix[i][j]
    node_ids_np = np.asarray(node_ids, dtype=int)
    dist_matrix_np[np.ix_(node_ids_np, node_ids_np)] = np.asarray(dist_matrix, dtype=float)

    return node_ids, np.array(dist_matrix_np)

def binary_tree_pairwise_distances(root: BinTreeNode) -> Tuple[List[int], List[List[float]]]:
    """
    Compute pairwise distances between all nodes in a BinTreeNode tree.

    The distance between two nodes is the sum of edge lengths on the unique
    path between them. For consistency with the dynamic program and the
    brute-force solver, we treat each parent–child edge as having a single
    length equal to the *child* node's ``parent_distance`` and use this as
    an undirected edge weight.

    Parameters
    ----------
    root : BinTreeNode
        Root of the binary tree.

    Returns
    -------
    node_ids : List[int]
        List of node_number values, in the order used for the matrix.
    dist_matrix : List[List[float]]
        dist_matrix[i][j] is the distance between node_ids[i] and node_ids[j].
    """
    if root is None:
        return [], []

    # 1. Collect all nodes
    nodes: List[BinTreeNode] = []
    stack = [root]
    while stack:
        node = stack.pop()
        nodes.append(node)
        if node.left is not None:
            stack.append(node.left)
        if node.right is not None:
            stack.append(node.right)

    node_ids = [node.node_number for node in nodes]
    id_to_idx: Dict[int, int] = {nid: i for i, nid in enumerate(node_ids)}
    n = len(nodes)

    # 2. Build an undirected adjacency list using only parent_distance
    adj: Dict[int, List[Tuple[int, float]]] = {nid: [] for nid in node_ids}
    for node in nodes:
        if node.parent is not None:
            u = node.node_number
            v = node.parent.node_number
            w = float(node.parent_distance)
            adj[u].append((v, w))
            adj[v].append((u, w))

    # 3. For each node, BFS to get distances to all others
    from collections import deque
    dist_matrix: List[List[float]] = [[0.0] * n for _ in range(n)]

    for i, u in enumerate(node_ids):
        q = deque([u])
        dist: Dict[int, float] = {u: 0.0}
        while q:
            x = q.popleft()
            for y, w in adj.get(x, []):
                if y not in dist:
                    dist[y] = dist[x] + w
                    q.append(y)
        for v, d in dist.items():
            j = id_to_idx[v]
            dist_matrix[i][j] = d

    return node_ids, dist_matrix

def binary_tree_pairwise_distances_node_num(root: BinTreeNode) -> Tuple[List[int], List[List[float]]]:
    """
    Compute pairwise distances between all nodes in a BinTreeNode tree.

    The distance between two nodes is the sum of edge lengths on the unique
    path between them. For consistency with the dynamic program and the
    brute-force solver, we treat each parent–child edge as having a single
    length equal to the *child* node's ``parent_distance`` and use this as
    an undirected edge weight.

    Parameters
    ----------
    root : BinTreeNode
        Root of the binary tree.

    Returns
    -------
    node_ids : List[int]
        List of node_number values.
    dist_matrix : np.ndarray
        A square matrix such that ``dist_matrix[u, v]`` is the distance
        between nodes with ``node_number`` u and v.  This mirrors the
        convention used in ``tree_pairwise_distances`` for general trees.
    """
    if root is None:
        return [], []

    # 1. Collect all nodes reachable from the root.
    nodes: List[BinTreeNode] = []
    stack = [root]
    while stack:
        node = stack.pop()
        nodes.append(node)
        if node.left is not None:
            stack.append(node.left)
        if node.right is not None:
            stack.append(node.right)

    node_ids = [node.node_number for node in nodes]
    id_to_idx: Dict[int, int] = {nid: i for i, nid in enumerate(node_ids)}
    n = len(nodes)

    # 2. Build an undirected adjacency list using ``parent_distance`` only.
    #    Each edge weight is the ``parent_distance`` of the child node.
    adj: Dict[int, List[Tuple[int, float]]] = {nid: [] for nid in node_ids}
    for node in nodes:
        if node.parent is not None:
            u = node.node_number
            v = node.parent.node_number
            w = float(node.parent_distance)
            adj[u].append((v, w))
            adj[v].append((u, w))

    # 3. For each node, BFS to get distances to all others.
    from collections import deque

    dist_matrix_idx: List[List[float]] = [[0.0] * n for _ in range(n)]
    for i, u in enumerate(node_ids):
        q = deque([u])
        dist: Dict[int, float] = {u: 0.0}
        while q:
            x = q.popleft()
            for y, w in adj.get(x, []):
                if y not in dist:
                    dist[y] = dist[x] + w
                    q.append(y)
        for v, d in dist.items():
            j = id_to_idx[v]
            dist_matrix_idx[i][j] = d

    # 4. Re-index the matrix so that rows/cols are indexed by node_number,
    #    just like in ``tree_pairwise_distances``.
    import numpy as np

    dist_matrix_np = np.zeros((n, n), dtype=float)
    node_ids_np = np.asarray(node_ids, dtype=int)
    dist_matrix_np[np.ix_(node_ids_np, node_ids_np)] = np.asarray(dist_matrix_idx, dtype=float)

    return node_ids, dist_matrix_np

def get_client_and_facility_node_ids(root) -> Tuple[List[int], List[int]]:
    """
    Traverse the binary tree and return the node_numbers of all clients
    and facilities.

    Parameters
    ----------
    root : TreeNode / BinTreeNode
        Root of the binary tree. Nodes are expected to have:
        - node.node_number : int
        - node.is_client   : bool
        - node.capacity    : int ( > 0 for facilities, 0 otherwise )
        - node.left, node.right : child pointers

    Returns
    -------
    client_ids : List[int]
        node_number values of all nodes that are clients.
    facility_ids : List[int]
        node_number values of all nodes that are facilities.
    """
    if root is None:
        return [], []

    client_ids: List[int] = []
    facility_ids: List[int] = []

    stack = [root]
    visited = set()

    while stack:
        node = stack.pop()
        if node in visited:
            continue
        visited.add(node)

        # Classify node
        if getattr(node, "is_client", False):
            client_ids.append(node.node_number)
        elif getattr(node, "capacity", 0) > 0:
            # Non-client with positive capacity → facility
            facility_ids.append(node.node_number)

        # Traverse children
        if getattr(node, "left", None) is not None:
            stack.append(node.left)
        if getattr(node, "right", None) is not None:
            stack.append(node.right)

    return client_ids, facility_ids

def get_facility_groups(root) -> Dict[int, Tuple[int, ...]]:
    """
    Traverse the binary tree and build a dictionary mapping
    facility node_number -> group membership vector.

    A node is treated as a facility if:
      - node.is_client is False, and
      - node.capacity > 0

    The group membership vector is taken from node.facility_type_vector.

    Parameters
    ----------
    root : BinTreeNode / TreeNode
        Root of the binary tree. Nodes are expected to have:
        - node.node_number : int
        - node.is_client   : bool
        - node.capacity    : int
        - node.facility_type_vector : tuple or list of ints
        - node.left, node.right : child pointers

    Returns
    -------
    facility_groups : Dict[int, Tuple[int, ...]]
        Dictionary where:
          key   = facility node_number
          value = its facility_type_vector (as a tuple)
    """
    if root is None:
        return {}

    facility_groups: Dict[int, Tuple[int, ...]] = {}

    stack: List = [root]
    while stack:
        node = stack.pop()

        # Check if this node is a facility (not a client, positive capacity)
        is_client = getattr(node, "is_client", False)
        capacity = getattr(node, "capacity", 0)

        if (not is_client) and capacity > 0:
            vec = getattr(node, "facility_type_vector", ())
            # Ensure it is stored as a tuple
            facility_groups[node.node_number] = tuple(vec)

        # Traverse children
        if getattr(node, "left", None) is not None:
            stack.append(node.left)
        if getattr(node, "right", None) is not None:
            stack.append(node.right)

    return facility_groups

def get_facility_capacities(root) -> Dict[int, Tuple[int, ...]]:
    """
    Traverse the binary tree and build a dictionary mapping
    facility node_number -> group membership vector.

    A node is treated as a facility if:
      - node.is_client is False, and
      - node.capacity > 0

    The group membership vector is taken from node.facility_type_vector.

    Parameters
    ----------
    root : BinTreeNode / TreeNode
        Root of the binary tree. Nodes are expected to have:
        - node.node_number : int
        - node.is_client   : bool
        - node.capacity    : int
        - node.facility_type_vector : tuple or list of ints
        - node.left, node.right : child pointers

    Returns
    -------
    facility_groups : Dict[int, Tuple[int, ...]]
        Dictionary where:
          key   = facility node_number
          value = its facility_type_vector (as a tuple)
    """
    if root is None:
        return {}

    facility_caps: Dict[int, Tuple[int, ...]] = {}

    stack: List = [root]
    while stack:
        node = stack.pop()

        # Check if this node is a facility (not a client, positive capacity)
        is_client = getattr(node, "is_client", False)
        capacity = getattr(node, "capacity", 0)

        if (not is_client) and capacity > 0:
            # Ensure it is stored as a tuple
            facility_caps[node.node_number] = capacity

        # Traverse children
        if getattr(node, "left", None) is not None:
            stack.append(node.left)
        if getattr(node, "right", None) is not None:
            stack.append(node.right)

    return facility_caps

def get_binary_tree_stats(root: BinTreeNode) -> Dict[str, int]:
    """
    Compute basic statistics about the binary tree.

    Parameters
    ----------
    root : BinTreeNode
        Root of the binary tree.

    Returns
    -------
    stats : Dict[str, int]
        Dictionary with keys:
          - 'num_nodes'      : total number of nodes
          - 'num_clients'    : number of client leaves
          - 'num_facilities' : number of facility leaves
          - 'depth'          : depth of the tree
    """
    if root is None:
        return {
            'num_nodes': 0,
            'num_clients': 0,
            'num_facilities': 0,
            'total_capacity': 0,
            'depth': 0,
        }

    num_nodes = 0
    num_clients = 0
    num_facilities = 0
    total_capacity = 0

    stack = [root]
    while stack:
        node = stack.pop()
        num_nodes += 1

        if node.is_leaf:
            if node.is_client:
                num_clients += 1
            else:
                num_facilities += 1
                total_capacity += getattr(node, "capacity", 0)

        if node.left is not None:
            stack.append(node.left)
        if node.right is not None:
            stack.append(node.right)

    depth = get_depth(root)

    return {
        'num_nodes': num_nodes,
        'num_clients': num_clients,
        'num_facilities': num_facilities,
        'total_capacity': total_capacity,
        'depth': depth,
    }

###############################################################################
def build_binary_tree_with_facilities(
    n: int,
    t: int,
    facility_probability: float = 0.5,
    max_capacity: int = 10,
    seed: int = 12345,
) -> Optional[BinTreeNode]:
    """Build a complete binary tree with *n* nodes and randomly placed facilities.

    Each leaf independently becomes either:
      - a client (probability 1 - facility_probability), or
      - a facility (probability facility_probability).

    For facilities:
      - capacity is chosen uniformly from {1, ..., 10};
      - a nonempty subset of the t groups is chosen,
        so each facility may belong to one or more groups.

    The group membership is encoded as a 0/1 vector of length t in
    `facility_type_vector`. Multiple 1's mean the facility is in
    several (possibly intersecting) groups, matching the model in the
    capacitated fair-range clustering paper. 

    Parameters
    ----------
    n : int
        Total number of nodes in the (complete) binary tree.
    t : int
        Number of groups.
    facility_probability : float
        Probability that a leaf becomes a facility (otherwise it is a client).
    max_facility_types : int
        Unused in this version (multi-group facilities); kept for API compat.
    seed : int
        Seed for the RNG so experiments are reproducible.

    Returns
    -------
    BinTreeNode or None
        The root of the generated tree, or None if n <= 0.
    """
    if n <= 0:
        return None

    local_random = Random()
    local_random.seed(seed)

    # ------------------------------------------------------------------
    # 1. Create all nodes
    # ------------------------------------------------------------------
    nodes: List[BinTreeNode] = [BinTreeNode(i) for i in range(0, n)]

    # ------------------------------------------------------------------
    # 2. Link parent / children pointers and assign random edge lengths.
    #    The shape is a complete binary tree in array layout.
    # ------------------------------------------------------------------
    for i, node in enumerate(nodes):
        left_index = 2 * i + 1
        right_index = 2 * i + 2

        # Parent information
        if i == 0:
            node.parent_id = None
            node.parent = None
            node.parent_distance = 0
        else:
            parent_index = (i - 1) // 2
            parent_node = nodes[parent_index]
            node.parent = parent_node
            node.parent_id = parent_node.node_number
            node.parent_distance = local_random.randint(0, 10)

        # Left child
        if left_index < n:
            child = nodes[left_index]
            node.left = child
            node.left_distance = local_random.randint(0, 10)

        # Right child
        if right_index < n:
            child = nodes[right_index]
            node.right = child
            node.right_distance = local_random.randint(0, 10)

    # After all children are linked, mark leaf nodes once.
    for node in nodes:
        node.is_leaf = node.left is None and node.right is None

    # ------------------------------------------------------------------
    # 3. For each leaf, randomly decide if it is a facility or a client
    #    and assign capacities / group labels. Internal nodes remain
    #    purely structural.
    # ------------------------------------------------------------------
    for node in nodes:
        if not node.is_leaf:
            continue

        if local_random.random() < facility_probability:
            # Facility leaf
            node.capacity = local_random.randint(1, max_capacity)

            # Choose a random nonempty subset of {0, ..., t-1} as group indices.
            # This allows multi-group membership, i.e., intersecting groups.
            num_groups = local_random.randint(1, t)
            group_indices = local_random.sample(range(t), num_groups)

            # For convenience, we store one representative group index (1-based).
            node.facility_type = group_indices[0] + 1

            # Multi-group membership encoded as a 0/1 vector of length t.
            membership = [0] * t
            for gi in group_indices:
                membership[gi] = 1
            node.facility_type_vector = tuple(membership)

            node.is_client = False
        else:
            # Client leaf
            node.is_client = True
            node.capacity = 0
            node.facility_type = -1
            # Clients have no group membership; use the all-zero vector.
            node.facility_type_vector = tuple(0 for _ in range(t))

    # Root is always the first node in array representation.
    root = nodes[0]
    return root
# end build_binary_tree_with_facilities


###############################################################################
# Helper routines for pretty-printing / debugging
###############################################################################
def print_binary_tree(
    node: Optional[BinTreeNode],
    level: int = 0,
    prefix: str = "Root: ",
) -> None:
    """Print the binary tree sideways (root at left, children to the right).

    Facilities are shown with their group membership vector and capacity.
    Clients are shown with `[C]`.
    """
    if node is None:
        return

    # Print right subtree first so it appears on top when printed
    print_binary_tree(node.right, level + 1, prefix="R--- ")

    # Describe this node
    desc = f"{node.node_number}"
    if node.is_leaf:
        if node.is_client:
            desc += " [C]"
        else:
            desc += f" [F groups={node.facility_type_vector} cap={node.capacity}]"

    print("     " * level + prefix + desc)

    # Then left subtree
    print_binary_tree(node.left, level + 1, prefix="L--- ")
# end print_binary_tree


def print_binary_tree_node(node: Optional[BinTreeNode]) -> None:
    """Print all stored attributes of a single node (for debugging)."""
    if node is None:
        print("None")
        return

    details = [
        "---------Binary Tree Node Details---------",
        f"Node Number    : {node.node_number}",
        f"Parent Node    : {node.parent.node_number if node.parent else 'None'}",
        f"Parent Distance: {node.parent_distance}",
        f"Left Child     : {node.left.node_number if node.left else 'None'}",
        f"Left Distance  : {node.left_distance if node.left else 'N/A'}",
        f"Right Child    : {node.right.node_number if node.right else 'None'}",
        f"Right Distance : {node.right_distance if node.right else 'N/A'}",
        f"Capacity       : {node.capacity}",
        f"Facility Type  : {node.facility_type}",
        f"Type Vector    : {node.facility_type_vector}",
        f"Is Client      : {node.is_client}",
        f"Is Leaf        : {node.is_leaf}",
        "-----------------------------------",
    ]
    print("\n".join(details))
# end print_binary_tree_node


###############################################################################
def main() -> None:
    n = 10
    t = 3
    facility_probability = 0.5
    max_facility_types = 2
    max_capacity = 10
    seed = 123456
    root = build_binary_tree_with_facilities(
        n, t, facility_probability=facility_probability, 
        max_capacity=max_capacity,
        seed=seed
    )
    print_binary_tree(root)

if __name__ == "__main__":
    main()
