from binary_tree import build_binary_tree_with_facilities
from binary_tree import get_depth, get_nodes_at_level, get_node_by_number
from binary_tree import print_binary_tree_node
from binary_tree import print_binary_tree
from binary_tree import get_binary_tree_stats

from binary_tree_brute_force import binary_tree_brute_force
from binary_tree_heuristic import solve_tree_kmedian_fast

import time

from random import Random
from sys import stdout
from collections import defaultdict
from itertools import product

def generate_valid_vectors(tup):
    # Optimization: Use list comprehension and filter out the all-zero tuple directly
    ranges = [[0, val] if val else [0] for val in tup]
    return [vec for vec in product(*ranges) if any(vec)]

def binary_tree_dp_old(root, n, t, k, alpha, beta):
    T = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: float('inf'))))
    iters = 0
    depth = get_depth(root)
    if depth == 0:
        return 0

    for level in range(depth - 1, -1, -1):
        nodes = get_nodes_at_level(root, level)
        for node_id in nodes:
            node = get_node_by_number(root, node_id)

            if node.is_leaf:
                parent = node.parent
                parent_id = parent.node_number
                parent_dist = node.parent_distance
                edge = (parent_id, node_id)

                if node.is_client:
                    T[edge][(0,) * t][1] = parent_dist
                    iters += 1
                else:
                    T[edge][(0,) * t][0] = 0  # Not open facility
                    rvec = node.facility_type_vector
                    for c in range(0, node.capacity + 1):
                        T[edge][rvec][-c] = c * parent_dist
                        iters += 1
                continue  # Done with leaf

            # For non-leaf node
            node_id = node.node_number
            parent = node.parent
            parent_id = parent.node_number if parent else 0
            parent_dist = node.parent_distance if parent else 0
            parent_edge = (parent_id, node_id)

            # Prepare left child edge
            left_edge = None
            if node.left:
                left_edge = (node_id, node.left.node_number)
                left_keys = list(T[left_edge].keys())
            else:
                left_keys = []

            # Prepare right child edge
            right_edge = None
            if node.right:
                right_edge = (node_id, node.right.node_number)
                right_keys = list(T[right_edge].keys())
            else:
                right_keys = []

            # Merge combinations of left and right rvecs
            for l_rvec in left_keys:
                l_caps = list(T[left_edge][l_rvec].keys())
                for l_cap in l_caps:
                    l_cost = T[left_edge][l_rvec][l_cap]

                    if right_edge:
                        for r_rvec in right_keys:
                            r_caps = list(T[right_edge][r_rvec].keys())
                            for r_cap in r_caps:
                                r_cost = T[right_edge][r_rvec][r_cap]
                                new_rvec = tuple(l + r for l, r in zip(l_rvec, r_rvec))
                                new_cap = l_cap + r_cap
                                new_cost = l_cost + r_cost + abs(new_cap * parent_dist)
                                iters += 1

                                if new_cost < T[parent_edge][new_rvec][new_cap]:
                                    T[parent_edge][new_rvec][new_cap] = new_cost
                    else:
                        new_rvec = l_rvec
                        new_cap = l_cap
                        new_cost = l_cost + abs(new_cap * parent_dist)
                        iters += 1

                        if new_cost < T[parent_edge][new_rvec][new_cap]:
                            T[parent_edge][new_rvec][new_cap] = new_cost

    # Find best root solution
    root_edge = (0, 1)
    best_cost = float('inf')
    best_rvec = None

    alpha_rvec = alpha
    beta_rvec = beta

    for rvec, cap_dict in T[root_edge].items():
        if not all(a <= r <= b for a, r, b in zip(alpha_rvec, rvec, beta_rvec)):
            continue
        if sum(rvec) > k:
            continue

        for cap, cost in cap_dict.items():
            if cost < best_cost:
                best_cost = cost
                best_rvec = rvec

    return best_cost, iters

