import numpy as np
from random import Random
from itertools import combinations
from typing import Dict, Tuple, List, Optional
from sys import stdout
import time

from binary_tree import binary_tree_pairwise_distances
from binary_tree import get_client_and_facility_node_ids
from binary_tree import build_binary_tree_with_facilities
from binary_tree import print_binary_tree
from binary_tree import get_facility_groups
from binary_tree import get_facility_capacities
from binary_tree import get_binary_tree_stats
from itertools import product


def _min_cost_flow_assignment(
    D_sub: np.ndarray,
    capacities: np.ndarray,
) -> Tuple[bool, float]:
    """
    Min-cost max-flow assignment for capacitated clustering.

    Graph:
        source -> clients -> facilities -> sink

    Every client has demand 1; each facility f has capacity capacities[f].
    Edge cost from client j to facility f is D_sub[j, f] (>= 0).

    Parameters
    ----------
    D_sub : (n_clients, k) array
        Distances from each client to each chosen facility in the combination.
    capacities : (k,) array
        Capacities of each chosen facility.

    Returns
    -------
    feasible : bool
        True if all clients can be assigned to facilities.
    total_cost : float
        Minimum total cost if feasible; +inf otherwise.
    """
    import heapq

    n_clients, k = D_sub.shape
    demand = n_clients
    if int(capacities.sum()) < demand:
        return False, float("inf")

    # Node indexing:
    #   s = 0
    #   clients: 1 .. n_clients
    #   facilities: n_clients+1 .. n_clients+k
    #   t = n_clients + k + 1
    s = 0
    client_offset = 1
    facility_offset = 1 + n_clients
    t = 1 + n_clients + k
    N = t + 1

    # Residual graph
    # each edge: (to, rev_index, cap, cost)
    g: List[List[Dict]] = [[] for _ in range(N)]

    def add_edge(u: int, v: int, cap: int, cost: float) -> None:
        g[u].append({"to": v, "rev": len(g[v]), "cap": cap, "cost": cost})
        g[v].append({"to": u, "rev": len(g[u]) - 1, "cap": 0, "cost": -cost})

    # source -> clients (capacity 1)
    for j in range(n_clients):
        add_edge(s, client_offset + j, 1, 0.0)

    # clients -> facilities
    for j in range(n_clients):
        for f in range(k):
            c = float(D_sub[j, f])
            add_edge(client_offset + j, facility_offset + f, 1, c)

    # facilities -> sink
    for f in range(k):
        cap = int(capacities[f])
        if cap > 0:
            add_edge(facility_offset + f, t, cap, 0.0)

    # successive shortest augmenting path with Johnson potentials
    flow = 0
    total_cost = 0.0
    INF = float("inf")
    potential = [0.0] * N  # Johnson's potentials

    while flow < demand:
        dist = [INF] * N
        prev_v = [-1] * N
        prev_e = [-1] * N
        dist[s] = 0.0

        # Dijkstra on reduced costs
        pq = [(0.0, s)]
        while pq:
            d, v = heapq.heappop(pq)
            if d > dist[v] + 1e-15:
                continue
            for ei, e in enumerate(g[v]):
                if e["cap"] <= 0:
                    continue
                w = e["to"]
                nd = d + e["cost"] + potential[v] - potential[w]
                if nd < dist[w] - 1e-15:
                    dist[w] = nd
                    prev_v[w] = v
                    prev_e[w] = ei
                    heapq.heappush(pq, (nd, w))

        if dist[t] == INF:
            # cannot push more flow
            break

        # Update potentials
        for v in range(N):
            if dist[v] < INF:
                potential[v] += dist[v]

        # Augment (will always be 1 unit due to client edges cap=1)
        add_flow = demand - flow
        v = t
        while v != s:
            e = g[prev_v[v]][prev_e[v]]
            add_flow = min(add_flow, e["cap"])
            v = prev_v[v]
        if add_flow <= 0:
            break

        v = t
        path_cost = 0.0
        while v != s:
            e = g[prev_v[v]][prev_e[v]]
            e["cap"] -= add_flow
            rev = g[v][e["rev"]]
            rev["cap"] += add_flow
            path_cost += e["cost"]
            v = prev_v[v]

        flow += add_flow
        total_cost += path_cost * add_flow

    feasible = (flow == demand)
    if not feasible:
        return False, float("inf")
    return True, total_cost

