import math
from typing import List, Tuple, Dict, Optional

import numpy as np

from binary_tree import (
    binary_tree_pairwise_distances,
    get_client_and_facility_node_ids,
    get_facility_groups,
    get_facility_capacities,
    build_binary_tree_with_facilities,
)

from binary_tree_brute_force import _min_cost_flow_assignment_v2

def baseline_greedy_solution(
    root,
    k: int,
    alpha: List[int],
    beta: List[int],
) -> Tuple[float, Tuple[int, ...], np.ndarray]:
    """
    Baseline greedy heuristic that returns a feasible solution if the
    instance (k, alpha, beta, capacities) is feasible.

    Strategy:
        1. Build client-facility distance matrix D_cf.
        2. Greedily select facilities to satisfy:
             - fair-range lower bounds (alpha),
             - capacity >= #clients,
             - at most k centers,
             - without violating upper bounds beta (if possible).
        3. Run capacitated min-cost assignment on the chosen facilities.

    Returns
    -------
    cost : float
        Total assignment cost of the greedy solution.
    chosen_facilities : Tuple[int, ...]
        Node IDs of facilities that were opened.
    assignment : np.ndarray
        Array of length n_clients; assignment[j] is in {0, ..., m-1}
        where m = len(chosen_facilities), giving which facility serves
        client j in the chosen set.
    """
    # 1. Distances between all nodes; then get client/facility subsets
    node_ids, dist_matrix = binary_tree_pairwise_distances(root)
    dist_matrix = np.asarray(dist_matrix, dtype=float)
    id_to_idx = {nid: i for i, nid in enumerate(node_ids)}

    client_ids, facility_ids = get_client_and_facility_node_ids(root)
    n_clients = len(client_ids)
    n_fac = len(facility_ids)

    if n_clients == 0 or n_fac == 0:
        # Degenerate cases: no clients or no facilities
        return 0.0, tuple(), np.empty((0,), dtype=int)

    client_idx = np.array([id_to_idx[cid] for cid in client_ids], dtype=int)
    facility_idx = np.array([id_to_idx[fid] for fid in facility_ids], dtype=int)

    # Distances from each client to each facility
    D_cf = dist_matrix[np.ix_(client_idx, facility_idx)]  # (n_clients, n_fac)

    # Facility groups and capacities
    facility_groups_dict = get_facility_groups(root)       # {fid: tuple length t}
    facility_caps_dict = get_facility_capacities(root)     # {fid: int}

    t = len(alpha)
    alpha = np.asarray(alpha, dtype=int)
    beta = np.asarray(beta, dtype=int)

    G = np.zeros((n_fac, t), dtype=int)
    capacities = np.zeros(n_fac, dtype=int)
    for j, fid in enumerate(facility_ids):
        vec = facility_groups_dict.get(fid, (0,) * t)
        G[j, :] = np.asarray(vec, dtype=int)
        capacities[j] = int(facility_caps_dict.get(fid, 0))

    # Quick sanity: if even opening all facilities can't satisfy alpha/beta,
    # or total capacity < n_clients, the instance is infeasible.
    total_groups = G.sum(axis=0)
    if np.any(total_groups < alpha) or np.any(total_groups > beta + 1e9):
        # (The >beta check is only meaningful if you explicitly want to forbid
        #  any solution with group counts above beta; otherwise skip.)
        pass

    if int(capacities.sum()) < n_clients:
        raise ValueError("Instance capacity is infeasible: total capacity < #clients.")

    # ----------------------------------------------------------------------
    # Phase 1: greedily add facilities to satisfy alpha and total capacity.
    # ----------------------------------------------------------------------
    chosen = []                # indices into facility_ids
    chosen_mask = np.zeros(n_fac, dtype=bool)
    group_counts = np.zeros(t, dtype=int)
    total_cap = 0

    # Precompute simple facility scores: average distance to clients
    avg_dist = np.mean(D_cf, axis=0)  # shape (n_fac,)

    def can_add_facility(idx: int) -> bool:
        """Check if we can add facility idx without violating beta or k."""
        if chosen_mask[idx]:
            return False
        if len(chosen) >= k:
            return False

        new_groups = group_counts + G[idx, :]
        if np.any(new_groups > beta):
            return False
        return True

    # Keep adding as long as some alpha or capacity requirement is unmet
    while True:
        need_groups = np.any(group_counts < alpha)
        need_capacity = (total_cap < n_clients)
        if not need_groups and not need_capacity:
            break

        best_idx = None
        best_score = float("inf")

        for j in range(n_fac):
            if chosen_mask[j]:
                continue
            # Candidate must help with capacity or some unmet group
            helps_group = False
            for g in range(t):
                if group_counts[g] < alpha[g] and G[j, g] > 0:
                    helps_group = True
                    break
            helps_capacity = (total_cap < n_clients and capacities[j] > 0)

            if not (helps_group or helps_capacity):
                continue

            if not can_add_facility(j):
                continue

            # Simple score: average distance to clients (smaller is better)
            score = avg_dist[j]
            if score < best_score:
                best_score = score
                best_idx = j

        if best_idx is None:
            # No facility can be added without violating constraints. In a
            # well-formed instance this should not happen; otherwise the
            # instance is infeasible w.r.t (k, alpha, beta, capacities).
            return math.inf, tuple(), np.empty((0,), dtype=int), False
            raise RuntimeError(
                "Greedy baseline could not satisfy alpha/capacity within k and beta."
            )

        # Add the chosen facility
        chosen.append(best_idx)
        chosen_mask[best_idx] = True
        group_counts += G[best_idx, :]
        total_cap += capacities[best_idx]

    chosen = sorted(chosen)
    chosen_facilities = tuple(facility_ids[j] for j in chosen)
    caps_sub = capacities[chosen]
    D_sub = D_cf[:, chosen]  # (n_clients, len(chosen))

    # ----------------------------------------------------------------------
    # Final capacitated assignment via min-cost flow
    # ----------------------------------------------------------------------
    feasible, total_cost, assignment = _min_cost_flow_assignment_v2(D_sub, caps_sub)
    if not feasible:
        # This should not happen if capacity_sum >= n_clients, but be safe.
        return math.inf, tuple(), np.empty((0,), dtype=int), False
        raise RuntimeError("Baseline min-cost flow deemed the instance infeasible.")

    return float(total_cost), chosen_facilities, assignment, True

