from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
from random import Random

def _compute_best_two(
    D: np.ndarray,
    centers: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    For each client, compute:
      - best_facility[j] : index of facility with minimum distance,
      - best_dist[j]     : its distance,
      - second_dist[j]   : distance to second-best center (for swaps).

    Parameters
    ----------
    D : array (n_f, n_c)
        Distance matrix.
    centers : array (k,)
        Facility indices currently chosen as centers.

    Returns
    -------
    best_facility : array (n_c,)
        Facility index (0..n_f-1) serving each client.
    best_dist : array (n_c,)
        Distance to the best facility.
    second_dist : array (n_c,)
        Distance to the second-best facility among current centers.
        If k == 1, second_dist == best_dist.
    """
    n_f, n_c = D.shape
    k = len(centers)
    if k == 0:
        raise ValueError("Need at least one center")

    # Distances from current centers to all clients: shape (k, n_c)
    dist_to_centers = D[centers, :]  # k x n_c

    # Sort centers by distance for each client (smallest first)
    sorted_idx = np.argsort(dist_to_centers, axis=0)  # k x n_c

    # Indices (row in dist_to_centers) of best & second best center
    best_rows = sorted_idx[0]          # (n_c,)
    best_facility = centers[best_rows] # (n_c,)
    best_dist = dist_to_centers[best_rows, np.arange(n_c)]

    if k > 1:
        second_rows = sorted_idx[1]
        second_dist = dist_to_centers[second_rows, np.arange(n_c)]
    else:
        second_dist = best_dist.copy()

    return best_facility, best_dist, second_dist


def _local_search_k_facility(
    D: np.ndarray,
    k: int,
    power: float,
    max_swaps: int ,
    seed: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Local-search (1-swap) algorithm for discrete k-median / k-means
    with candidate centers restricted to facilities.

    Objective:
      - k-median: power = 1  → minimize sum_j dist(f(j), j)
      - k-means : power = 2  → minimize sum_j dist(f(j), j)^2

    Parameters
    ----------
    D : array (n_f, n_c)
        Distance matrix, D[i, j] is distance from facility i to client j.
    k : int
        Number of centers to open (must be <= n_f).
    power : float
        1.0 for k-median, 2.0 for k-means (discrete).
    max_swaps : int, default 10
        Maximum number of improving swaps to perform.
    seed : int or None
        Random seed for initial center selection.

    Returns
    -------
    centers : array (k,)
        Facility indices selected as centers (0-based).
    assignment : array (n_c,)
        For each client j, the facility index f_j ∈ centers serving it.
    dist_to_center : array (n_c,)
        For each client j, the distance D[f_j, j].
    """
    rng = np.random.default_rng(seed)
    n_f, n_c = D.shape

    if k <= 0 or k > n_f:
        raise ValueError(f"k must be between 1 and n_f = {n_f}, got k={k}")

    # --- Initial centers: random k facilities ---
    centers = rng.choice(n_f, size=k, replace=False)
    centers = np.array(centers, dtype=int)

    # Track which facilities are centers
    is_center = np.zeros(n_f, dtype=bool)
    is_center[centers] = True

    # Initial assignment and cost
    assignment, best_dist, second_dist = _compute_best_two(D, centers)
    current_cost = float(np.sum(best_dist ** power))

    # --- Local search with 1-swap ---
    for it in range(max_swaps):
        improved = False
        best_swap = None
        best_swap_cost = current_cost

        # Precompute this once per outgoing center
        for out_pos, out_fac in enumerate(centers):
            # Which clients are currently assigned to this outgoing center?
            assigned_to_out = (assignment == out_fac)

            # If no clients use this center, we can skip most checks,
            # but it's still legal to swap it out.
            for in_fac in range(n_f):
                if is_center[in_fac]:
                    continue  # already a center

                # Distances to the candidate incoming center
                d_in = D[in_fac, :]  # shape (n_c,)

                # Start with "keep old center" for each client
                # then update where swap matters.
                new_dist = np.minimum(d_in, best_dist)

                if assigned_to_out.any():
                    # For clients that used 'out_fac', their old best
                    # was dist(out_fac, j). After removing 'out_fac',
                    # the best among remaining centers is second_dist[j].
                    # Then we also consider the new center in_fac.
                    mask = assigned_to_out
                    new_dist[mask] = np.minimum(d_in[mask], second_dist[mask])

                new_cost = float(np.sum(new_dist ** power))

                if new_cost < best_swap_cost - 1e-8:
                    improved = True
                    best_swap_cost = new_cost
                    best_swap = (out_pos, out_fac, in_fac)

            # Optional: first-improvement strategy (apply first found swap)
            # If you prefer best-improvement, remove this break.
            # if improved:
            #     break

        if not improved or best_swap is None:
            break  # local optimum

        # --- Apply best swap ---
        out_pos, out_fac, in_fac = best_swap

        # Update centers and is_center
        centers[out_pos] = in_fac
        is_center[out_fac] = False
        is_center[in_fac] = True

        # Recompute assignment and cost for the new center set
        assignment, best_dist, second_dist = _compute_best_two(D, centers)
        current_cost = best_swap_cost

    # At the end, best_dist are the actual distances (not powered).
    return centers, assignment, best_dist, current_cost

def k_median_clustering(
    D: np.ndarray,
    k: int,
    max_swaps: int,
    seed: int
):
    """
    Approximate k-median clustering with centers restricted to facilities.

    Parameters
    ----------
    D : array (n_f, n_c)
        Distance matrix, D[i, j] is distance from facility i to client j.
    k : int
        Number of centers to select (k <= n_f).
    max_swaps : int, default 10
        Maximum number of improving swaps in local search.
    seed : int or None
        Random seed for initial centers.

    Returns
    -------
    centers : array (k,)
        Selected facility indices (0-based, rows of X_facilities).
    client_facility : array (n_c,)
        For each client j (column of X_clients), index of facility i
        that serves it.
    client_distance : array (n_c,)
        For each client j, D[client_facility[j], j].
    """
    # centers, assignment, dist, cost = _local_search_k_facility(
    #    D=D,
    #    k=k,
    #    power=1.0,  # k-median objective
    #    max_swaps=max_swaps,
    #    seed=seed,
    # )
    # return centers, assignment, dist
    l_random = Random()
    l_random.seed(seed)

    no_of_trials = 10 # making 10 trials to avoid bad local minima
    f_centers = None
    f_assignment = None
    f_dist = None
    f_cost = float('inf')
    for i in range(no_of_trials):
        local_search_seed = l_random.randint(0, 1000000)
        centers, assignment, dist, cost = _local_search_k_facility(
            D=D,
            k=k,
            power=2.0,  # k-means (squared distances)
            max_swaps=max_swaps,
            seed=local_search_seed
        )
        if f_cost > cost:
            f_cost = cost
            f_centers = centers
            f_assignment = assignment
            f_dist = dist

    return f_centers, f_assignment, f_dist   

def k_means_clustering(
    D: np.ndarray,
    k: int,
    max_swaps: int,
    seed: int
):
    """
    Approximate (discrete) k-means clustering with centers restricted
    to facilities and cost = sum_j dist^2.

    Parameters
    ----------
    D : array (n_f, n_c)
        Distance matrix, D[i, j] is distance from facility i to client j.
    k : int
        Number of centers to select (k <= n_f).
    max_swaps : int, default 10
        Maximum number of improving swaps in local search.
    seed : int or None
        Random seed for initial centers.

    Returns
    -------
    centers : array (k,)
        Selected facility indices (0-based, rows of X_facilities).
    client_facility : array (n_c,)
        For each client j (column of X_clients), index of facility i
        that serves it.
    client_distance : array (n_c,)
        For each client j, D[client_facility[j], j] (plain distance).
        The k-means objective uses these squared, but we return the
        unsquared distances for convenience.
    """
    l_random = Random()
    l_random.seed(seed)

    no_of_trials = 10 # making 10 trials to avoid bad local minima
    f_centers = None
    f_assignment = None
    f_dist = None
    f_cost = float('inf')
    for i in range(no_of_trials):
        local_search_seed = l_random.randint(0, 1000000)
        centers, assignment, dist, cost = _local_search_k_facility(
            D=D,
            k=k,
            power=2.0,  # k-means (squared distances)
            max_swaps=max_swaps,
            seed=local_search_seed
        )
        if f_cost > cost:
            f_cost = cost
            f_centers = centers
            f_assignment = assignment
            f_dist = dist

    return f_centers, f_assignment, f_dist