def _min_cost_flow_assignment_v2_old(
    D_sub: np.ndarray,
    capacities: np.ndarray,
) -> Tuple[bool, float, Optional[np.ndarray]]:
    """
    Min-cost max-flow assignment for capacitated clustering.

    Graph:
        source -> clients -> facilities -> sink

    Every client has demand 1; each facility f has capacity capacities[f].
    Edge cost from client j to facility f is D_sub[j, f] (>= 0).

    Parameters
    ----------
    D_sub : (n_clients, k) array
        Distances from each client to each chosen facility in the combination.
    capacities : (k,) array
        Capacities of each chosen facility.

    Returns
    -------
    feasible : bool
        True if all clients can be assigned to facilities.
    total_cost : float
        Minimum total cost if feasible; +inf otherwise.
    assignment : np.ndarray or None
        If feasible, an array of shape (n_clients,) where assignment[j] is
        the *local* index of the facility in 0..k-1 to which client j is assigned.
        None if infeasible.
    """
    import heapq

    n_clients, k = D_sub.shape
    demand = n_clients
    if int(capacities.sum()) < demand:
        return False, float("inf"), None

    # Node indexing:
    #   s = 0
    #   clients: 1 .. n_clients
    #   facilities: n_clients+1 .. n_clients+k
    #   t = n_clients + k + 1
    s = 0
    client_offset = 1
    facility_offset = 1 + n_clients
    t = 1 + n_clients + k
    N = t + 1

    # Residual graph
    # each edge: (to, rev_index, cap, cost)
    g: List[List[Dict]] = [[] for _ in range(N)]

    def add_edge(u: int, v: int, cap: int, cost: float) -> None:
        g[u].append({"to": v, "rev": len(g[v]), "cap": cap, "cost": cost})
        g[v].append({"to": u, "rev": len(g[u]) - 1, "cap": 0, "cost": -cost})

    # source -> clients (capacity 1)
    for j in range(n_clients):
        add_edge(s, client_offset + j, 1, 0.0)

    # clients -> facilities
    for j in range(n_clients):
        for f in range(k):
            c = float(D_sub[j, f])
            add_edge(client_offset + j, facility_offset + f, 1, c)

    # facilities -> sink
    for f in range(k):
        cap = int(capacities[f])
        if cap > 0:
            add_edge(facility_offset + f, t, cap, 0.0)

    # successive shortest augmenting path with Johnson potentials
    flow = 0
    total_cost = 0.0
    INF = float("inf")
    potential = [0.0] * N  # Johnson's potentials

    while flow < demand:
        dist = [INF] * N
        prev_v = [-1] * N
        prev_e = [-1] * N
        dist[s] = 0.0

        # Dijkstra on reduced costs
        pq = [(0.0, s)]
        while pq:
            d, v = heapq.heappop(pq)
            if d > dist[v] + 1e-15:
                continue
            for ei, e in enumerate(g[v]):
                if e["cap"] <= 0:
                    continue
                w = e["to"]
                nd = d + e["cost"] + potential[v] - potential[w]
                if nd < dist[w] - 1e-15:
                    dist[w] = nd
                    prev_v[w] = v
                    prev_e[w] = ei
                    heapq.heappush(pq, (nd, w))

        if dist[t] == INF:
            # cannot push more flow
            break

        # Update potentials
        for v in range(N):
            if dist[v] < INF:
                potential[v] += dist[v]

        # Augment (will always be 1 unit due to client edges cap=1)
        add_flow = demand - flow
        v = t
        while v != s:
            e = g[prev_v[v]][prev_e[v]]
            add_flow = min(add_flow, e["cap"])
            v = prev_v[v]
        if add_flow <= 0:
            break

        v = t
        path_cost = 0.0
        while v != s:
            e = g[prev_v[v]][prev_e[v]]
            e["cap"] -= add_flow
            rev = g[v][e["rev"]]
            rev["cap"] += add_flow
            path_cost += e["cost"]
            v = prev_v[v]

        flow += add_flow
        total_cost += path_cost * add_flow

    feasible = (flow == demand)
    if not feasible:
        return False, float("inf"), None

    # Recover assignment: for each client node, find which facility edge is saturated.
    assignment = np.full(n_clients, -1, dtype=int)
    for j in range(n_clients):
        u = client_offset + j
        for e in g[u]:
            to = e["to"]
            # Original client->facility edges had cap=1; after flow, used edges have cap=0
            if facility_offset <= to < facility_offset + k and e["cap"] == 0:
                f_loc = to - facility_offset
                assignment[j] = f_loc
                break

    if np.any(assignment < 0):
        # Shouldn't happen if flow == demand, but be safe
        return False, float("inf"), None

    return True, total_cost, assignment

