import itertools
import math
from typing import Sequence, Dict, Any, Optional, List, Tuple

import numpy as np
import pandas as pd
from data import generate_random_facility_client_df


def _euclidean_distance_matrix_two_sets(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    """
    Simple Euclidean distance matrix between two sets of points.

    X: shape (n_x, d)
    Y: shape (n_y, d)

    Returns: shape (n_x, n_y)
    """
    diff = X[:, None, :] - Y[None, :, :]
    return np.sqrt(np.sum(diff * diff, axis=2))


class _MCFEdge:
    __slots__ = ("to", "rev", "cap", "cost")

    def __init__(self, to: int, rev: int, cap: int, cost: float):
        self.to = to
        self.rev = rev      # index of reverse edge in G[to]
        self.cap = cap
        self.cost = cost


def _mcf_add_edge(G: List[List[_MCFEdge]], fr: int, to: int, cap: int, cost: float) -> int:
    """
    Add directed edge fr -> to with capacity 'cap' and cost 'cost' to graph G.
    Also add the reverse edge with capacity 0 and cost -cost.

    Returns
    -------
    edge_index_in_fr : int
        Index of the forward edge in G[fr].
    """
    fwd = _MCFEdge(to=to, rev=len(G[to]), cap=cap, cost=cost)
    rev = _MCFEdge(to=fr, rev=len(G[fr]), cap=0, cost=-cost)
    G[fr].append(fwd)
    G[to].append(rev)
    return len(G[fr]) - 1


def _min_cost_max_flow(
    G: List[List[_MCFEdge]],
    s: int,
    t: int,
    required_flow: int,
) -> Tuple[int, float]:
    """
    Standard successive-shortest-augmenting-path min-cost max-flow.

    G is modified in place (capacities updated).
    Returns (flow, cost) where 'flow' <= required_flow and 'cost' is the
    min cost to send that much flow.
    """
    import heapq

    n = len(G)
    INF = 10**18

    flow = 0
    cost = 0.0
    h = [0.0] * n        # potentials
    prev_v = [0] * n
    prev_e = [0] * n

    while flow < required_flow:
        dist = [INF] * n
        dist[s] = 0.0
        pq: List[Tuple[float, int]] = [(0.0, s)]

        while pq:
            d, v = heapq.heappop(pq)
            if dist[v] < d:
                continue
            for i, e in enumerate(G[v]):
                if e.cap > 0:
                    nd = d + e.cost + h[v] - h[e.to]
                    if nd < dist[e.to]:
                        dist[e.to] = nd
                        prev_v[e.to] = v
                        prev_e[e.to] = i
                        heapq.heappush(pq, (nd, e.to))

        if dist[t] == INF:
            # Cannot send more flow
            break

        for v in range(n):
            if dist[v] < INF:
                h[v] += dist[v]

        # Augment along the path
        add_f = required_flow - flow
        v = t
        while v != s:
            e = G[prev_v[v]][prev_e[v]]
            add_f = min(add_f, e.cap)
            v = prev_v[v]

        flow += add_f
        cost += add_f * h[t]

        v = t
        while v != s:
            e = G[prev_v[v]][prev_e[v]]
            e.cap -= add_f
            rev = G[v][e.rev]
            rev.cap += add_f
            v = prev_v[v]

    return flow, cost


def brute_force_capacitated_k_median(
    df: pd.DataFrame,
    k: int,
    alpha: Sequence[int],
    beta: Sequence[int],
    feature_cols: Optional[Sequence[str]] = None,
    is_facility_col: str = "is_facility",
    capacity_col: str = "capacity",
    group_prefix: str = "group",
) -> Dict[str, Any]:
    """
    Brute-force exact solver for the capacitated k-median with integer
    alpha/beta group constraints on facilities.

    Assumptions / Model
    -------------------
    - df contains:
        * feature columns f1, f2, ... (used for Euclidean distance)
        * 'is_facility' (0/1)
        * 'capacity' (integer capacity for facilities; 0 for clients)
        * group columns 'group1', 'group2', ... (0/1) describing facility
          types / categories. A facility can have multiple groups = 1.

    - There are t groups (t = number of group columns).
      alpha and beta are integer vectors of length t with entries
      in {0, 1, ..., k} (practically 1..k as per your setup).

      For any subset S of open facilities (with |S| <= k):

          m_g(S) = # of open facilities in S with group g = 1

      must satisfy:

          alpha[g] <= m_g(S) <= beta[g]   for all groups g.

      Note that because facilities can be multi-group, one facility can
      contribute to multiple m_g(S).

    - Capacity constraint:
        sum_{i in S} capacity_i >= (#clients).

    - For each feasible subset S, we solve capacitated assignment via
      MIN-COST MAX-FLOW:

        Nodes:
          source -> clients -> facilities -> sink
        Capacities:
          source->client: 1
          client->facility: 1
          facility->sink: capacity_i
        Costs:
          client->facility: Euclidean distance between client and facility
          all other edges: 0

      We require flow = #clients. If achievable, the min-cost flow value
      is the exact capacitated k-median cost for subset S.

    - We return the globally optimal S (over all subsets of facilities of
      size at most k) and its cost, plus the assignment.

    Parameters
    ----------
    df : pd.DataFrame
        Data with features, facility flags, capacities, and group columns.
    k : int
        Maximum number of facilities to open (subset size is at most k).
    alpha, beta : Sequence[int]
        Integer vectors of length = number of group columns. alpha[g] is
        the lower bound and beta[g] is the upper bound on the number of
        selected facilities that belong to group g.
    feature_cols : sequence of str, optional
        Feature columns; if None, all columns starting with "f" are used.
    is_facility_col : str
        Column marking facilities (1) vs clients (0).
    capacity_col : str
        Column with facility capacities.
    group_prefix : str
        Prefix for group columns, e.g., "group" for group1, group2, ...

    Returns
    -------
    result : dict
        {
          "best_cost": float or math.inf,
          "best_subset_facility_indices": List[int],  # row indices in df
          "assignment": Dict[int, int],               # client_idx -> facility_idx
          "num_facilities_total": int,
          "num_clients_total": int,
          "num_subsets_examined": int,
          "feasible_subsets_examined": int,
          "alpha": list(alpha),
          "beta": list(beta),
        }
    """
    # --- Basic setup --------------------------------------------------------
    df = df.reset_index(drop=True)

    if feature_cols is None:
        feature_cols = [c for c in df.columns if c.startswith("f")]

    group_cols = [c for c in df.columns if c.startswith(group_prefix)]
    G = len(group_cols)
    if len(alpha) != G or len(beta) != G:
        raise ValueError(
            f"alpha and beta must have length {G} (number of group columns), "
            f"got len(alpha)={len(alpha)}, len(beta)={len(beta)}."
        )

    alpha = list(map(int, alpha))
    beta = list(map(int, beta))

    facilities_mask = df[is_facility_col].astype(bool).to_numpy()
    facility_indices = np.where(facilities_mask)[0]
    client_indices = np.where(~facilities_mask)[0]

    num_facilities_total = facility_indices.size
    num_clients_total = client_indices.size

    if num_facilities_total == 0:
        raise ValueError("No facilities in the dataframe (is_facility_col has no 1s).")
    if num_clients_total == 0:
        # Trivial: no clients, cost 0.
        return {
            "best_cost": 0.0,
            "best_subset_facility_indices": [],
            "assignment": {},
            "num_facilities_total": int(num_facilities_total),
            "num_clients_total": 0,
            "num_subsets_examined": 0,
            "feasible_subsets_examined": 0,
            "alpha": alpha,
            "beta": beta,
        }

    # Precompute feature matrices for distance calculations
    X = df.loc[:, feature_cols].to_numpy(dtype=float)
    X_clients = X[client_indices]             # shape (C, d)
    X_facilities_all = X[facility_indices]    # shape (F, d)

    capacities_all = df.loc[facility_indices, capacity_col].to_numpy(dtype=int)

    def facility_subset_satisfies_fairness(subset_indices: np.ndarray) -> bool:
        """
        Check integer alpha/beta constraints on a subset of facility row indices.

        subset_indices: array of df row indices for the chosen facilities (S).
        """
        m = subset_indices.size
        if m == 0:
            return False

        # Count facilities per group (multi-group membership allowed).
        subset_groups = df.loc[subset_indices, group_cols].to_numpy(dtype=int)  # (m, G)
        m_g = subset_groups.sum(axis=0)  # (G,)

        for g in range(G):
            if not (alpha[g] <= m_g[g] <= beta[g]):
                return False

        return True

    # --- Brute-force enumeration -------------------------------------------
    best_cost = math.inf
    best_subset: List[int] = []
    best_assignment: Dict[int, int] = {}
    num_subsets_examined = 0
    feasible_subsets_examined = 0

    max_subset_size = min(k, num_facilities_total)

    # Enumerate all subsets of facilities of size 1..max_subset_size
    for size in range(1, max_subset_size + 1):
        for comb in itertools.combinations(facility_indices, size):
            num_subsets_examined += 1
            subset_indices = np.array(comb, dtype=int)

            # Capacity check
            total_cap = int(df.loc[subset_indices, capacity_col].sum())
            if total_cap < num_clients_total:
                continue

            # Fairness check with integer alpha/beta
            if not facility_subset_satisfies_fairness(subset_indices):
                continue

            # Feasible wrt alpha-beta and capacity => do min-cost max-flow
            feasible_subsets_examined += 1

            # Build distance matrix from all clients to facilities in this subset
            # Map subset_indices (row indices in df) to positions among all facilities
            subset_positions = np.searchsorted(facility_indices, subset_indices)
            X_subset_facilities = X_facilities_all[subset_positions]  # shape (F_sub, d)
            D_cf = _euclidean_distance_matrix_two_sets(X_clients, X_subset_facilities)
            F_sub = subset_indices.size
            C = num_clients_total

            # Build min-cost max-flow graph
            # Nodes:
            #   0: source
            #   1..C: clients
            #   C+1..C+F_sub: facilities
            #   C+F_sub+1: sink
            N = C + F_sub + 2
            source = 0
            sink = C + F_sub + 1

            G_mcf: List[List[_MCFEdge]] = [[] for _ in range(N)]

            # source -> clients
            for ci in range(C):
                _mcf_add_edge(G_mcf, source, 1 + ci, cap=1, cost=0.0)

            # clients -> facilities (cost = distance)
            client_facility_edges: List[Tuple[int, int, int]] = []
            # (client_node, edge_index, facility_position)

            for ci in range(C):
                client_node = 1 + ci
                for fi in range(F_sub):
                    facility_node = 1 + C + fi
                    cost_ij = float(D_cf[ci, fi])
                    edge_idx = _mcf_add_edge(
                        G_mcf, client_node, facility_node, cap=1, cost=cost_ij
                    )
                    client_facility_edges.append((client_node, edge_idx, fi))

            # facilities -> sink
            subset_caps = df.loc[subset_indices, capacity_col].to_numpy(dtype=int)
            for fi in range(F_sub):
                facility_node = 1 + C + fi
                cap_i = int(subset_caps[fi])
                _mcf_add_edge(G_mcf, facility_node, sink, cap=cap_i, cost=0.0)

            # Run min-cost max-flow to assign all clients
            flow, cost = _min_cost_max_flow(G_mcf, source, sink, required_flow=C)

            if flow < C:
                # This subset cannot serve all clients respecting capacities
                continue

            # If we reach here, 'cost' is the exact k-median objective for this subset
            if cost < best_cost:
                best_cost = cost
                best_subset = list(subset_indices)

                # Recover client -> facility assignment from residual graph:
                assignment: Dict[int, int] = {}
                for client_node, edge_idx, fi in client_facility_edges:
                    e = G_mcf[client_node][edge_idx]
                    # Original capacity was 1; if cap == 0 now, this edge carried flow
                    if e.cap == 0:
                        # map client position and facility position back to df indices
                        ci = client_node - 1
                        client_df_idx = int(client_indices[ci])
                        facility_df_idx = int(subset_indices[fi])
                        assignment[client_df_idx] = facility_df_idx
                best_assignment = assignment

    return {
        "best_cost": best_cost,
        "best_subset_facility_indices": best_subset,
        "assignment": best_assignment,
        "num_facilities_total": int(num_facilities_total),
        "num_clients_total": int(num_clients_total),
        "num_subsets_examined": int(num_subsets_examined),
        "feasible_subsets_examined": int(feasible_subsets_examined),
        "alpha": alpha,
        "beta": beta,
    }

def test_stub():
    # Suppose df is generated by generate_random_facility_client_df(...)
    df = generate_random_facility_client_df(
        n_points=40,
        n_features=3,
        n_groups=3,
        facility_probability=0.5,
        max_capacity=10,
        seed=0,
    )

    # Suppose we want at most k = 3 centers,
    # and alpha/beta constraints on facility groups
    alpha = (1,) * 3  # at least 1 facility from each group
    beta  = (3,) * 3  # at most 3 facilities from each group

    result = brute_force_capacitated_k_median(
        df=df,
        k=3,
        alpha=alpha,
        beta=beta,
        feature_cols=[c for c in df.columns if c.startswith("f")],
    )

    print("Best cost:", result["best_cost"])
    print("Best facility subset (df row indices):", result["best_subset_facility_indices"])
    print("Assignment (client -> facility):", result["assignment"])

if __name__ == "__main__":
    test_stub()
