import numpy as np
import torch
import torch.nn.functional as F
from sklearn.cluster import SpectralClustering
from eval_utils import cluster_metric

dataset = "CIFAR-10"
num_prompts = 7

grid_k = [30]
grid_temp = [0.04]

mu = 0.1
lam = 10
T_outer = 5
T_diffuse = 8
TA_cd = 4
seed = 0

USE_PNTK_FAST = True

if dataset in ("CIFAR-10", "STL-10", "ImageNet-10"):
    cluster_num = 10
elif dataset == "CIFAR-20":
    cluster_num = 20
elif dataset == "ImageNet-Dogs":
    cluster_num = 15
elif dataset == "DTD":
    cluster_num = 47
elif dataset == "UCF101":
    cluster_num = 101
elif dataset == "ImageNet":
    cluster_num = 1000
else:
    raise NotImplementedError(f"Unknown dataset {dataset}")

torch.manual_seed(seed)
np.random.seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"


def load_image_embeddings_and_labels(dataset):
    X = np.load(f"./data/{dataset}_image_embedding_test.npy")
    X = X / np.linalg.norm(X, axis=1, keepdims=True)
    y = np.loadtxt(f"./data/{dataset}_labels_test.txt", dtype=float).astype(int)
    return torch.from_numpy(X).to(device).float(), y

def load_embeddings(path):
    emb = np.load(path)
    emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)
    return torch.from_numpy(emb).to(device).float()