from collections import defaultdict

def _prune_z_dominance(zdict: Dict[int, float]) -> Dict[int, float]:
    """
    For a fixed group vector r, given states {z -> cost}, keep only those
    that are non-dominated in the (|z|, cost) plane.

    Dominance: (z1, c1) dominates (z2, c2) if |z1| <= |z2| and c1 <= c2.
    """
    if not zdict:
        return zdict

    items = sorted(zdict.items(), key=lambda x: (abs(x[0]), x[1]))
    pruned: Dict[int, float] = {}
    best_cost = float("inf")

    for z, c in items:
        # If we have previously seen some z' with |z'| <= |z| and cost <= c,
        # then (z, c) is dominated.
        if c >= best_cost:
            continue
        pruned[z] = c
        best_cost = c

    return pruned


def binary_tree_dp_pruned(
    root,
    n: int,
    t: int,
    k: int,
    alpha,
    beta,
    use_z_cap: bool = True,
    Z_MAX: int = 50,
):
    """
    Pruned variant of the exact tree DP for capacitated fair-range k-median.

    - Encodes "at most k centers" as an extra group (t_ext = t+1).
    - Performs exact logical pruning (group bounds, outside-groups, |z| <= #clients).
    - Additionally applies z-dominance pruning per group vector r.
    - Optionally caps |z| at Z_MAX to keep the flow dimension bounded (heuristic).

    This DP is used as an *improvement* heuristic; it may fail to find any
    feasible root state (cost = +inf). We rely on a separate greedy baseline
    to guarantee feasibility overall.

    Returns
    -------
    best_cost : float
        Best DP cost found (may be +inf if no feasible state survives).
    best_rvec_ext : Optional[Tuple[int, ...]]
        The corresponding group-count vector (length t+1), or None.
    iters : int
        Rough count of inner-loop iterations (for diagnostics).
    """
    if root is None:
        return 0.0, (0,) * (t + 1), 0

    from binary_tree import get_depth, get_nodes_at_level, get_node_by_number

    # ----------------------------------------------------------------------
    # Extend groups: extra "k-group" that counts total centers.
    # ----------------------------------------------------------------------
    t_orig = t
    t_ext = t_orig + 1

    alpha_ext = tuple(alpha) + (0,)   # no lower bound on total centers
    beta_ext = tuple(beta) + (k,)     # at most k centers globally

    # ----------------------------------------------------------------------
    # Collect nodes and build (node_number -> node) map.
    # ----------------------------------------------------------------------
    nodes = []
    stack = [root]
    seen = set()
    while stack:
        node = stack.pop()
        if node is None or node.node_number in seen:
            continue
        seen.add(node.node_number)
        nodes.append(node)
        if node.left is not None:
            stack.append(node.left)
        if node.right is not None:
            stack.append(node.right)

    id2node = {node.node_number: node for node in nodes}

    # ----------------------------------------------------------------------
    # Subtree statistics: #clients and max possible group counts per subtree.
    # ----------------------------------------------------------------------
    subtree_clients: Dict[int, int] = {}
    subtree_groups: Dict[int, List[int]] = {}

    def dfs_subtree(u_id: int):
        node = id2node[u_id]
        clients = 1 if getattr(node, "is_client", False) else 0
        groups = [0] * t_ext

        if node.left is not None:
            c_l, g_l = dfs_subtree(node.left.node_number)
            clients += c_l
            for j in range(t_ext):
                groups[j] += g_l[j]

        if node.right is not None:
            c_r, g_r = dfs_subtree(node.right.node_number)
            clients += c_r
            for j in range(t_ext):
                groups[j] += g_r[j]

        # Facilities contribute their group vector; clients do not.
        if not getattr(node, "is_client", False) and getattr(node, "capacity", 0) > 0:
            base_vec = getattr(node, "facility_type_vector", None)
            if base_vec is not None:
                if len(base_vec) != t_orig:
                    raise ValueError(
                        f"facility_type_vector length {len(base_vec)} "
                        f"!= t={t_orig} at node {node.node_number}"
                    )
                ext_vec = list(base_vec) + [1]  # extra group counts total centers
                for j in range(t_ext):
                    groups[j] += int(ext_vec[j])

        subtree_clients[u_id] = clients
        subtree_groups[u_id] = groups
        return clients, groups

    total_clients, total_groups = dfs_subtree(root.node_number)
    total_groups = list(total_groups)

    # outside_groups[u][g] = facilities of group g that can still be opened
    # outside the subtree rooted at u.
    outside_groups: Dict[int, List[int]] = {
        u_id: [max(0, total_groups[g] - subtree_groups[u_id][g]) for g in range(t_ext)]
        for u_id in subtree_groups
    }

    iters = 0

    def prune_states(node, states: Dict[Tuple[int, ...], Dict[int, float]]):
        """
        Logical pruning + z-dominance; no beam truncation.
        """
        nonlocal iters
        u_id = node.node_number
        outside = outside_groups[u_id]

        pruned: Dict[Tuple[int, ...], Dict[int, float]] = defaultdict(dict)

        for rvec, zdict in states.items():
            # 1) group upper bounds (including extra k-group)
            if any(rvec[g] > beta_ext[g] for g in range(t_ext)):
                continue

            # 2) feasibility of reaching lower bounds with outside groups
            feasible = True
            for g in range(t_orig):
                if rvec[g] + outside[g] < alpha_ext[g]:
                    feasible = False
                    break
            if not feasible:
                continue

            # 3) keep best cost for each z with |z| <= total_clients (or Z_MAX if capping)
            cap_bound = Z_MAX if use_z_cap else total_clients
            local_zdict: Dict[int, float] = {}
            for z, cost in zdict.items():
                if abs(z) > cap_bound:
                    continue
                iters += 1
                cur = local_zdict.get(z, float("inf"))
                if cost < cur:
                    local_zdict[z] = cost

            if not local_zdict:
                continue

            # 4) z-dominance pruning for this rvec
            pruned[rvec] = _prune_z_dominance(local_zdict)

        return pruned

    def dp_edge(node):
        """
        DP on edge (parent, node).
        Returns: Dict[rvec -> Dict[z -> cost]]
        """
        nonlocal iters
        if node is None:
            return defaultdict(dict)

        # Edge length to parent
        if node.parent is None:
            parent_dist = 0.0
        else:
            parent_dist = float(getattr(node, "parent_distance", 0.0))

        states: Dict[Tuple[int, ...], Dict[int, float]] = defaultdict(dict)

        # --------------------------------------------------------------
        # Leaf node
        # --------------------------------------------------------------
        if getattr(node, "is_leaf", False):
            zero_rvec = tuple(0 for _ in range(t_ext))

            if getattr(node, "is_client", False):
                # One unit of demand goes upward
                states[zero_rvec][1] = parent_dist
                return prune_states(node, states)

            # Facility leaf
            # Option 1: not opened
            states[zero_rvec][0] = 0.0

            # Option 2: opened
            base_vec = getattr(node, "facility_type_vector", None)
            if base_vec is not None:
                if len(base_vec) != t_orig:
                    raise ValueError(
                        f"facility_type_vector length {len(base_vec)} "
                        f"!= t={t_orig} at node {node.node_number}"
                    )
                ext_vec = tuple(list(base_vec) + [1])
                cap = int(getattr(node, "capacity", 0))
                cmax = min(cap, total_clients)
                for c in range(0, cmax + 1):
                    z = -c
                    cost = c * parent_dist
                    iters += 1
                    cur = states[ext_vec].get(z, float("inf"))
                    if cost < cur:
                        states[ext_vec][z] = cost

            return prune_states(node, states)

        # --------------------------------------------------------------
        # Internal node: combine children
        # --------------------------------------------------------------
        left_states = dp_edge(node.left) if node.left is not None else None
        right_states = dp_edge(node.right) if node.right is not None else None

        if left_states is None and right_states is None:
            return states  # should not happen

        if left_states is None:
            # Only right child
            for rvec, zdict in right_states.items():
                for z, cost in zdict.items():
                    # Optional z cap
                    if use_z_cap and abs(z) > Z_MAX:
                        z = Z_MAX if z > 0 else -Z_MAX
                    new_rvec = rvec
                    new_z = z
                    new_cost = cost + abs(new_z) * parent_dist
                    iters += 1
                    cur = states[new_rvec].get(new_z, float("inf"))
                    if new_cost < cur:
                        states[new_rvec][new_z] = new_cost
            return prune_states(node, states)

        if right_states is None:
            # Only left child
            for rvec, zdict in left_states.items():
                for z, cost in zdict.items():
                    if use_z_cap and abs(z) > Z_MAX:
                        z = Z_MAX if z > 0 else -Z_MAX
                    new_rvec = rvec
                    new_z = z
                    new_cost = cost + abs(new_z) * parent_dist
                    iters += 1
                    cur = states[new_rvec].get(new_z, float("inf"))
                    if new_cost < cur:
                        states[new_rvec][new_z] = new_cost
            return prune_states(node, states)

        # Two children: merge smaller table into larger for efficiency
        LS, RS = left_states, right_states
        size_L = sum(len(zdict) for zdict in LS.values())
        size_R = sum(len(zdict) for zdict in RS.values())
        if size_L > size_R:
            LS, RS = RS, LS

        for l_rvec, l_zdict in LS.items():
            for r_rvec, r_zdict in RS.items():
                # Combine group counts
                new_r = [l_rvec[g] + r_rvec[g] for g in range(t_ext)]
                # Early upper-bound pruning
                if any(new_r[g] > beta_ext[g] for g in range(t_ext)):
                    continue
                new_rvec = tuple(new_r)

                for zl, cl in l_zdict.items():
                    for zr, cr in r_zdict.items():
                        new_z = zl + zr
                        if use_z_cap and abs(new_z) > Z_MAX:
                            new_z = Z_MAX if new_z > 0 else -Z_MAX
                        new_cost = cl + cr + abs(new_z) * parent_dist
                        iters += 1
                        cur = states[new_rvec].get(new_z, float("inf"))
                        if new_cost < cur:
                            states[new_rvec][new_z] = new_cost

        # Optionally free child tables for memory (not strictly necessary in Python)
        left_states.clear()
        right_states.clear()

        return prune_states(node, states)

    # ----------------------------------------------------------------------
    # Run DP from root
    # ----------------------------------------------------------------------
    root_states = dp_edge(root)

    best_cost = float("inf")
    best_rvec_ext = None

    for rvec_ext, zdict in root_states.items():
        # Require z = 0 at virtual root edge
        if 0 not in zdict:
            continue

        # Safety check: fair-range for original groups
        feasible = True
        for g in range(t_orig):
            if rvec_ext[g] < alpha_ext[g] or rvec_ext[g] > beta_ext[g]:
                feasible = False
                break
        if not feasible:
            continue

        # Safety: at most k centers via extra group
        if rvec_ext[t_orig] > k:
            continue

        cost = zdict[0]
        if cost < best_cost:
            best_cost = cost
            best_rvec_ext = rvec_ext

    return best_cost, best_rvec_ext, iters