def binary_tree_dp_v1(root, n, t, k, alpha, beta):
    """
    Exact dynamic program for capacitated fair-range k-median on a tree.

    State:
        T[edge][rvec][p][z] = minimum cost of serving all clients in the
        subtree below the directed edge `edge = (parent, child)` such that

          * rvec (length t tuple) gives, for each group g, the number of
            facilities chosen from that group inside the subtree;
          * p is the total **number of facilities opened** in the subtree;
          * z is the net flow across this edge:
                z > 0  : z units of client demand go *up* across the edge;
                z < 0  : -z units of unused facility capacity go *down*.

        The contribution of edge (parent, child) to the objective is
            |z| * length(parent, child).

    At the (virtual) edge above the root, we require:
        - z = 0   (all demand is matched to capacity within the tree),
        - p <= k  (at most k facilities in total), and
        - alpha <= rvec <= beta (fair-range group constraints).
    """
    # 4-level defaultdict: edge -> rvec -> p -> z -> cost
    T = defaultdict(
        lambda: defaultdict(
            lambda: defaultdict(
                lambda: defaultdict(lambda: float("inf"))
            )
        )
    )
    iters = 0

    depth = get_depth(root)
    if depth == 0:
        return 0

    # ------------------------------------------------------------------
    # Process nodes bottom-up by level
    # ------------------------------------------------------------------
    for level in range(depth - 1, -1, -1):
        nodes = get_nodes_at_level(root, level)
        for node_id in nodes:
            node = get_node_by_number(root, node_id)

            # ----------------------------------------------------------
            # Leaf nodes: either a client or a facility
            # ----------------------------------------------------------
            if node.is_leaf:
                parent = node.parent
                if parent is None:
                    # Single-node tree handled by depth==0 above
                    continue
                parent_id = parent.node_number
                parent_dist = node.parent_distance
                edge = (parent_id, node_id)

                zero_rvec = (0,) * t

                if node.is_client:
                    # One unit of demand goes up to the parent.
                    T[edge][zero_rvec][0][1] = parent_dist
                    iters += 1
                else:
                    # Facility leaf
                    # Option 1: facility not opened
                    T[edge][zero_rvec][0][0] = 0.0
                    iters += 1

                    # Option 2: facility opened (counts as p = 1)
                    rvec = node.facility_type_vector
                    for c in range(0, node.capacity + 1):
                        z = -c  # -c units of capacity available below
                        cost = c * parent_dist
                        if cost < T[edge][rvec][1][z]:
                            T[edge][rvec][1][z] = cost
                        iters += 1
                continue  # Done with leaf

            # ----------------------------------------------------------
            # Internal node: merge children and push flow to parent
            # ----------------------------------------------------------
            node_id = node.node_number
            parent = node.parent
            if parent is None:
                parent_id = 0     # virtual parent of the root
                parent_dist = 0.0
            else:
                parent_id = parent.node_number
                parent_dist = node.parent_distance
            parent_edge = (parent_id, node_id)

            # Prepare left child edge
            if node.left is not None:
                left_edge = (node_id, node.left.node_number)
                left_items = list(T[left_edge].items())
            else:
                left_edge = None
                left_items = []

            # Prepare right child edge
            if node.right is not None:
                right_edge = (node_id, node.right.node_number)
                right_items = list(T[right_edge].items())
            else:
                right_edge = None
                right_items = []

            # No children (should not happen: leaves handled above)
            if left_edge is None and right_edge is None:
                continue

            # Only right child
            if left_edge is None:
                for r_rvec, r_p_dict in right_items:
                    for r_p, r_cap_dict in r_p_dict.items():
                        for r_cap, r_cost in r_cap_dict.items():
                            new_rvec = r_rvec
                            new_p = r_p
                            if new_p > k:
                                continue
                            new_cap = r_cap
                            new_cost = r_cost + abs(new_cap * parent_dist)
                            iters += 1
                            if new_cost < T[parent_edge][new_rvec][new_p][new_cap]:
                                T[parent_edge][new_rvec][new_p][new_cap] = new_cost
                continue

            # Only left child
            if right_edge is None:
                for l_rvec, l_p_dict in left_items:
                    for l_p, l_cap_dict in l_p_dict.items():
                        for l_cap, l_cost in l_cap_dict.items():
                            new_rvec = l_rvec
                            new_p = l_p
                            if new_p > k:
                                continue
                            new_cap = l_cap
                            new_cost = l_cost + abs(new_cap * parent_dist)
                            iters += 1
                            if new_cost < T[parent_edge][new_rvec][new_p][new_cap]:
                                T[parent_edge][new_rvec][new_p][new_cap] = new_cost
                continue

            # Both children present: merge combinations of left and right states
            for l_rvec, l_p_dict in left_items:
                for l_p, l_cap_dict in l_p_dict.items():
                    for l_cap, l_cost in l_cap_dict.items():
                        for r_rvec, r_p_dict in right_items:
                            for r_p, r_cap_dict in r_p_dict.items():
                                new_p = l_p + r_p
                                if new_p > k:
                                    continue
                                for r_cap, r_cost in r_cap_dict.items():
                                    new_rvec = tuple(l + r for l, r in zip(l_rvec, r_rvec))
                                    new_cap = l_cap + r_cap
                                    new_cost = l_cost + r_cost + abs(new_cap * parent_dist)
                                    iters += 1
                                    if new_cost < T[parent_edge][new_rvec][new_p][new_cap]:
                                        T[parent_edge][new_rvec][new_p][new_cap] = new_cost

    # ------------------------------------------------------------------
    # Extract best solution at the virtual edge above the root
    # ------------------------------------------------------------------
    root_edge = (0, root.node_number)
    best_cost = float("inf")
    best_rvec = None
    best_p = None

    alpha_rvec = alpha
    beta_rvec = beta

    for rvec, p_dict in T[root_edge].items():
        # Fair-range constraints
        if not all(a <= r <= b for a, r, b in zip(alpha_rvec, rvec, beta_rvec)):
            continue

        for p, cap_dict in p_dict.items():
            # At most k facilities
            if p > k:
                continue

            # All demand must be matched internally: require z = 0
            if 0 not in cap_dict:
                continue

            cost = cap_dict[0]
            if cost < best_cost:
                best_cost = cost
                best_rvec = rvec
                best_p = p

    return best_cost, iters

