from collections import defaultdict

from binary_tree import build_binary_tree_with_facilities

def binary_tree_dp_pruning(
    root,
    n,
    t,
    k,
    alpha,
    beta,
    max_states_per_edge=5000,
):
    """
    Heuristic variant of the tree DP for capacitated fair-range k-median.

    Differences vs. the exact DP:
      * Follows the same DP formulation on edges.
      * Encodes the "at most k centers" constraint as one extra group.
      * Uses aggressive pruning AND cost-based truncation (keep only the
        top-K states per edge) to improve running time and memory.

    Because of the truncation step, this procedure is NOT guaranteed to
    be globally optimal, but in practice it is much faster and often
    matches the exact DP on moderate-sized instances.

    Parameters
    ----------
    root : BinTreeNode
        Root of the tree (after binarization and pushing clients/facilities
        to leaves as in the paper).
    n : int
        Number of clients in the instance (not used internally; kept for
        API compatibility).
    t : int
        Number of *fairness groups*.
    k : int
        Maximum allowed number of opened facilities.
    alpha, beta : sequences of int, length t
        Lower and upper bounds for each fairness group.

    max_states_per_edge : int, optional
        Maximum number of (rvec, z) states kept per edge after pruning.
        Smaller => faster but more approximate; larger => slower but closer
        to exact DP behavior.

    Returns
    -------
    best_cost : float
        Cost of the best DP state that survives pruning.
    best_group_counts : Tuple[int, ...]
        Length-t tuple giving the number of opened facilities chosen from
        each fairness group in the (heuristic) solution.
    best_num_centers : int
        Total number of centers opened in the (heuristic) solution.
    iters : int
        A rough counter of inner-loop iterations, for diagnostics.
    """
    if root is None:
        return 0.0, (0,) * t, 0, 0

    # ------------------------------------------------------------------
    # Extend groups: extra "k-group" that counts total number of centers.
    # ------------------------------------------------------------------
    t_orig = t
    t_ext = t_orig + 1

    alpha_ext = tuple(alpha) + (0,)   # no lower bound on #centers
    beta_ext  = tuple(beta) + (k,)    # at most k centers globally

    # ------------------------------------------------------------------
    # Collect all nodes reachable from root.
    # ------------------------------------------------------------------
    nodes = []
    stack = [root]
    seen = set()
    while stack:
        node = stack.pop()
        if node is None:
            continue
        if node.node_number in seen:
            continue
        seen.add(node.node_number)
        nodes.append(node)
        if node.left is not None:
            stack.append(node.left)
        if node.right is not None:
            stack.append(node.right)

    id2node = {node.node_number: node for node in nodes}

    # ------------------------------------------------------------------
    # Subtree statistics: #clients and max possible group counts per subtree.
    # ------------------------------------------------------------------
    subtree_clients = {}
    subtree_groups = {}

    def dfs_subtree(u_id):
        node = id2node[u_id]

        clients = 1 if getattr(node, "is_client", False) else 0
        groups = [0] * t_ext

        if node.left is not None:
            c_l, g_l = dfs_subtree(node.left.node_number)
            clients += c_l
            for j in range(t_ext):
                groups[j] += g_l[j]

        if node.right is not None:
            c_r, g_r = dfs_subtree(node.right.node_number)
            clients += c_r
            for j in range(t_ext):
                groups[j] += g_r[j]

        # Facilities contribute their group vector; clients do not.
        if not getattr(node, "is_client", False) and getattr(node, "capacity", 0) > 0:
            base_vec = getattr(node, "facility_type_vector", None)
            if base_vec is not None:
                if len(base_vec) != t_orig:
                    raise ValueError(
                        f"facility_type_vector length {len(base_vec)} "
                        f"!= t={t_orig} at node {node.node_number}"
                    )
                ext_vec = list(base_vec) + [1]  # extra group counts centers
                for j in range(t_ext):
                    groups[j] += int(ext_vec[j])

        subtree_clients[u_id] = clients
        subtree_groups[u_id] = groups
        return clients, groups

    total_clients, total_groups = dfs_subtree(root.node_number)
    total_groups = list(total_groups)

    # outside_groups[u][g] = facilities of group g that could still be opened
    # outside the subtree rooted at u.
    outside_groups = {
        u_id: [max(0, total_groups[g] - subtree_groups[u_id][g]) for g in range(t_ext)]
        for u_id in subtree_groups
    }

    iters = 0

    def state_size(states):
        return sum(len(zdict) for zdict in states.values())

    def truncate_states(states):
        """
        Cost-based truncation: keep only the top-K (rvec, z) states by cost.
        """
        nonlocal iters
        flat = []
        for rvec, zdict in states.items():
            for z, cost in zdict.items():
                flat.append((cost, rvec, z))

        if len(flat) <= max_states_per_edge:
            return states

        # Sort by cost (ascending) and keep the first K
        flat.sort(key=lambda x: x[0])
        truncated = defaultdict(dict)
        kept = 0
        for cost, rvec, z in flat:
            truncated[rvec][z] = cost
            kept += 1
            if kept >= max_states_per_edge:
                break

        return truncated

    def prune_states(node, states):
        """
        Logic pruning + (afterwards) cost-based truncation.
        """
        nonlocal iters
        u_id = node.node_number
        outside = outside_groups[u_id]

        pruned = defaultdict(dict)
        for rvec, zdict in states.items():
            # 1) hard upper bounds on groups (including "k-group")
            too_big = False
            for g in range(t_ext):
                if rvec[g] > beta_ext[g]:
                    too_big = True
                    break
            if too_big:
                continue

            # 2) feasibility of reaching lower bounds with facilities outside
            feasible = True
            for g in range(t_orig):
                if rvec[g] + outside[g] < alpha_ext[g]:
                    feasible = False
                    break
            if not feasible:
                continue

            # 3) keep best cost for each z with |z| <= total_clients
            for z, cost in zdict.items():
                if abs(z) > total_clients:
                    continue
                iters += 1
                cur = pruned[rvec].get(z, float("inf"))
                if cost < cur:
                    pruned[rvec][z] = cost

        if not pruned:
            return pruned

        # 4) heuristic: truncate to top-K by cost
        return truncate_states(pruned)

    def dp_edge(node):
        """
        DP over edge (parent, node).
        Returns: Dict[rvec -> Dict[z -> cost]]
        """
        nonlocal iters
        if node is None:
            return defaultdict(dict)

        if node.parent is None:
            parent_dist = 0.0
        else:
            parent_dist = float(getattr(node, "parent_distance", 0.0))

        states = defaultdict(dict)

        # --------------------------------------------------------------
        # Leaf
        # --------------------------------------------------------------
        if getattr(node, "is_leaf", False):
            zero_rvec = tuple(0 for _ in range(t_ext))

            if getattr(node, "is_client", False):
                # one unit of demand flows upward
                states[zero_rvec][1] = parent_dist
                return prune_states(node, states)

            # Facility leaf
            # Option 1: facility not opened.
            states[zero_rvec][0] = 0.0

            # Option 2: facility opened.
            base_vec = getattr(node, "facility_type_vector", None)
            if base_vec is not None:
                if len(base_vec) != t_orig:
                    raise ValueError(
                        f"facility_type_vector length {len(base_vec)} "
                        f"!= t={t_orig} at node {node.node_number}"
                    )
                ext_vec = tuple(list(base_vec) + [1])

                cap = int(getattr(node, "capacity", 0))
                cmax = min(cap, total_clients)
                for c in range(0, cmax + 1):
                    z = -c      # -c units of surplus capacity downwards
                    cost = c * parent_dist
                    iters += 1
                    cur = states[ext_vec].get(z, float("inf"))
                    if cost < cur:
                        states[ext_vec][z] = cost

            return prune_states(node, states)

        # --------------------------------------------------------------
        # Internal node: combine child edge states and push upward.
        # --------------------------------------------------------------
        left_states = dp_edge(node.left) if node.left is not None else None
        right_states = dp_edge(node.right) if node.right is not None else None

        if left_states is None and right_states is None:
            return states  # should not happen for non-leaf

        # Only one child: just push its states across this edge.
        if left_states is None:
            for rvec, zdict in right_states.items():
                for z, cost in zdict.items():
                    if abs(z) > total_clients:
                        continue
                    new_rvec = rvec
                    new_z = z
                    new_cost = cost + abs(new_z) * parent_dist
                    iters += 1
                    cur = states[new_rvec].get(new_z, float("inf"))
                    if new_cost < cur:
                        states[new_rvec][new_z] = new_cost
            return prune_states(node, states)

        if right_states is None:
            for rvec, zdict in left_states.items():
                for z, cost in zdict.items():
                    if abs(z) > total_clients:
                        continue
                    new_rvec = rvec
                    new_z = z
                    new_cost = cost + abs(new_z) * parent_dist
                    iters += 1
                    cur = states[new_rvec].get(new_z, float("inf"))
                    if new_cost < cur:
                        states[new_rvec][new_z] = new_cost
            return prune_states(node, states)

        # Both children present: merge.
        LS = left_states
        RS = right_states

        # Iterate outer over smaller table for speed.
        if state_size(LS) > state_size(RS):
            LS, RS = RS, LS

        for l_rvec, l_zdict in LS.items():
            for r_rvec, r_zdict in RS.items():
                # Combine group vectors.
                new_r = [l_rvec[g] + r_rvec[g] for g in range(t_ext)]

                # Early upper-bound pruning.
                too_big = False
                for g in range(t_ext):
                    if new_r[g] > beta_ext[g]:
                        too_big = True
                        break
                if too_big:
                    continue

                new_rvec = tuple(new_r)

                for zl, cl in l_zdict.items():
                    for zr, cr in r_zdict.items():
                        new_z = zl + zr
                        if abs(new_z) > total_clients:
                            continue
                        new_cost = cl + cr + abs(new_z) * parent_dist
                        iters += 1
                        cur = states[new_rvec].get(new_z, float("inf"))
                        if new_cost < cur:
                            states[new_rvec][new_z] = new_cost

        # Free child tables.
        left_states.clear()
        right_states.clear()

        return prune_states(node, states)

    # ------------------------------------------------------------------
    # Run DP at root, then pick best root state.
    # ------------------------------------------------------------------
    root_states = dp_edge(root)

    best_cost = float("inf")
    best_rvec_ext = None

    for rvec_ext, zdict in root_states.items():
        # Require z = 0 at the virtual root edge.
        if 0 not in zdict:
            continue

        # Re-check original fair-range constraints (safety).
        ok = True
        for g in range(t_orig):
            if rvec_ext[g] < alpha_ext[g] or rvec_ext[g] > beta_ext[g]:
                ok = False
                break
        if not ok:
            continue

        # At most k centers on the extra group (safety).
        if rvec_ext[t_orig] > k:
            continue

        cost = zdict[0]
        if cost < best_cost:
            best_cost = cost
            best_rvec_ext = rvec_ext

    if best_rvec_ext is None:
        # No feasible state survived pruning.
        return float("inf"), None, None, iters

    best_group_counts = tuple(best_rvec_ext[:t_orig])  # per fairness group
    best_num_centers = int(best_rvec_ext[t_orig])      # total centers

    return best_cost, best_group_counts, best_num_centers, iters

def sanity_check():
    seed = 1234567890
    n = 255
    t = 3
    k = 5
    facility_prob = 0.5
    alpha = (1,) * t
    beta = (k,) * t
    max_capacity = int(n / k)
    max_states_per_edge = 10000
    
    root = build_binary_tree_with_facilities(
                n, t, facility_prob, max_capacity, seed
            )
    cost, group_counts, num_centers, iters = binary_tree_dp_pruning(root, n, t, k, alpha, beta,
                                                                    max_states_per_edge)

    print("Heuristic DP results:")
    print(f"  Cost: {cost}")
    print(f"  Group counts: {group_counts}")
    print(f"  Number of centers: {num_centers}")
    print(f"  Iterations: {iters}")

if __name__ == "__main__":
    sanity_check()