@torch.no_grad()
def build_knn_graph_from_similarity(K_np, k):
    N = K_np.shape[0]
    K = K_np.copy()
    np.fill_diagonal(K, 0.0)
    idx = np.argpartition(-K, kth=min(k, N-1)-1, axis=1)[:, :k]
    rows = np.repeat(np.arange(N), idx.shape[1])
    cols = idx.reshape(-1)
    vals = K[rows, cols]
    W = np.zeros_like(K, dtype=np.float32)
    W[rows, cols] = vals
    W = np.maximum(W, W.T)
    np.nan_to_num(W, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
    W = np.maximum(W, 0.0)
    return W

def transition_from_W(W):
    d = W.sum(axis=1)
    zero_mask = d <= 0
    if zero_mask.any():
        W = W.copy()
        W[zero_mask, :] = 0.0
        W[:, zero_mask] = 0.0
        d = W.sum(axis=1)
    Dm12 = 1.0 / np.sqrt(np.maximum(d, 1.0))
    S = (Dm12[:, None] * W) * Dm12[None, :]
    S[zero_mask, :] = 0.0
    S[:, zero_mask] = 0.0
    return S


@torch.no_grad()
def compute_pntk_matrix_fast(images_emb, nouns_emb, temp, Kx_precomputed=None):
    if Kx_precomputed is None:
        Kx = (images_emb @ images_emb.t())
    else:
        Kx = Kx_precomputed
    logits = (images_emb @ nouns_emb.t()) / temp
    S = torch.softmax(logits, dim=1)
    Ks = S @ S.t()
    K = (Kx * Ks) / (temp * temp)
    return K.detach().cpu().float().numpy()

def compute_pntk_matrix_autograd(images_emb, nouns_emb, temp):
    N = images_emb.shape[0]
    grads = []
    for j in range(N):
        if nouns_emb.grad is not None:
            nouns_emb.grad.zero_()
        x = images_emb[j : j+1]
        logits = (x @ nouns_emb.t()) / temp
        g = torch.logsumexp(logits, dim=1).sum()
        grad = torch.autograd.grad(g, nouns_emb, create_graph=False, retain_graph=False)[0]
        grads.append(grad.flatten().detach().cpu())
    G = torch.stack(grads, dim=0)
    K = (G @ G.t()).numpy().astype(np.float32)
    return K


def compute_Hv(A, S):
    SAS = S @ A @ S.T
    return float(np.sum(A*A) - np.sum(A * SAS))

def project_to_simplex(v):
    v = np.maximum(v, 0.0)
    if v.sum() == 0:
        return np.ones_like(v) / len(v)
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    cond = u > (cssv - 1) / (np.arange(len(u)) + 1)
    if not cond.any():
        return np.ones_like(v) / len(v)
    rho = np.where(cond)[0][-1]
    theta = (cssv[rho] - 1) / (rho + 1.0)
    w = np.maximum(v - theta, 0.0)
    return w

def update_beta_coordinate_descent(beta, H, lam, passes=1):
    M = len(beta)
    for _ in range(passes):
        for i in range(M):
            for j in range(i+1, M):
                bij = beta[i] + beta[j]
                num = lam * bij + (H[j] - H[i])
                bi_new = num / (2.0 * lam)
                bi_new = min(max(bi_new, 0.0), 1.0)
                bj_new = bij - bi_new
                bj_new = min(max(bj_new, 0.0), 1.0)
                beta[i], beta[j] = bi_new, bj_new
        beta = project_to_simplex(beta)
    return beta


def run_red_once(images_emb, labels, k_val, temp_val):
    Kx = (images_emb @ images_emb.t()) if USE_PNTK_FAST else None

    S_list = []
    N = images_emb.shape[0]
    for i in range(num_prompts):
        path = f"./data/{dataset}_nouns_embedding_prompt_{i}_selected.npy"
        nouns_emb = load_embeddings(path)
        nouns_emb.requires_grad_(not USE_PNTK_FAST)

        print(f"[Prompt {i+1}/{num_prompts}] computing pNTK (temp={temp_val})...")
        if USE_PNTK_FAST:
            K_v = compute_pntk_matrix_fast(images_emb, nouns_emb, temp_val, Kx_precomputed=Kx)
        else:
            K_v = compute_pntk_matrix_autograd(images_emb, nouns_emb, temp_val)

        print(f"[Prompt {i+1}/{num_prompts}] building kNN graph (k={k_val})...")
        W_v = build_knn_graph_from_similarity(K_v, k=k_val)
        S_v = transition_from_W(W_v)
        S_list.append(S_v.astype(np.float32))

        del nouns_emb, K_v, W_v
        torch.cuda.empty_cache()

    A = np.eye(N, dtype=np.float32)
    beta = np.ones(num_prompts, dtype=np.float32) / num_prompts

    for Tout in range(T_outer):
        sumbeta = float(beta.sum())
        alpha = beta / (mu + sumbeta)
        alpha_sum = float(alpha.sum())
        print(f"[Outer {Tout+1}/{T_outer}] sum(alpha)={alpha_sum:.4f}")

        I = np.eye(N, dtype=np.float32)
        A_prev = A
        for t in range(T_diffuse):
            SASt_sum = np.zeros_like(A)
            for v in range(num_prompts):
                SASt_sum += alpha[v] * (S_list[v] @ A @ S_list[v].T)
            A = SASt_sum + (1.0 - alpha_sum) * I
            A = 0.5 * (A + A.T)
            A[A < 0] = 0.0

            num = np.linalg.norm(A - A_prev, 'fro')
            den = np.linalg.norm(A_prev, 'fro') + 1e-12
            delta = num / den
            if delta < 1e-4:
                print(f"  early stop at t={t+1}, delta={delta:.3e}")
                break
            A_prev = A

        H = np.array([compute_Hv(A, S_list[v]) for v in range(num_prompts)], dtype=np.float32)
        beta = update_beta_coordinate_descent(beta, H, lam, passes=TA_cd)

        beta_entropy = float(-(beta * (np.log(beta + 1e-12))).sum())
        print(f"  beta: {np.round(beta, 4)}, H(beta)={beta_entropy:.3f}")

    print("Final RED similarity ready:", A.shape)
    sc = SpectralClustering(
        n_clusters=cluster_num,
        affinity="precomputed",
        assign_labels="discretize",
        random_state=seed,
    )
    preds = sc.fit_predict(A)
    print("Spectral clustering done.")
    cluster_metric(labels, preds)


def main():
    images_emb, labels = load_image_embeddings_and_labels(dataset)
    images_emb = F.normalize(images_emb, dim=1)

    for kk in grid_k:
        for tt in grid_temp:
            print(f"\n=== k={kk}, temp={tt} ===")
            run_red_once(images_emb, labels, k_val=kk, temp_val=tt)

if __name__ == "__main__":
    main()