def _min_cost_flow_assignment_v2(
    D_sub: np.ndarray,
    capacities: np.ndarray,
) -> Tuple[bool, float, Optional[np.ndarray]]:
    """
    Solve the minimum-cost assignment of clients to facilities with capacities
    using a standard min-cost max-flow algorithm.

    Network:
        source -> clients -> facilities -> sink

    Each client j has unit demand. Each facility f has capacity capacities[f].
    The cost of assigning client j to facility f is D_sub[j, f] (assumed >= 0).

    Parameters
    ----------
    D_sub : np.ndarray
        2D array of shape (n_clients, k_facilities) with nonnegative costs.
    capacities : np.ndarray
        1D array of shape (k_facilities,) with integer capacities.

    Returns
    -------
    feasible : bool
        True iff it is possible to assign every client to some facility without
        exceeding capacities.
    total_cost : float
        The minimum possible total assignment cost if feasible, otherwise inf.
    assignment : Optional[np.ndarray]
        An array of length n_clients where assignment[j] is the index (0..k-1)
        of the facility serving client j (indexing columns of D_sub).
        None if infeasible.
    """
    D_sub = np.asarray(D_sub, dtype=float)
    capacities = np.asarray(capacities, dtype=int)

    n_clients, k = D_sub.shape
    demand = n_clients

    # Node indexing:
    #   0                          : source
    #   1 .. n_clients             : clients
    #   n_clients+1 .. n_clients+k : facilities
    #   n_clients+k+1              : sink
    N = 1 + n_clients + k + 1
    s = 0
    client_offset = 1
    facility_offset = 1 + n_clients
    t = N - 1

    g: List[List[Dict]] = [[] for _ in range(N)]

    def add_edge(fr: int, to: int, cap: int, cost: float) -> None:
        """Add directed edge fr->to with given capacity and cost."""
        g[fr].append({"to": to, "rev": len(g[to]), "cap": cap, "cost": float(cost)})
        g[to].append({"to": fr, "rev": len(g[fr]) - 1, "cap": 0, "cost": -float(cost)})

    # source -> clients (capacity 1, zero cost)
    for j in range(n_clients):
        add_edge(s, client_offset + j, 1, 0.0)

    # clients -> facilities (capacity 1, cost = distance)
    for j in range(n_clients):
        for f in range(k):
            c = float(D_sub[j, f])
            add_edge(client_offset + j, facility_offset + f, 1, c)

    # facilities -> sink (capacity = facility capacity, zero cost)
    for f in range(k):
        cap = int(capacities[f])
        if cap > 0:
            add_edge(facility_offset + f, t, cap, 0.0)

    # Successive shortest augmenting path with Johnson potentials.
    INF = float("inf")
    flow = 0
    total_cost = 0.0
    potential = [0.0] * N

    import heapq

    while flow < demand:
        dist = [INF] * N
        prev_v = [-1] * N
        prev_e = [-1] * N
        dist[s] = 0.0

        pq: List[Tuple[float, int]] = [(0.0, s)]
        while pq:
            d, v = heapq.heappop(pq)
            if d > dist[v]:
                continue
            for ei, e in enumerate(g[v]):
                if e["cap"] <= 0:
                    continue
                to = e["to"]
                nd = d + e["cost"] + potential[v] - potential[to]
                if nd < dist[to]:
                    dist[to] = nd
                    prev_v[to] = v
                    prev_e[to] = ei
                    heapq.heappush(pq, (nd, to))

        if dist[t] == INF:
            # Cannot send more flow; assignment infeasible.
            break

        # Update potentials
        for v in range(N):
            if dist[v] < INF:
                potential[v] += dist[v]

        # Find bottleneck on path (here it will be 1, but keep generic)
        add_flow = demand - flow
        v = t
        while v != s:
            pv = prev_v[v]
            pe = prev_e[v]
            if pv < 0 or pe < 0:
                add_flow = 0
                break
            e = g[pv][pe]
            add_flow = min(add_flow, e["cap"])
            v = pv

        if add_flow == 0:
            break

        # Apply flow and accumulate original costs (without potentials)
        v = t
        path_cost = 0.0
        while v != s:
            pv = prev_v[v]
            pe = prev_e[v]
            e = g[pv][pe]
            e["cap"] -= add_flow
            rev = g[v][e["rev"]]
            rev["cap"] += add_flow
            path_cost += e["cost"]
            v = pv

        flow += add_flow
        total_cost += path_cost * add_flow

    feasible = (flow == demand)
    if not feasible:
        return False, float("inf"), None

    # Recover assignment from residual graph
    assignment = np.full(n_clients, -1, dtype=int)
    for j in range(n_clients):
        u = client_offset + j
        for e in g[u]:
            to = e["to"]
            if facility_offset <= to < facility_offset + k and e["cap"] == 0:
                f_loc = to - facility_offset
                assignment[j] = f_loc
                break

    if np.any(assignment < 0):
        return False, float("inf"), None

    return True, total_cost, assignment