def binary_tree_dp_v2(root, n, t, k, alpha, beta):
    """
    Exact dynamic program for capacitated fair-range k-median on a tree metric.

    This implementation follows the DP in Appendix B.2.2 of
    Gadekar & Thejaswi (2025), but:
      * Encodes the "at most k centers" constraint as one extra group,
      * Adds aggressive pruning for scalability.

    State for an edge e = (parent, child):
        T_e[rvec][z] = minimum cost for the subtree below e, where:
          - rvec is a length-(t+1) tuple of non-negative integers:
              rvec[g]  = #open centers from group g (for g < t)
              rvec[t]  = total #open centers in the subtree
          - z is the net flow of clients across e:
              z > 0 : z units of demand go upward across e
              z < 0 : -z units of surplus capacity go downward across e

        The edge e itself contributes |z| * length(e) to the cost.

    At the virtual root edge (0, root.node_number), we require:
        - Fair-range constraints for the original t groups: alpha <= rvec[:t] <= beta
        - At most k centers via the extra group: rvec[t] <= k
        - z = 0   (all demand is matched within the tree)
    """

    if root is None:
        return 0.0, 0

    # ------------------------------------------------------------------
    # 0. Extend groups: add one extra "k-group" that counts centers.
    # ------------------------------------------------------------------
    t_orig = t
    t_ext = t_orig + 1

    alpha_ext = tuple(alpha) + (0,)   # no lower bound on #centers
    beta_ext  = tuple(beta) + (k,)    # at most k centers

    # ------------------------------------------------------------------
    # 1. Collect all nodes reachable from the root.
    # ------------------------------------------------------------------
    nodes = []
    stack = [root]
    seen = set()
    while stack:
        node = stack.pop()
        if node is None:
            continue
        if node.node_number in seen:
            continue
        seen.add(node.node_number)
        nodes.append(node)
        if node.left is not None:
            stack.append(node.left)
        if node.right is not None:
            stack.append(node.right)

    id2node = {node.node_number: node for node in nodes}

    # ------------------------------------------------------------------
    # 2. Subtree statistics: clients and (max) group counts per subtree.
    #    We also compute global counts.
    # ------------------------------------------------------------------
    subtree_clients = {}
    subtree_groups = {}

    def dfs_subtree(u_id):
        node = id2node[u_id]

        clients = 1 if getattr(node, "is_client", False) else 0
        groups = [0] * t_ext

        if node.left is not None:
            c_l, g_l = dfs_subtree(node.left.node_number)
            clients += c_l
            for j in range(t_ext):
                groups[j] += g_l[j]

        if node.right is not None:
            c_r, g_r = dfs_subtree(node.right.node_number)
            clients += c_r
            for j in range(t_ext):
                groups[j] += g_r[j]

        # Facilities contribute their group vector; clients do not.
        if not getattr(node, "is_client", False) and getattr(node, "capacity", 0) > 0:
            base_vec = getattr(node, "facility_type_vector", None)
            if base_vec is not None:
                if len(base_vec) != t_orig:
                    raise ValueError(
                        f"facility_type_vector length {len(base_vec)} "
                        f"!= t={t_orig} at node {node.node_number}"
                    )
                # Extended vector: original groups + "k-group"
                ext_vec = list(base_vec) + [1]
                for j in range(t_ext):
                    groups[j] += int(ext_vec[j])

        subtree_clients[u_id] = clients
        subtree_groups[u_id] = groups
        return clients, groups

    total_clients, total_groups = dfs_subtree(root.node_number)

    # Max possible groups in whole instance (for sanity/pruning)
    total_groups = list(total_groups)

    # outside_groups[u][g] = how many facilities of group g can still be opened
    # outside the subtree rooted at u.
    outside_groups = {
        u_id: [max(0, total_groups[g] - subtree_groups[u_id][g]) for g in range(t_ext)]
        for u_id in subtree_groups
    }

    # ------------------------------------------------------------------
    # 3. DP over edges via recursion: returns T_e for edge (parent, node).
    #    Each T_e is: Dict[rvec -> Dict[z -> cost]].
    # ------------------------------------------------------------------
    iters = 0

    def prune_states(node, states):
        """
        Prune states at edge (parent, node) using:
          - global fair-range upper bounds,
          - feasibility w.r.t. fair-range lower bounds, given what's outside,
          - global |z| <= total_clients bound.
        """
        nonlocal iters
        u_id = node.node_number
        outside = outside_groups[u_id]

        pruned = defaultdict(dict)
        for rvec, zdict in states.items():
            # 1. Never exceed upper bounds (including k via the last group).
            too_big = False
            for g in range(t_ext):
                if rvec[g] > beta_ext[g]:
                    too_big = True
                    break
            if too_big:
                continue

            # 2. Check if fair-range lower bounds can still be met globally.
            feasible = True
            for g in range(t_orig):  # only the original fairness groups
                if rvec[g] + outside[g] < alpha_ext[g]:
                    feasible = False
                    break
            if not feasible:
                continue

            # 3. Keep best cost for each z with |z| <= total_clients.
            for z, cost in zdict.items():
                if abs(z) > total_clients:
                    continue
                iters += 1
                cur = pruned[rvec].get(z, float("inf"))
                if cost < cur:
                    pruned[rvec][z] = cost

        return pruned

    def state_size(states):
        """Approximate size of a state table: total (rvec,z) pairs."""
        return sum(len(zdict) for zdict in states.values())

    def dp_edge(node):
        """
        Compute DP table for the edge (parent, node).
        For the root, parent is virtual with id 0 and length 0.
        """
        nonlocal iters
        if node is None:
            return defaultdict(dict)

        # Edge (parent_id, node_id)
        if node.parent is None:
            parent_dist = 0.0
        else:
            parent_dist = float(getattr(node, "parent_distance", 0.0))

        states = defaultdict(dict)

        # --------------------------------------------------------------
        # Leaf nodes: either client or facility.
        # --------------------------------------------------------------
        if getattr(node, "is_leaf", False):
            zero_rvec = tuple(0 for _ in range(t_ext))

            if getattr(node, "is_client", False):
                # One unit of demand goes upward.
                states[zero_rvec][1] = parent_dist
                return prune_states(node, states)

            # Facility leaf
            # Option 1: facility not opened.
            states[zero_rvec][0] = 0.0

            # Option 2: facility opened.
            base_vec = getattr(node, "facility_type_vector", None)
            if base_vec is not None:
                if len(base_vec) != t_orig:
                    raise ValueError(
                        f"facility_type_vector length {len(base_vec)} "
                        f"!= t={t_orig} at node {node.node_number}"
                    )
                ext_vec = tuple(list(base_vec) + [1])

                # We never need more reserved capacity than total_clients.
                cap = int(getattr(node, "capacity", 0))
                cmax = min(cap, total_clients)
                for c in range(0, cmax + 1):
                    z = -c  # negative => surplus capacity going downward
                    cost = c * parent_dist
                    iters += 1
                    cur = states[ext_vec].get(z, float("inf"))
                    if cost < cur:
                        states[ext_vec][z] = cost

            return prune_states(node, states)

        # --------------------------------------------------------------
        # Internal node: combine child edge states and push through parent.
        # --------------------------------------------------------------
        left_states = dp_edge(node.left) if node.left is not None else None
        right_states = dp_edge(node.right) if node.right is not None else None

        # If somehow both children are None but is_leaf was False, just return empty.
        if left_states is None and right_states is None:
            return states

        # Only one child: propagate its states upward.
        if left_states is None:
            child_states = right_states
            for rvec, zdict in child_states.items():
                for z, cost in zdict.items():
                    if abs(z) > total_clients:
                        continue
                    # rvec unchanged; only pay cost across this edge.
                    new_rvec = rvec
                    new_z = z
                    new_cost = cost + abs(new_z) * parent_dist
                    iters += 1
                    cur = states[new_rvec].get(new_z, float("inf"))
                    if new_cost < cur:
                        states[new_rvec][new_z] = new_cost

            return prune_states(node, states)

        if right_states is None:
            child_states = left_states
            for rvec, zdict in child_states.items():
                for z, cost in zdict.items():
                    if abs(z) > total_clients:
                        continue
                    new_rvec = rvec
                    new_z = z
                    new_cost = cost + abs(new_z) * parent_dist
                    iters += 1
                    cur = states[new_rvec].get(new_z, float("inf"))
                    if new_cost < cur:
                        states[new_rvec][new_z] = new_cost

            return prune_states(node, states)

        # Both children: merge.
        LS = left_states
        RS = right_states

        # Always iterate outer over the smaller table for speed.
        if state_size(LS) > state_size(RS):
            LS, RS = RS, LS

        for l_rvec, l_zdict in LS.items():
            for r_rvec, r_zdict in RS.items():
                # Combine group vectors.
                new_r = [l_rvec[g] + r_rvec[g] for g in range(t_ext)]

                # Early upper-bound pruning on groups.
                too_big = False
                for g in range(t_ext):
                    if new_r[g] > beta_ext[g]:
                        too_big = True
                        break
                if too_big:
                    continue

                new_rvec = tuple(new_r)

                for zl, cl in l_zdict.items():
                    for zr, cr in r_zdict.items():
                        new_z = zl + zr
                        if abs(new_z) > total_clients:
                            continue

                        new_cost = cl + cr + abs(new_z) * parent_dist
                        iters += 1
                        cur = states[new_rvec].get(new_z, float("inf"))
                        if new_cost < cur:
                            states[new_rvec][new_z] = new_cost

        # We can safely discard child tables here to save memory.
        left_states.clear()
        right_states.clear()

        return prune_states(node, states)

    # ------------------------------------------------------------------
    # 4. Run DP starting at the root; its edge has virtual parent with
    #    distance 0, so we just call dp_edge(root).
    # ------------------------------------------------------------------
    root_states = dp_edge(root)

    best_cost = float("inf")

    for rvec, zdict in root_states.items():
        # At the root, all fair-range constraints are already enforced
        # by pruning; we only need z = 0 and the extra group <= k
        # (which is also enforced by beta_ext[t_ext-1] = k).
        if 0 not in zdict:
            continue
        cost = zdict[0]
        if cost < best_cost:
            best_cost = cost

    return best_cost, iters


