# from sklearn_extra.cluster import KMedoids
from typing import Optional, Dict, Any, List, Union
import numpy as np
from numpy.linalg import eigh
from scipy.linalg import eigh
from scipy.sparse import issparse
from scipy.sparse.csgraph import laplacian
from sklearn.cluster import KMeans

from init_SVT import init_SVT
from losses import compute_acc_ari_nmi


def to_one_hot(labels: np.ndarray, k: int) -> np.ndarray:
    labels = np.asarray(labels, dtype=int).ravel()
    N = labels.shape[0]
    one_hot = np.zeros((N, k), dtype=int)
    one_hot[np.arange(N), labels] = 1
    return one_hot

def _argmax_with_random_tie_breaking(scores: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    max_vals = scores.max(axis=1, keepdims=True)
    ties = (scores == max_vals)
    idx = np.empty(scores.shape[0], dtype=int)
    for i, row in enumerate(ties):
        candidates = np.where(row)[0]
        idx[i] = rng.choice(candidates)
    return idx

def local_refinement_by_neighbors(
    A: np.ndarray,
    pred_labels: np.ndarray,
    num_classes: int,
    random_state: Optional[int] = None,
    alpha: float = 1e-6
) -> np.ndarray:
    rng = np.random.default_rng(random_state)

    # --- 基本校验 ---
    A = np.asarray(A, dtype=float)
    if A.ndim == 3 and A.shape[0] == 1:
        A = A[0]
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError(f"A must be square (N,N), got {A.shape}")
    N = A.shape[0]
    # 保守: 去自环，避免自连影响统计
    if np.any(np.diag(A) != 0):
        A = A.copy()
        np.fill_diagonal(A, 0.0)

    y = np.asarray(pred_labels).reshape(-1).astype(int)
    if y.shape[0] != N or y.min() < 0 or y.max() >= num_classes:
        raise ValueError("pred_labels shape/value error")

    # --- 统计 ---
    H = to_one_hot(y, num_classes).astype(float)      # (N, k)
    deg_to_comm = A @ H                                # (N, k)
    comm_sizes  = H.sum(axis=0, keepdims=True)        # (1, k)
    comm_sizes_excl_self = comm_sizes - H             # (N, k)  每个节点都减去自己

    # --- 概率打分（带平滑） ---
    probs = (deg_to_comm + alpha) / (comm_sizes_excl_self + 2.0 * alpha)  # (N, k)

    # --- 先做一次常规选择（含并列随机打破） ---
    refined = _argmax_with_random_tie_breaking(probs, rng)                # (N,)

    # --- 规则：若自己所在社区排除自己后为 0，则强制保持原标签 ---
    # own_sizes_excl[i] = comm_sizes_excl_self[i, y[i]]
    own_sizes_excl = comm_sizes_excl_self[np.arange(N), y]
    stay_mask = (own_sizes_excl == 0)
    refined[stay_mask] = y[stay_mask]

    return refined

def local_refinement_by_neighbors_multi(
    A: np.ndarray,
    init_labels: np.ndarray,
    num_classes: int,
    *,
    true_labels: Optional[Union[np.ndarray, "torch.Tensor"]] = None,  # 可选：用于计算 acc/ARI/NMI
    num_iters: int = 5,
    alpha: float = 1e-6,
    random_state: Optional[int] = None,
    tol: int = 0,
    verbose: bool = True,
    return_history: bool = True,
):
    """
    多步 local refinement（同步更新）。每一步都基于“上一轮的整图标签”重新统计并更新。

    参数
    ----
    A : (N,N) 或 (1,N,N) 的邻接矩阵（numpy）
    init_labels : (N,) 初始标签（0..K-1）
    num_classes : K
    true_labels : (可选) 用于计算 acc/ARI/NMI；可为 torch.LongTensor[(1,N)] 或 numpy[(N,)]
    num_iters : 最大迭代步数
    alpha : 平滑系数（与单步版一致）
    random_state : 随机种子（用于并列打破）
    tol : 若某步变化的节点数 <= tol，则提前停止
    verbose : 打印每步指标
    return_history : 返回完整历史

    返回
    ----
    result: dict
      - "labels": 最终标签 (N,)
      - "iter_done": 实际完成的轮数
      - "stopped_early": 是否因 tol 提前停止
      - "history": list[dict]（若 return_history=True），每步包含：
            {"iter": i, "labels": (N,), "n_changed": int,
             "acc": float, "ari": float, "nmi": float}  # 若提供了 true_labels
    """
    rng = np.random.default_rng(random_state)

    # --- 规范 A ---
    A = np.asarray(A, dtype=float)
    if A.ndim == 3 and A.shape[0] == 1:
        A = A[0]
    if A.ndim != 2 or A.shape[0] != A.shape[1]:
        raise ValueError(f"A must be square (N,N), got {A.shape}")
    N = A.shape[0]
    if np.any(np.diag(A) != 0):
        A = A.copy()
        np.fill_diagonal(A, 0.0)

    # --- 规范 labels ---
    y = np.asarray(init_labels).reshape(-1).astype(int)
    if y.shape[0] != N or y.min() < 0 or y.max() >= num_classes:
        raise ValueError("init_labels shape/value error")

    # --- 处理 true_labels（可选） ---
    use_metrics = True if true_labels is not None else False
    tl_torch = None
    if use_metrics:
        try:
            import torch
            if isinstance(true_labels, np.ndarray):
                tl_torch = torch.tensor(true_labels.reshape(1, -1), dtype=torch.long)
            elif isinstance(true_labels, torch.Tensor):
                tl_torch = true_labels
            else:
                raise TypeError("true_labels must be numpy.ndarray or torch.Tensor")
        except Exception as e:
            # 如果 torch 不可用或失败，则不算指标
            use_metrics = False
            tl_torch = None
            if verbose:
                print(f"[warn] metrics disabled ({e})")

    history: List[Dict[str, Any]] = []
    stopped_early = False

    for it in range(1, num_iters + 1):
        # 同步更新：基于当前 y 做“一步” local refinement
        new_y = local_refinement_by_neighbors(
            A=A,
            pred_labels=y,
            num_classes=num_classes,
            random_state=rng.integers(1<<31) if random_state is not None else None,
            alpha=alpha
        )

        n_changed = int(np.sum(new_y != y))

        rec: Dict[str, Any] = {"iter": it, "labels": new_y.copy(), "n_changed": n_changed}

        # 计算指标（若传入了 true_labels 且工程里有 compute_acc_ari_nmi）
        if use_metrics:
            try:
                # 你的工程里函数签名：compute_acc_ari_nmi(pred_labels_np, true_labels_torch, n_classes)
                acc, best_pred, ari, nmi = compute_acc_ari_nmi(new_y, tl_torch, num_classes)
                rec.update({"acc": float(acc), "ari": float(ari), "nmi": float(nmi)})
                if verbose:
                    print(f"[Iter {it}] acc={acc:.4f}, ari={ari:.4f}, nmi={nmi:.4f}, changed={n_changed}")
            except Exception as e:
                # 如果你项目里的 compute_acc_ari_nmi 不可用，就跳过指标
                if verbose:
                    print(f"[Iter {it}] changed={n_changed} (metrics unavailable: {e})")
        else:
            if verbose:
                print(f"[Iter {it}] changed={n_changed}")

        if return_history:
            history.append(rec)

        y = new_y  # 同步到下一轮输入

        # 提前停止
        if n_changed <= tol:
            stopped_early = True
            break

    result = {
        "labels": y,
        "iter_done": it,
        "stopped_early": stopped_early
    }
    if return_history:
        result["history"] = history
    return result


def _cluster_from_U(U, k):
    """
    用 sklearn 的 KMeans 聚类
    返回 labels (n,)
    """
    U = np.asarray(U, dtype=float)
    km = KMeans(n_clusters=k, n_init=20, max_iter=300, tol=1e-6, random_state=42)
    labels = km.fit_predict(U)
    return labels


def spectral_clustering_adj(A, k, true_labels, normalized: bool = False, *, run_all: bool = False,
                            random_state: int = 0):
    """
    一口气跑三种谱聚类（normalized / unnormalized / adjacency）+ local refinement + SVT初始化，
    同时兼容旧接口（默认返回与旧版一致的6个指标）。

    旧接口（与 test_single_first_period 兼容）：
        spectral_clustering_adj(A, k, true_labels, normalized=False)
        -> (acc_sc, ari_sc, nmi_sc, acc_ref, ari_ref, nmi_ref)

    新用法（一次性获取三种方法的结果）：
        spectral_clustering_adj(A, k, true_labels, run_all=True)
        -> {
            'normalized':   {'sc': (acc,ari,nmi), 'refined': (acc,ari,nmi)},
            'unnormalized': {'sc': (acc,ari,nmi), 'refined': (acc,ari,nmi)},
            'adjacency':    {'sc': (acc,ari,nmi), 'refined': (acc,ari,nmi)},
            'svt':          {'sc': (acc,ari,nmi), 'refined': (acc,ari,nmi)},
        }
    依赖：laplacian, eigh, KMeans, compute_acc_ari_nmi, local_refinement_by_neighbors, init_SVT
    """

    # --- 输入规范化 ---
    # 支持 torch.Tensor / 稀疏 / (1,N,N) / (N,N)
    try:
        import torch
        if isinstance(A, torch.Tensor):
            A = A.detach().cpu().numpy()
        if isinstance(true_labels, torch.Tensor):
            true_labels = true_labels.detach().cpu().numpy()
    except Exception:
        pass

    if A.ndim == 3 and A.shape[0] == 1:
        A = A.squeeze(0)
    if issparse(A):
        A = A.toarray()
    A = np.asarray(A, dtype=np.float64)

    # 确保对称
    A = 0.5 * (A + A.T)

    def _run_normalized():
        L = laplacian(A, normed=True).astype(np.float64)
        w, V = eigh(L)  # 取最小 k 个
        U = V[:, np.argsort(w)[:k]]
        # Ng–Jordan–Weiss 行归一化
        U = U / (np.linalg.norm(U, axis=1, keepdims=True) + 1e-12)
        labels = _cluster_from_U(U, k)
        acc_sc, best_pred, ari_sc, nmi_sc = compute_acc_ari_nmi(labels, true_labels, k)
        refined = local_refinement_by_neighbors(A, best_pred, k)
        acc_ref, _, ari_ref, nmi_ref = compute_acc_ari_nmi(refined, true_labels, k)
        return (acc_sc, ari_sc, nmi_sc), (acc_ref, ari_ref, nmi_ref)

    def _run_unnormalized():
        L = laplacian(A, normed=False).astype(np.float64)
        w, V = eigh(L)  # 取最小 k 个
        U = V[:, np.argsort(w)[:k]]
        labels = _cluster_from_U(U, k)
        acc_sc, best_pred, ari_sc, nmi_sc = compute_acc_ari_nmi(labels, true_labels, k)
        refined = local_refinement_by_neighbors(A, best_pred, k)
        acc_ref, _, ari_ref, nmi_ref = compute_acc_ari_nmi(refined, true_labels, k)
        return (acc_sc, ari_sc, nmi_sc), (acc_ref, ari_ref, nmi_ref)

    def _run_adjacency():
        # 全谱分解（稠密）
        w, V = eigh(A)

        # 按 |λ| 从大到小取前 k 个
        idx = np.argsort(np.abs(w))[-k:]
        idx = idx[np.argsort(np.abs(w[idx]))[::-1]]
        U = V[:, idx]

        # 后续聚类 + 评估
        labels = _cluster_from_U(U, k)
        acc_sc, best_pred, ari_sc, nmi_sc = compute_acc_ari_nmi(labels, true_labels, k)

        refined = local_refinement_by_neighbors(A, best_pred, k)
        acc_ref, _, ari_ref, nmi_ref = compute_acc_ari_nmi(refined, true_labels, k)

        return (acc_sc, ari_sc, nmi_sc), (acc_ref, ari_ref, nmi_ref)

    if run_all:
        sc_n, ref_n = _run_normalized()
        sc_u, ref_u = _run_unnormalized()
        sc_a, ref_a = _run_adjacency()
        return {
            'normalized': {'sc': sc_n, 'refined': ref_n},
            'unnormalized': {'sc': sc_u, 'refined': ref_u},
            'adjacency': {'sc': sc_a, 'refined': ref_a},
        }
    else:
        # 兼容旧接口：按 normalized 标志选择单一路径并返回 6 元组
        if normalized:
            sc, ref = _run_normalized()
        else:
            sc, ref = _run_unnormalized()
        acc_sc, ari_sc, nmi_sc = sc
        acc_ref, ari_ref, nmi_ref = ref
        return acc_sc, ari_sc, nmi_sc, acc_ref, ari_ref, nmi_ref
