import os
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from sklearn.cluster import KMeans

def l2_norm(x: torch.Tensor) -> torch.Tensor:
    return F.normalize(x, dim=1, eps=1e-6)

def orthogonal_loss(W: torch.Tensor) -> torch.Tensor:
    I = torch.eye(W.shape[0], device=W.device)
    return torch.norm(W @ W.T - I, p="fro") ** 2

def load_embeddings(path: str):
    obj = torch.load(path)
    return obj["image_embeds"], obj["text_embeds"], obj["meta"],obj


def kmeans_labels(x: torch.Tensor, k: int, seed: int = 42):
    km = KMeans(n_clusters=k, random_state=seed, n_init=10)
    return torch.tensor(km.fit_predict(x.cpu().numpy()), device=x.device)

def centers(feats: torch.Tensor, labels: torch.Tensor, k: int):
    d = feats.shape[1]
    device = feats.device
    cs = []
    for i in range(k):
        m = feats[labels == i]
        cs.append(m.mean(0) if len(m) else torch.zeros(d, device=device))
    return l2_norm(torch.stack(cs))

def hungarian(src: torch.Tensor, tgt: torch.Tensor):
    cost = (1 - src @ tgt.T).cpu().numpy()
    row, col = linear_sum_assignment(cost)
    return torch.tensor(col, device=src.device)

def filter_external(
    coco_feat: torch.Tensor,
    cc3m_feat: torch.Tensor,
    k_in: int = 100,
    k_out: int = 100,
    tau_c: float = 0.8,
    tau_v: float = 0.7,
    seed: int = 42,
):
    device = coco_feat.device
    co_lbl = kmeans_labels(coco_feat, k_in, seed)
    cc_lbl = kmeans_labels(cc3m_feat, k_out, seed)
    co_ctr = centers(coco_feat, co_lbl, k_in)
    cc_ctr = centers(cc3m_feat, cc_lbl, k_out)

    keep_idx, pseudo, weights = [], [], []
    sim_mat = cc_ctr @ co_ctr.T  
    for cid in range(k_out):
        if sim_mat[cid].max() < tau_c:
            continue
        vec_idx = (cc_lbl == cid).nonzero(as_tuple=False).squeeze(1)
        if vec_idx.numel() == 0:
            continue
        vec = l2_norm(cc3m_feat[vec_idx])
        sim = vec @ co_ctr.T
        mval, midx = sim.max(1)
        mask = mval >= tau_v
        if mask.any():
            keep_idx.append(vec_idx[mask])
            pseudo.append(midx[mask])
            weights.append(mval[mask])

    keep_idx = torch.cat(keep_idx)
    pseudo_labels = torch.cat(pseudo)
    weights = torch.cat(weights)
    return cc3m_feat[keep_idx], keep_idx, pseudo_labels, weights, co_lbl, co_ctr

def knn_overlap(orig, mapped, k=10, chunk=4096):
    
    orig   = F.normalize(orig, dim=1)
    mapped = F.normalize(mapped, dim=1)
    N      = orig.size(0)
    device = orig.device

    overlap_sum = 0.0
    for s in range(0, N, chunk):
        e = min(s + chunk, N)
        sim_o  = torch.matmul(orig[s:e], orig.T)          # (b, N)
        sim_o[torch.arange(e-s, device=device),
              torch.arange(s, e, device=device)] = -1.0   
        knn_o  = sim_o.topk(k, dim=1).indices             # (b, k)

        sim_m  = torch.matmul(mapped[s:e], mapped.T)
        sim_m[torch.arange(e-s, device=device),
              torch.arange(s, e, device=device)] = -1.0
        knn_m  = sim_m.topk(k, dim=1).indices

        overlap = (knn_o.unsqueeze(2) == knn_m.unsqueeze(1))  # (b, k, k)
        overlap = overlap.any(2).float().mean(1)              # (b,)
        overlap_sum += overlap.sum().item()

    return overlap_sum / N          
def l2_normalize(x: torch.Tensor) -> torch.Tensor:
    return x / x.norm(dim=1, keepdim=True).clamp_min(1e-9)