def binary_tree_dp_heuristic(
    root,
    n,
    t,
    k,
    alpha,
    beta,
    max_states_per_edge=5000,
):
    """
    Heuristic variant of the tree DP for capacitated fair-range k-median.

    Differences vs. the exact DP:
      * Follows the same DP formulation on edges.
      * Encodes the "at most k centers" constraint as one extra group.
      * Uses aggressive pruning AND cost-based truncation (keep only the
        top-K states per edge) to improve running time and memory.

    Because of the truncation step, this procedure is NOT guaranteed to
    be globally optimal, but in practice it is much faster and often
    matches the exact DP on moderate-sized instances.

    Parameters
    ----------
    root : BinTreeNode
        Root of the tree (after binarization and pushing clients/facilities
        to leaves as in the paper).
    n : int
        Number of clients in the instance (not used internally; kept for
        API compatibility).
    t : int
        Number of *fairness groups*.
    k : int
        Maximum allowed number of opened facilities.
    alpha, beta : sequences of int, length t
        Lower and upper bounds for each fairness group.

    max_states_per_edge : int, optional
        Maximum number of (rvec, z) states kept per edge after pruning.
        Smaller => faster but more approximate; larger => slower but closer
        to exact DP behavior.

    Returns
    -------
    best_cost : float
        Cost of the best DP state that survives pruning.
    best_group_counts : Tuple[int, ...]
        Length-t tuple giving the number of opened facilities chosen from
        each fairness group in the (heuristic) solution.
    best_num_centers : int
        Total number of centers opened in the (heuristic) solution.
    iters : int
        A rough counter of inner-loop iterations, for diagnostics.
    """
    if root is None:
        return 0.0, (0,) * t, 0, 0

    # ------------------------------------------------------------------
    # Extend groups: extra "k-group" that counts total number of centers.
    # ------------------------------------------------------------------
    t_orig = t
    t_ext = t_orig + 1

    alpha_ext = tuple(alpha) + (0,)   # no lower bound on #centers
    beta_ext  = tuple(beta) + (k,)    # at most k centers globally

    # ------------------------------------------------------------------
    # Collect all nodes reachable from root.
    # ------------------------------------------------------------------
    nodes = []
    stack = [root]
    seen = set()
    while stack:
        node = stack.pop()
        if node is None:
            continue
        if node.node_number in seen:
            continue
        seen.add(node.node_number)
        nodes.append(node)
        if node.left is not None:
            stack.append(node.left)
        if node.right is not None:
            stack.append(node.right)

    id2node = {node.node_number: node for node in nodes}

    # ------------------------------------------------------------------
    # Subtree statistics: #clients and max possible group counts per subtree.
    # ------------------------------------------------------------------
    subtree_clients = {}
    subtree_groups = {}

    def dfs_subtree(u_id):
        node = id2node[u_id]

        clients = 1 if getattr(node, "is_client", False) else 0
        groups = [0] * t_ext

        if node.left is not None:
            c_l, g_l = dfs_subtree(node.left.node_number)
            clients += c_l
            for j in range(t_ext):
                groups[j] += g_l[j]

        if node.right is not None:
            c_r, g_r = dfs_subtree(node.right.node_number)
            clients += c_r
            for j in range(t_ext):
                groups[j] += g_r[j]

        # Facilities contribute their group vector; clients do not.
        if not getattr(node, "is_client", False) and getattr(node, "capacity", 0) > 0:
            base_vec = getattr(node, "facility_type_vector", None)
            if base_vec is not None:
                if len(base_vec) != t_orig:
                    raise ValueError(
                        f"facility_type_vector length {len(base_vec)} "
                        f"!= t={t_orig} at node {node.node_number}"
                    )
                ext_vec = list(base_vec) + [1]  # extra group counts centers
                for j in range(t_ext):
                    groups[j] += int(ext_vec[j])

        subtree_clients[u_id] = clients
        subtree_groups[u_id] = groups
        return clients, groups

    total_clients, total_groups = dfs_subtree(root.node_number)
    total_groups = list(total_groups)

    # outside_groups[u][g] = facilities of group g that could still be opened
    # outside the subtree rooted at u.
    outside_groups = {
        u_id: [max(0, total_groups[g] - subtree_groups[u_id][g]) for g in range(t_ext)]
        for u_id in subtree_groups
    }

    iters = 0

    def state_size(states):
        return sum(len(zdict) for zdict in states.values())

    def truncate_states(states):
        """
        Cost-based truncation: keep only the top-K (rvec, z) states by cost.
        """
        nonlocal iters
        flat = []
        for rvec, zdict in states.items():
            for z, cost in zdict.items():
                flat.append((cost, rvec, z))

        if len(flat) <= max_states_per_edge:
            return states

        # Sort by cost (ascending) and keep the first K
        flat.sort(key=lambda x: x[0])
        truncated = defaultdict(dict)
        kept = 0
        for cost, rvec, z in flat:
            truncated[rvec][z] = cost
            kept += 1
            if kept >= max_states_per_edge:
                break

        return truncated

    def prune_states(node, states):
        """
        Logic pruning + (afterwards) cost-based truncation.
        """
        nonlocal iters
        u_id = node.node_number
        outside = outside_groups[u_id]

        pruned = defaultdict(dict)
        for rvec, zdict in states.items():
            # 1) hard upper bounds on groups (including "k-group")
            too_big = False
            for g in range(t_ext):
                if rvec[g] > beta_ext[g]:
                    too_big = True
                    break
            if too_big:
                continue

            # 2) feasibility of reaching lower bounds with facilities outside
            feasible = True
            for g in range(t_orig):
                if rvec[g] + outside[g] < alpha_ext[g]:
                    feasible = False
                    break
            if not feasible:
                continue

            # 3) keep best cost for each z with |z| <= total_clients
            for z, cost in zdict.items():
                if abs(z) > total_clients:
                    continue
                iters += 1
                cur = pruned[rvec].get(z, float("inf"))
                if cost < cur:
                    pruned[rvec][z] = cost

        if not pruned:
            return pruned

        # 4) heuristic: truncate to top-K by cost
        return truncate_states(pruned)

    def dp_edge(node):
        """
        DP over edge (parent, node).
        Returns: Dict[rvec -> Dict[z -> cost]]
        """
        nonlocal iters
        if node is None:
            return defaultdict(dict)

        if node.parent is None:
            parent_dist = 0.0
        else:
            parent_dist = float(getattr(node, "parent_distance", 0.0))

        states = defaultdict(dict)

        # --------------------------------------------------------------
        # Leaf
        # --------------------------------------------------------------
        if getattr(node, "is_leaf", False):
            zero_rvec = tuple(0 for _ in range(t_ext))

            if getattr(node, "is_client", False):
                # one unit of demand flows upward
                states[zero_rvec][1] = parent_dist
                return prune_states(node, states)

            # Facility leaf
            # Option 1: facility not opened.
            states[zero_rvec][0] = 0.0

            # Option 2: facility opened.
            base_vec = getattr(node, "facility_type_vector", None)
            if base_vec is not None:
                if len(base_vec) != t_orig:
                    raise ValueError(
                        f"facility_type_vector length {len(base_vec)} "
                        f"!= t={t_orig} at node {node.node_number}"
                    )
                ext_vec = tuple(list(base_vec) + [1])

                cap = int(getattr(node, "capacity", 0))
                cmax = min(cap, total_clients)
                for c in range(0, cmax + 1):
                    z = -c      # -c units of surplus capacity downwards
                    cost = c * parent_dist
                    iters += 1
                    cur = states[ext_vec].get(z, float("inf"))
                    if cost < cur:
                        states[ext_vec][z] = cost

            return prune_states(node, states)

        # --------------------------------------------------------------
        # Internal node: combine child edge states and push upward.
        # --------------------------------------------------------------
        left_states = dp_edge(node.left) if node.left is not None else None
        right_states = dp_edge(node.right) if node.right is not None else None

        if left_states is None and right_states is None:
            return states  # should not happen for non-leaf

        # Only one child: just push its states across this edge.
        if left_states is None:
            for rvec, zdict in right_states.items():
                for z, cost in zdict.items():
                    if abs(z) > total_clients:
                        continue
                    new_rvec = rvec
                    new_z = z
                    new_cost = cost + abs(new_z) * parent_dist
                    iters += 1
                    cur = states[new_rvec].get(new_z, float("inf"))
                    if new_cost < cur:
                        states[new_rvec][new_z] = new_cost
            return prune_states(node, states)

        if right_states is None:
            for rvec, zdict in left_states.items():
                for z, cost in zdict.items():
                    if abs(z) > total_clients:
                        continue
                    new_rvec = rvec
                    new_z = z
                    new_cost = cost + abs(new_z) * parent_dist
                    iters += 1
                    cur = states[new_rvec].get(new_z, float("inf"))
                    if new_cost < cur:
                        states[new_rvec][new_z] = new_cost
            return prune_states(node, states)

        # Both children present: merge.
        LS = left_states
        RS = right_states

        # Iterate outer over smaller table for speed.
        if state_size(LS) > state_size(RS):
            LS, RS = RS, LS

        for l_rvec, l_zdict in LS.items():
            for r_rvec, r_zdict in RS.items():
                # Combine group vectors.
                new_r = [l_rvec[g] + r_rvec[g] for g in range(t_ext)]

                # Early upper-bound pruning.
                too_big = False
                for g in range(t_ext):
                    if new_r[g] > beta_ext[g]:
                        too_big = True
                        break
                if too_big:
                    continue

                new_rvec = tuple(new_r)

                for zl, cl in l_zdict.items():
                    for zr, cr in r_zdict.items():
                        new_z = zl + zr
                        if abs(new_z) > total_clients:
                            continue
                        new_cost = cl + cr + abs(new_z) * parent_dist
                        iters += 1
                        cur = states[new_rvec].get(new_z, float("inf"))
                        if new_cost < cur:
                            states[new_rvec][new_z] = new_cost

        # Free child tables.
        left_states.clear()
        right_states.clear()

        return prune_states(node, states)

    # ------------------------------------------------------------------
    # Run DP at root, then pick best root state.
    # ------------------------------------------------------------------
    root_states = dp_edge(root)

    best_cost = float("inf")
    best_rvec_ext = None

    for rvec_ext, zdict in root_states.items():
        # Require z = 0 at the virtual root edge.
        if 0 not in zdict:
            continue

        # Re-check original fair-range constraints (safety).
        ok = True
        for g in range(t_orig):
            if rvec_ext[g] < alpha_ext[g] or rvec_ext[g] > beta_ext[g]:
                ok = False
                break
        if not ok:
            continue

        # At most k centers on the extra group (safety).
        if rvec_ext[t_orig] > k:
            continue

        cost = zdict[0]
        if cost < best_cost:
            best_cost = cost
            best_rvec_ext = rvec_ext

    if best_rvec_ext is None:
        # No feasible state survived pruning.
        return float("inf"), None, None, iters

    best_group_counts = tuple(best_rvec_ext[:t_orig])  # per fairness group
    best_num_centers = int(best_rvec_ext[t_orig])      # total centers

    return best_cost, best_group_counts, best_num_centers, iters