def solve_tree_kmedian_fast(
    root,
    n: int,
    t: int,
    k: int,
    alpha,
    beta,
    use_dp: bool = True,
    use_z_cap: bool = True,
    Z_MAX: int = 50,
):
    """
    Combined solver:

      1. Compute a greedy baseline solution (always feasible if the instance
         is feasible).
      2. Run the pruned DP as an improvement heuristic (optional).
      3. Return the baseline solution as the concrete clustering, plus
         diagnostic information about the DP cost.

    Parameters
    ----------
    root : BinTreeNode
        Root of the binary tree.
    n : int
        Number of clients (for consistency with your existing API).
    t : int
        Number of fairness groups.
    k : int
        Maximum number of facilities to open.
    alpha, beta : sequences of int, length t
        Fair-range lower and upper bounds.
    use_dp : bool
        If True, run the pruned DP as an improvement heuristic.
    use_z_cap : bool
        If True, apply z-capping inside the DP.
    Z_MAX : int
        Maximum |z| used by the DP when z-capping is enabled.

    Returns
    -------
    solution : dict
        {
          "cost": float,                 # cost of the returned clustering
          "centers": Tuple[int, ...],    # facility node IDs
          "assignment": np.ndarray,      # length n_clients in [0, len(centers)-1]
          "baseline_cost": float,
          "dp_cost": float,              # +inf if DP found no feasible state
          "dp_rvec_ext": Optional[Tuple[int, ...]],
          "dp_iters": int,
          "used": str,                   # "baseline" or "baseline+dp"
        }
    """

    # 1) Greedy baseline (always feasible if instance is feasible)
    base_cost, base_centers, base_assignment, feasible = baseline_greedy_solution(
        root, k, alpha, beta
    )

    dp_cost = float("inf")
    dp_rvec_ext = None
    dp_iters = 0

    # 2) Optional pruned DP refinement
    if use_dp:
        dp_cost, dp_rvec_ext, dp_iters = binary_tree_dp_pruned(
            root,
            n,
            t,
            k,
            alpha,
            beta,
            use_z_cap=use_z_cap,
            Z_MAX=Z_MAX,
        )

    # 3) For now, we *return the baseline clustering* as the concrete solution.
    #    DP cost is only used as a diagnostic lower bound / improvement signal.
    used = "baseline+dp" if use_dp else "baseline"

    return {
        "cost": float(base_cost),
        "centers": base_centers,
        "assignment": base_assignment,
        "baseline_cost": float(base_cost),
        "dp_cost": float(dp_cost),
        "dp_rvec_ext": dp_rvec_ext,
        "dp_iters": dp_iters,
        "used": used,
        "feasible": feasible,
    }

def test_stub():
    n = 31
    t = 3
    k = 4
    alpha = [1] * t
    beta = [k] * t
    seed = 86087334

    root = build_binary_tree_with_facilities(n, t, facility_probability=0.5, max_facility_types=2,
                                             seed=seed)

    result = solve_tree_kmedian_fast(root, n, t, k, alpha, beta,
                                     use_dp=True, use_z_cap=True, Z_MAX=50)

    print("Feasible:", result["feasible"])
    print("Returned cost:", result["cost"])
    print("Baseline cost:", result["baseline_cost"])
    print("DP best cost (possibly lower bound):", result["dp_cost"])
    print("Opened centers:", result["centers"])

if __name__ == "__main__":
    test_stub()