def binary_tree_brute_force_old(root, k, lower_bounds, upper_bounds):
    """
    Brute-force k-median with group constraints on a binary tree metric.

    Parameters
    ----------
    root : BinTreeNode
        Root of the binary tree.
    k : int
        Number of facilities to select.
    lower_bounds : sequence of int
        Lower bounds for each group.
    upper_bounds : sequence of int
        Upper bounds for each group.

    Returns
    -------
    best_combination : tuple of int
        Node IDs of the chosen facilities (size k).
    best_cost : float
        Sum of distances from each client to its nearest chosen facility.
    """
    # Pairwise distances between *all* nodes
    node_ids, dist_matrix = binary_tree_pairwise_distances(root)
    dist_matrix = np.asarray(dist_matrix, dtype=float)

    # Classify nodes and get groups
    client_ids, facility_ids = get_client_and_facility_node_ids(root)
    facility_groups_dict = get_facility_groups(root)  # Dict[int, Tuple[int,...]]
    facility_caps_dict = get_facility_capacities(root)  # Dict[int, int]

    # Map node_id -> index in dist_matrix (row/col)
    id_to_idx = {nid: i for i, nid in enumerate(node_ids)}

    # Indices of clients and facilities in the distance matrix
    client_idx = np.array([id_to_idx[cid] for cid in client_ids], dtype=int)
    facility_idx = np.array([id_to_idx[fid] for fid in facility_ids], dtype=int)

    # Distances from each client to each facility: shape (n_clients, n_facilities)
    D_cf = dist_matrix[np.ix_(client_idx, facility_idx)]

    # Group matrix G: shape (n_facilities, t)
    if facility_groups_dict:
        t = len(next(iter(facility_groups_dict.values())))
    else:
        t = 0

    G = np.zeros((len(facility_ids), t), dtype=int)
    for j, fid in enumerate(facility_ids):
        vec = facility_groups_dict.get(fid, (0,) * t)
        G[j, :] = np.asarray(vec, dtype=int)

    # Capacities vector for all facilities (align with facility_ids)
    capacities = np.array(
        [facility_caps_dict.get(fid, 0) for fid in facility_ids],
        dtype=int,
    )

    lower_bounds = np.asarray(lower_bounds, dtype=int)
    upper_bounds = np.asarray(upper_bounds, dtype=int)

    n_clients = len(client_ids)

    best_cost = float("inf")
    best_combination = None
    best_assignment = None

    # Iterate over *indices* of facilities, not node IDs
    for comb in combinations(range(len(facility_ids)), k):
        # print("Evaluating combination:", comb)
        comb_idx = np.array(comb, dtype=int)

        # Sum group vectors for facilities in the combination
        group_counts = G[comb_idx, :].sum(axis=0)

        # Check group constraints
        if np.any(group_counts < lower_bounds) or np.any(group_counts > upper_bounds):
           continue
        # Distances from each client to chosen facilities, then min over facilities

        # Capacities for these facilities
        caps_sub = capacities[comb_idx]
        if int(caps_sub.sum()) < n_clients:
            # Can't serve all clients
            continue

        D_sub = D_cf[:, comb_idx]               # shape (n_clients, k)
        feasible, total_cost, assignment = _min_cost_flow_assignment_v2(D_sub, caps_sub)
        if not feasible:
            continue

        if total_cost < best_cost:
            best_cost = total_cost
            best_combination = tuple(facility_ids[i] for i in comb)
            best_assignment = assignment

    return best_cost, best_combination, best_assignment