def sanity_check():
    """
    Prints timing and basic stats so you can see how the DP scales with n.
    Feel free to tweak nn, tt, kk, no_of_trials for your own tests.
    """

    facility_prob = 0.5
    seed = 123456789

    l_random = Random()
    l_random.seed(seed)

    # Tree sizes (all of the form 2^h - 1, i.e., complete binary trees).
    nn = [255]
    # Number of groups.
    tt = [3, 4]
    # Number of centers k.
    kk = [3, 4, 5, 6]
    no_of_trials = 2

    header = (
        f"{'input_seed':>10s} "
        f"{'n':>4s} {'t':>3s} {'k':>3s} {'rep':>3s} "
        f"{'n_c':>4s} {'n_f':>4s} {'cap':>4s} "
        f"{'input':>5s} {'BF':>6s} "
        f"{'DP_v1':>6s} {'DP_v2':>6s} {'DP_h1':>6s} {'DP_f1':>6s} "
        f"{'BF':>8s} {'DP_v1':>8s} {'DP_v2':>8s} {'DP_h1':>8s} {'DP_f1':>8s} "
        f"{'test':>5s}"
    )
    stdout.write(header + "\n")
    stdout.write("=" * len(header) + "\n")

    for n, t, k in product(nn, tt, kk):
        for trial in range(no_of_trials):
            input_seed = l_random.randint(1, 123456789)


            alpha = (1,) * t
            beta = (k,) * t

            start_time = time.time()
            # Build random instance on a tree metric.
            root = build_binary_tree_with_facilities(
                n, t, facility_probability=facility_prob,
                max_facility_types=2, seed=input_seed
            )
            input_time = time.time() - start_time

            input_stats = get_binary_tree_stats(root)
            n_c = input_stats['num_clients']
            n_f = input_stats['num_facilities']
            cap_sum = input_stats['total_capacity']

            stdout.write(f"{input_seed:10d} {n:4d} {t:3d} {k:3d} {trial:3d} {n_c:4d} {n_f:4d} {cap_sum:4d}")
            stdout.flush()

            # Run brute-force for sanity check.
            bf_start = time.time()
            bf_cost, bf_centers, bf_assignment = binary_tree_brute_force(root, k, alpha, beta)
            bf_time = time.time() - bf_start

            # Run exact DP on tree for that instance.
            dp_start_v1 = time.time()
            dp_cost_v1, dp_iters_v1 = binary_tree_dp_v1(root, n, t, k, alpha, beta)
            dp_time_v1 = time.time() - dp_start_v1

            dp_start_v2 = time.time()
            dp_cost_v2, dp_iters_v2 = binary_tree_dp_v2(root, n, t, k, alpha, beta)
            dp_time_v2 = time.time() - dp_start_v2

            dp_start_heur = time.time()
            dp_cost_heur, dp_group_counts_heur, dp_num_centers_heur, dp_iters_heur = binary_tree_dp_heuristic(
                        root, n, t, k, alpha, beta,
                        max_states_per_edge=5000
                    )
            dp_time_heur = time.time() - dp_start_heur

            dp_start_fast = time.time()
            result = solve_tree_kmedian_fast(root, n, t, k, alpha, beta)
            dp_time_fast = time.time() - df_start_fast

            total_time = time.time() - start_time
            test_passed = (dp_cost_v1 == bf_cost) and (dp_cost_v2 == bf_cost)

            stdout.write(
                f" {input_time:5.2f} {bf_time:6.2f} "
                f"{dp_time_v1:6.2f} {dp_time_v2:6.2f} {dp_time_heur:6.2f} {dp_time_fast:6.2f} "
                f"{bf_cost:8.2f} {dp_cost_v1:8.2f} {dp_cost_v2:8.2f} {dp_cost_heur:8.2f} {result.cost:8.2f} "
                f"{'PASS' if test_passed else 'FAIL':>5s}\n"
            )
    stdout.write("=" * len(header) + "\n")