def whitening(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, float]:
    
    mu = x.mean(0)
    x_c = x - mu
    scale = x_c.norm(dim=1).mean()
    return x_c / scale, mu, scale.item()

def csls_torch(x_src: torch.Tensor, x_tgt: torch.Tensor, k: int = 10) -> torch.Tensor:
    
    sim = x_src @ x_tgt.T                                   # [Ns, Nt]
    knn_sim_src, _ = sim.topk(k, dim=1)                     # [Ns, k]
    knn_sim_tgt, _ = sim.topk(k, dim=0)                     # [k, Nt]
    r_src = knn_sim_src.mean(1, keepdim=True)               # [Ns, 1]
    r_tgt = knn_sim_tgt.mean(0, keepdim=True)               # [1, Nt]
    return 2 * sim - r_src - r_tgt                          # [Ns, Nt]

def build_mnn(csls_sim: torch.Tensor, topk: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
    
    nn_tgt = csls_sim.argmax(1)               # src -> best tgt
    nn_src = csls_sim.argmax(0)               # tgt -> best src
    idx_src = torch.arange(csls_sim.size(0), device=csls_sim.device)
    mask = (nn_src[nn_tgt] == idx_src)
    return idx_src[mask], nn_tgt[mask]

@dataclass
class AlignParams:
    mu_ext: torch.Tensor
    mu_int: torch.Tensor
    scale_ext: float
    scale_int: float
    W: torch.Tensor

    def save(self, path):
        torch.save(
            {
                "mu_ext": self.mu_ext.cpu(),
                "mu_int": self.mu_int.cpu(),
                "scale_ext": self.scale_ext,
                "scale_int": self.scale_int,
                "W": self.W.cpu(),
            },
            str(path),
        )

    @staticmethod
    def load(path: str) -> "AlignParams":
        st = torch.load(str(path), map_location="cpu")
        return AlignParams(**st)

def whitening_norm(x):
    mu = x.mean(0)
    x_c = x - mu
    scale = x_c.norm(dim=1).mean()
    return l2_normalize(x_c / scale), mu, scale.item()

@torch.no_grad()
def fit_unsupervised(
    X_ext, X_int,
    n_iter=10, k_csls=50, min_pairs=500,
):
    X_ext_w, mu_e, sc_e = whitening_norm(X_ext)   # ←★ 加 L2
    X_int_w, mu_i, sc_i = whitening_norm(X_int)   # ←★
    d = X_ext.size(1); W = torch.eye(d, device=X_ext.device)

    for it in range(n_iter):
        X_ext_m = X_ext_w @ W.T
        sim_csls = csls_torch(X_ext_m, X_int_w, k=k_csls)
        idx_e, idx_i = build_mnn(sim_csls)
        print(f"[iter {it}] MNN pairs: {len(idx_e)}")

        if len(idx_e) < min_pairs and it > 0:
            print("pairs is too low…")

        M = X_int_w[idx_i].T @ X_ext_w[idx_e]
        U, _, Vh = torch.linalg.svd(M); W_new = U @ Vh
        delta = (W_new - W).abs().mean().item(); W = W_new
        print(f"ΔW={delta:.6f}")
        if delta < 1e-6: break

    return AlignParams(mu_e, mu_i, sc_e, sc_i, W)


def map_external(v_ext: torch.Tensor, p: AlignParams) -> torch.Tensor:
    return ((v_ext - p.mu_ext) / p.scale_ext) @ p.W.T

def preprocess_internal(v_int: torch.Tensor, p: AlignParams) -> torch.Tensor:
    return (v_int - p.mu_int) / p.scale_int

def run_easy_alignment(
    coco_path: str,
    cc3m_path: str,
    dataset_name: str = "MScoco",
    ratio: float = 0.35,
    k_in: int = 128,
    k_out: int = 512,
    tau_c: float = 0.7,
    tau_v: float = 0.7,
    n_iter: int = 5,
    k_csls: int = 50,
    min_pairs: int = 500,
    ckpt_dir: str = "./checkpoints/align_net",
    out_internal_easy: str = "./coco_data/internal_kb_easy.pt",
    out_external_easy: str = "./cc3m_kb/cc3m_kb_easy_MScoco.pt",
    device: str = None,
) -> Tuple[str, str]:
    
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    os.makedirs(ckpt_dir, exist_ok=True)
    os.makedirs(os.path.dirname(out_internal_easy), exist_ok=True)
    os.makedirs(os.path.dirname(out_external_easy), exist_ok=True)

    coco_img, coco_txt, coco_meta, data_coco = load_embeddings(coco_path)
    cc3m_img, cc3m_txt, cc3m_meta, data_cc3m = load_embeddings(cc3m_path)
    coco_img = coco_img.to(device)
    coco_txt = coco_txt.to(device)
    cc3m_img = cc3m_img.to(device)
    cc3m_txt = cc3m_txt.to(device)

    cc_sel_img, keep_idx_img, pseudo_img, w_img, co_lbl_img, co_ctr_img = filter_external(
        coco_img, cc3m_img, k_in=k_in, k_out=k_out, tau_c=tau_c, tau_v=tau_v
    )
    
    cc3m_img_sel = cc3m_img[keep_idx_img].detach().cpu()
    cc3m_txt_sel = cc3m_txt[keep_idx_img].detach().cpu()
    meta_sel     = [cc3m_meta[i] for i in keep_idx_img.cpu().tolist()]

    params_img = fit_unsupervised(
        X_ext=cc_sel_img.float(),
        X_int=coco_img.float(),
        n_iter=n_iter, k_csls=k_csls, min_pairs=min_pairs,
    )

    ckpt_img = os.path.join(ckpt_dir, f"{dataset_name}/{ratio}/unsup_align_params_img.pt")
    os.makedirs(os.path.dirname(ckpt_img), exist_ok=True)
    params_img.save(ckpt_img)

    coco_img_proc = preprocess_internal(coco_img, params_img).cpu()
    data_coco["image_embeds"] = coco_img_proc

    cc3m_img_aligned = map_external(cc3m_img_sel.to(device), params_img).cpu()

    coco_txt_feat = coco_txt
    cc3m_txt_feat = cc3m_txt_sel.to(device)

    cc_sel_txt, keep_idx_txt, pseudo_txt, w_txt, co_lbl_txt, co_ctr_txt = filter_external(
        coco_txt_feat, cc3m_txt_feat,
        k_in=k_in, k_out=k_out, tau_c=tau_c, tau_v=tau_v
    )
    
    cc3m_img_aligned = cc3m_img_aligned[keep_idx_txt.cpu()]
    cc3m_txt_sel     = cc3m_txt_sel[keep_idx_txt.cpu()]
    meta_sel         = [meta_sel[i] for i in keep_idx_txt.cpu().tolist()]

    params_txt = fit_unsupervised(
        X_ext=cc_sel_txt.float(),
        X_int=coco_txt_feat.float(),
        n_iter=n_iter, k_csls=k_csls, min_pairs=min_pairs,
    )
    ckpt_txt = os.path.join(ckpt_dir, f"{dataset_name}/{ratio}/unsup_align_params_txt.pt")
    os.makedirs(os.path.dirname(ckpt_txt), exist_ok=True)
    params_txt.save(ckpt_txt)

    coco_txt_proc = preprocess_internal(coco_txt_feat, params_txt).cpu()
    data_coco["text_embeds"] = coco_txt_proc

    cc3m_txt_aligned = map_external(cc3m_txt_sel.to(device), params_txt).cpu()

    torch.save(
        {"image_embeds": cc3m_img_aligned,
         "text_embeds":  cc3m_txt_aligned,
         "meta":         meta_sel},
        out_external_easy
    )
    torch.save(data_coco, out_internal_easy)

    print(f"internal_easy -> {out_internal_easy}\n  - external_easy -> {out_external_easy}\n"
          f"{ckpt_img}\n  - {ckpt_txt}")

    return out_internal_easy, out_external_easy