def binary_tree_brute_force(
    root,
    k: int,
    lower_bounds,
    upper_bounds,
):
    """
    Brute-force solver for capacitated fair-range k-median on a tree metric.

    Parameters
    ----------
    root : BinTreeNode
        Root of the binary tree.
    k : int
        Maximum number of facilities allowed to be selected.
    lower_bounds : sequence of int
        Lower bounds for each group.
    upper_bounds : sequence of int
        Upper bounds for each group.

    Returns
    -------
    best_cost : float
        Optimal clustering cost.
    best_centers : Tuple[int, ...]
        Node numbers of the facilities that are opened.
    best_assignment : np.ndarray
        Array of length n_clients; entry j is the index (0..len(best_centers)-1)
        of the facility in ``best_centers`` serving client j.
    """
    # Pairwise distances between *all* nodes
    node_ids, dist_matrix = binary_tree_pairwise_distances(root)
    dist_matrix = np.asarray(dist_matrix, dtype=float)

    # Classify nodes and get groups / capacities
    client_ids, facility_ids = get_client_and_facility_node_ids(root)
    facility_groups_dict = get_facility_groups(root)        # Dict[int, Tuple[int,...]]
    facility_caps_dict = get_facility_capacities(root)      # Dict[int, int]

    # Map node_id -> index in dist_matrix
    id_to_idx = {nid: i for i, nid in enumerate(node_ids)}

    # Indices of clients and facilities in the distance matrix
    client_idx = np.array([id_to_idx[cid] for cid in client_ids], dtype=int)
    facility_idx = np.array([id_to_idx[fid] for fid in facility_ids], dtype=int)

    # Distances from each client to each facility: shape (n_clients, n_facilities)
    D_cf = dist_matrix[np.ix_(client_idx, facility_idx)]

    # Group matrix G: shape (n_facilities, t)
    if facility_groups_dict:
        t = len(next(iter(facility_groups_dict.values())))
    else:
        t = 0

    G = np.zeros((len(facility_ids), t), dtype=int)
    for j, fid in enumerate(facility_ids):
        vec = facility_groups_dict.get(fid, (0,) * t)
        G[j, :] = np.asarray(vec, dtype=int)

    # Capacities vector for all facilities (align with facility_ids)
    capacities = np.array(
        [facility_caps_dict.get(fid, 0) for fid in facility_ids],
        dtype=int,
    )

    lower_bounds = np.asarray(lower_bounds, dtype=int)
    upper_bounds = np.asarray(upper_bounds, dtype=int)

    n_clients = len(client_ids)

    best_cost = float("inf")
    best_combination = None
    best_assignment = None

    # Iterate over *indices* of facilities, not node IDs.
    # We allow selecting *at most* k centers.
    n_fac = len(facility_ids)
    max_centers = min(k, n_fac)

    for num_centers in range(1, max_centers + 1):
        for comb in combinations(range(n_fac), num_centers):
            comb_idx = np.array(comb, dtype=int)

            # Sum group vectors for facilities in the combination
            group_counts = G[comb_idx, :].sum(axis=0)

            # Check group constraints
            if np.any(group_counts < lower_bounds) or np.any(group_counts > upper_bounds):
                continue

            # Capacities for these facilities
            caps_sub = capacities[comb_idx]
            if int(caps_sub.sum()) < n_clients:
                # Can't serve all clients
                continue

            # Distances from each client to chosen facilities
            D_sub = D_cf[:, comb_idx]  # shape (n_clients, num_centers)

            feasible, total_cost, assignment = _min_cost_flow_assignment_v2(D_sub, caps_sub)
            if not feasible:
                continue

            if total_cost < best_cost:
                best_cost = total_cost
                # Map facility indices back to node numbers
                best_combination = tuple(facility_ids[i] for i in comb)
                best_assignment = assignment

    return best_cost, best_combination, best_assignment