def scalability_test_1():
    """
    Prints timing and basic stats so you can see how the DP scales with n.
    Feel free to tweak nn, tt, kk, no_of_trials for your own tests.
    """

    facility_prob = 0.5
    seed = 123456789

    l_random = Random()
    l_random.seed(seed)

    # Tree sizes (all of the form 2^h - 1, i.e., complete binary trees).
    nn = [255, 511, 1023]
    # Number of groups.
    tt = [3, 4]
    # Number of centers k.
    kk = [4, 5, 6, 6, 7, 8]
    no_of_trials = 2

    header = (
        f"{'seed':>10s} "
        f"{'n':>4s} {'t':>3s} {'k':>3s} {'rep':>3s} "
        f"{'input':>8s} "
        f"{'DP_time_v2':>10s} {'DP_time_h1':>10s} {'DP_time_h2':>10s} {'DP_time_h3':>10s} "
        f"{'DP_cost_v2':>10s} {'DP_cost_h1':>10s} {'DP_cost_h2':>10s} {'DP_cost_h3':>10s}"
    )
    stdout.write(header + "\n")
    stdout.write("=" * len(header) + "\n")

    for n, t, k in product(nn, tt, kk):
        for trial in range(no_of_trials):
            input_seed = l_random.randint(1, 123456789)

            stdout.write(f"{input_seed:10d} {n:4d} {t:3d} {k:3d} {trial:3d} ")
            stdout.flush()

            alpha = (1,) * t
            beta = (k,) * t

            start_time = time.time()
            # Build random instance on a tree metric.
            root = build_binary_tree_with_facilities(
                n, t, facility_probability=facility_prob,
                max_facility_types=2, seed=input_seed
            )
            input_time = time.time() - start_time

            # Run exact DP on tree for that instance.
            # dp_start_v1 = time.time()
            # dp_cost_v1, dp_iters_v1 = binary_tree_dp_v1(root, n, t, k, alpha, beta)
            # dp_time_v1 = time.time() - dp_start_v1

            dp_start_v2 = time.time()
            dp_cost_v2, dp_iters_v2 = binary_tree_dp_v2(root, n, t, k, alpha, beta)
            dp_time_v2 = time.time() - dp_start_v2

            dp_start_h1 = time.time()
            dp_cost_h1, dp_group_counts_h1, dp_num_centers_h1, dp_iters_h1 = binary_tree_dp_heuristic(
                        root, n, t, k, alpha, beta,
                        max_states_per_edge=10000
                    )
            dp_time_h1 = time.time() - dp_start_h1

            dp_start_h2 = time.time()
            dp_cost_h2, dp_group_counts_h2, dp_num_centers_h2, dp_iters_h2 = binary_tree_dp_heuristic(
                        root, n, t, k, alpha, beta,
                        max_states_per_edge=20000
                    )
            dp_time_h2 = time.time() - dp_start_h2

            dp_start_h3 = time.time()
            dp_cost_h3, dp_group_counts_h3, dp_num_centers_h3, dp_iters_h3 = binary_tree_dp_heuristic(
                        root, n, t, k, alpha, beta,
                        max_states_per_edge=30000
                    )
            dp_time_h3 = time.time() - dp_start_h3

            total_time = time.time() - start_time

            stdout.write(
                f"{input_time:8.2f} "
                f"{dp_time_v2:10.2f} {dp_time_h1:10.2f} {dp_time_h2:10.2f} {dp_time_h3:10.2f} "
                f"{dp_cost_v2:10.2f} {dp_cost_h1:10.2f} {dp_cost_h2:10.2f} {dp_cost_h3:10.2f}\n"
            )
    stdout.write("=" * len(header) + "\n")


if __name__ == "__main__":
    sanity_check_1() 
    # scalability_test_1()