def sanity_check(init_seed = 123456789):
    local_random = Random()
    local_random.seed(init_seed)

    nn = [15, 31, 63, 127]
    tt = [2, 3]
    kk = [2, 3, 4]
    facility_probability = 0.5

    header = (
        f"{'input_seed':>10s}{'n':>4s} {'t':>5s} {'k':>5s} "
        f"{'n_c':>4s} {'n_f':>4s} {'cap':>4s} "
        f"{'input_time':>10s} {'time':>10s} {'cost':>10s}"
    )

    stdout.write(header + "\n")
    stdout.write("-" * len(header) + "\n")

    for n, t, k in product(nn, tt, kk):
        start_time = time.time()
        input_seed = local_random.randint(1, 123456789)
        root = build_binary_tree_with_facilities(n, t, facility_probability, input_seed)

        input_time = time.time() - start_time
        binarty_tree_stats = get_binary_tree_stats(root)

        if binarty_tree_stats['num_facilities'] < k:
            continue

        lower_bounds = [1] * t
        upper_bounds = [k] * t
        
        brute_force_start = time.time()
        cost, centers, assignment = binary_tree_brute_force(root, k, lower_bounds, upper_bounds)
        brute_force_time = time.time() - brute_force_start

        stdout.write(
            f"{input_seed:10d}"
            f"{n:4d} {t:5d} {k:5d} "
            f"{binarty_tree_stats['num_clients']:4d} {binarty_tree_stats['num_facilities']:4d} "
            f"{binarty_tree_stats['total_capacity']:4d} "
            f"{input_time:10.4f} {brute_force_time:10.4f} "
            f"{cost:10.4f}\n"
        )
        stdout.flush()
    stdout.write("-" * len(header) + "\n")

def test_stub():
    n = 15
    t = 2
    k = 3
    input_seed = 86087334
    facility_probability = 0.5
    root = build_binary_tree_with_facilities(n, t, facility_probability, input_seed)

    print_binary_tree(root)

    lower_bounds = [1] * t
    upper_bounds = [k] * t

    cost, centers, assignment = binary_tree_brute_force(root, k, lower_bounds, upper_bounds)

    print("Chosen facilities:", centers)
    print("Total cost:", cost)


if __name__ == "__main__":
    sanity_check()
