from __future__ import annotations
import argparse, json, re
from pathlib import Path
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.datasets as dsets
from torch.utils.data import DataLoader, random_split
import torchvision.models as models
import timm
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import pdist, squareform
from scipy.stats import spearmanr, pearsonr
from dataloader.cub2011 import Cub2011                

from datasets import load_dataset        
from functools import partial         
from sklearn.metrics import silhouette_score

import umap
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches     
from itertools import cycle, islice
from matplotlib import cm

BATCH_SIZE, SEED, BETA_INIT = 256, 42, 0.5
CKPT_REGEX = re.compile(r".*\.pth$")
torch.manual_seed(SEED)


def strip_prefix(name: str) -> str:
    return name.split('.', 1)[1] if name[:3].isdigit() and name[3] == '.' else name

from functools import partial

def collate_tiny_imagenet(batch, transform):
    imgs   = [transform(ex["image"].convert("RGB")) for ex in batch]
    labels = [ex["label"] for ex in batch]
    return torch.stack(imgs), torch.tensor(labels)


def make_loader(dataset: str,
                data_root: Path,
                split: str = "test",
                batch_size: int = BATCH_SIZE):
    ds = dataset.lower()


    if ds == "cifar":
        tfm = T.Compose([T.ToTensor()])
        cifar = dsets.CIFAR100(data_root / "cifar",
                               train=False, download=True, transform=tfm)
        loader = DataLoader(cifar, batch_size, shuffle=False)
        classes = cifar.classes


    elif ds == "cub":
        tfm_train = T.Compose([
            T.Resize(256), T.CenterCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])
        tfm_eval = T.Compose([
            T.Resize(256), T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])

        train_set = Cub2011('./data', train=True,
                            download=True, transform=tfm_train)
        test_set  = Cub2011('./data', train=False,
                            download=True, transform=tfm_eval)

        if split == "train":
            subset = train_set
        elif split == "val":
            vlen = int(0.2 * len(test_set))
            subset, _ = random_split(test_set, [vlen, len(test_set)-vlen],
                                     generator=torch.Generator().manual_seed(SEED))
        else:                      
            subset = test_set

        loader  = DataLoader(subset, batch_size, shuffle=False)
        classes = [strip_prefix(c) for c in train_set.class_names]


    elif ds =='imagenet':
        train_tfm = T.Compose([
            T.RandomResizedCrop(64),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean=[0.480, 0.448, 0.398],
                        std=[0.277, 0.269, 0.282]),
        ])
        eval_tfm = T.Compose([
            T.Resize(64), T.CenterCrop(64),
            T.ToTensor(),
            T.Normalize(mean=[0.480, 0.448, 0.398],
                        std=[0.277, 0.269, 0.282]),
        ])

        
        raw = load_dataset("zh-plus/tiny-imagenet")

        if split == "train":
            hf_ds, active_tfm = raw["train"], train_tfm
        else:                           # "val" or "test"
            vt = raw["valid"].train_test_split(test_size=0.2, seed=SEED)
            hf_ds = vt["train" if split == "val" else "test"]
            active_tfm = eval_tfm

        hf_ds = hf_ds.with_format("python", columns=["image", "label"])

        pin = (torch.cuda.is_available()
            and torch.device(args.device).type == "cuda")

        loader = DataLoader(
            hf_ds,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=pin,
            collate_fn=partial(collate_tiny_imagenet, transform=active_tfm),  # ← picklable
        )


       
        with open("fine_to_coarse_imagenet.json") as f:
            fine_to_coarse = json.load(f)
        classes = list(fine_to_coarse.keys())

   
    else:
        raise ValueError(f"Unknown dataset '{dataset}'")

    return loader, classes


class ResNet50Encoder(nn.Module):
    """ResNet-50 with optional CIFAR (3×3) stem."""
    def __init__(self,
                 losses: List[str],
                 num_classes: int,
                 feat_dim: int = 384,
                 cifar_stem: bool = False):
        super().__init__()
        self.selected = losses
        net = models.resnet50(weights=None)
        if cifar_stem:                       # 3×3, stride-1 stem
            net.conv1   = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
            net.maxpool = nn.Identity()
        net.fc = nn.Linear(net.fc.in_features, feat_dim)
        self.backbone, self.fc = net, nn.Linear(feat_dim, num_classes)
        self.alpha = nn.Parameter(torch.ones(len(losses)))
        self._smx  = nn.Softmax(dim=0)

    def forward(self, x):
        z = self.backbone(x)
        return self.fc(z), z


class ConvNeXtEncoder(nn.Module):
    def __init__(self, num_classes: int, feat_dim: int = 384):
        super().__init__()
        net = timm.create_model("convnext_tiny", pretrained=True, num_classes=0)
        self.backbone, self.head = net, nn.Linear(net.num_features, feat_dim)
        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, x):
        z = self.head(self.backbone(x))
        return self.fc(z), z


class ViTEncoder(nn.Module):
    def __init__(self, num_classes: int,
                 model_name: str = "vit_small_patch16_224",
                 img_size: int = 64):
        super().__init__()
        self.backbone = timm.create_model(model_name,
                                          pretrained=True, img_size=img_size,
                                          num_classes=0)
        self.fc = nn.Linear(self.backbone.num_features, num_classes)

    def forward(self, x):
        z = self.backbone(x)
        return self.fc(z), z



def build_model(tag: str, losses: List[str], n_cls: int,
                ckpt: Path, device: torch.device) -> nn.Module:
    state = torch.load(ckpt, map_location="cpu")["model_state_dict"]

    if "resnet" in tag:
        k = state["backbone.conv1.weight"].shape[2]   # kernel height
        cifar_stem = (k == 3)
        net = ResNet50Encoder(losses, n_cls, cifar_stem=cifar_stem)

    elif "convnext" in tag:
        net = ConvNeXtEncoder(n_cls)

    elif "vit" in tag:
        tokens = state["backbone.pos_embed"].shape[1]
        grid = int(round((tokens - 1) ** 0.5))        # patch grid
        img_size = grid * 16 if grid * grid == tokens - 1 else 224
        net = ViTEncoder(n_cls, img_size=img_size)

    else:
        raise NotImplementedError(tag)

    net.load_state_dict(state, strict=False)
    return net.to(device).eval()



def nn_hit_rate(lat, pred, lut):
    clusters: Dict[int, List[torch.Tensor]] = {}
    for z, p in zip(lat, pred):
        clusters.setdefault(int(p), []).append(z)
    if len(clusters) < 2:
        return np.nan
    means  = torch.stack([torch.stack(v).mean(0) for v in clusters.values()])
    coarse = torch.tensor([lut[k] for k in clusters.keys()],
                          device=means.device)
    dist = torch.cdist(means, means); dist.fill_diagonal_(float("inf"))
    hits = sum(coarse[i] == coarse[dist[i].argmin()] for i in range(len(coarse)))
    return hits / len(coarse)


def corr_metrics(lat_cent, sem_cent):
    if lat_cent.size(0) < 2:
        return {"spearman": np.nan, "pearson": np.nan}
    lat_cent = F.normalize(lat_cent, 1)
    sem_cent = F.normalize(sem_cent, 1)
    ld = squareform(pdist(lat_cent.cpu(),  "euclidean"))
    sd = squareform(pdist(sem_cent.cpu(), "euclidean"))
    tri = np.triu_indices_from(ld, 1)
    return {"spearman": spearmanr(ld[tri], sd[tri]).correlation,
            "pearson":  pearsonr(ld[tri], sd[tri])[0]}

def sil_score(latents: torch.Tensor, labels: torch.Tensor) -> float:
    """
    Silhouette score for latent vectors grouped by *predicted* fine labels.
    Returns NaN if fewer than 2 distinct labels.
    """
    uniq = torch.unique(labels)
    if uniq.numel() < 2:
        return float("nan")
    # scikit-learn expects numpy
    return silhouette_score(latents.numpy(), labels.numpy(), metric="euclidean")




plt.style.use("default")  

def _default_palette(n: int) -> List[str]:
    cmap = cm.get_cmap('tab20', n)
    return [cmap(i) for i in range(n)]


def generate_legend_image(class_labels, output_path=Path("legend_only.png")):
    n = len(class_labels)
    palette = _default_palette(n)

    fig, ax = plt.subplots(figsize=(6, 6))
    fig.patch.set_visible(False)
    ax.axis('off')

    handles = [mpatches.Patch(color=palette[i], label=label) for i, label in enumerate(class_labels)]
    legend = ax.legend(handles=handles, loc='center', ncol=4,
                       frameon=True, framealpha=1.0, fontsize=8,
                       edgecolor='black')

    # Save as tight image
    fig.savefig(output_path, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    print(f"Legend saved to {output_path}")



def plot_umap(lat: torch.Tensor,
              fine_labels: torch.Tensor,
              coarse_labels: torch.Tensor,
              combo: str,
              dataset: str,
              model: str,
              coarse_names: List[str],
              outdir: Path = Path("."),
              random_state: int = 42):

    reducer = umap.UMAP(n_components=2,
                        metric="euclidean",
                        random_state=random_state)
    emb = reducer.fit_transform(lat.numpy())


    generate_legend_image(coarse_names)

    fine_np = fine_labels.cpu().numpy()
    coarse_np = coarse_labels.cpu().numpy()
    unique_fine = np.unique(fine_np)
    centroids = np.vstack([
        emb[fine_np == i].mean(axis=0)
        for i in unique_fine
        ])
    coarse_of_fine = np.array([
        np.unique(coarse_np[fine_np == i])[0]
        for i in unique_fine
        ], dtype=int)

    palette_fine = _default_palette(len(unique_fine))
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(
        centroids[:, 0], centroids[:, 1],
        c=palette_fine,
        s=50, alpha=0.9, linewidths=0
    )
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(emb[:, 0].min(), emb[:, 0].max())
    ax.set_ylim(emb[:, 1].min(), emb[:, 1].max())
    ax.axis('off')
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(0.8)
    fname_fine = outdir / f"{dataset}_{model}_{combo}_fine.png"
    fig.savefig(fname_fine, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    print("saved", fname_fine)

 
    palette_coarse = _default_palette(len(coarse_names))
    fig, ax = plt.subplots(figsize=(6, 6))
  
    ax.scatter(
        emb[:, 0], emb[:, 1],
        c=[palette_coarse[l] for l in coarse_np],
        s=6, alpha=0.2, linewidths=0
    )
   
    ax.scatter(
        centroids[:, 0], centroids[:, 1],
        c=[palette_coarse[c] for c in coarse_of_fine],
        s=50, alpha=1, linewidths=1, edgecolors='black'
    )
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(emb[:, 0].min(), emb[:, 0].max())
    ax.set_ylim(emb[:, 1].min(), emb[:, 1].max())
    ax.axis('off')
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(0.8)
    fname_coarse = outdir / f"{dataset}_{model}_{combo}_coarse.png"
    fig.savefig(fname_coarse, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    print("saved", fname_coarse)


def evaluate_pair(root: Path, dataset: str, model: str,
                  device: torch.device, data_root: Path) -> pd.DataFrame:
    pair_dir = root / f"{dataset}_{model}"
    if not pair_dir.is_dir():
        raise FileNotFoundError(pair_dir)

    loader, fine = make_loader(dataset, data_root)
    if dataset.lower() == "cifar":
        json_file = "fine_to_coarse_cifar.json"
    elif dataset.lower() == "cub":
        fine = [strip_prefix(c) for c in fine]  
        json_file = "fine_to_coarse_cub.json"
    elif dataset.lower() == "imagenet":
        fine = [strip_prefix(c) for c in fine]   
        json_file = "fine_to_coarse_imagenet.json"
    else:
        raise RuntimeError("Add mapping JSON for this dataset")

    with open(json_file) as f:
        f2c = json.load(f)

    coarse = sorted(set(f2c[c] for c in fine))
    lut = torch.tensor([coarse.index(f2c[c]) for c in fine],
                       dtype=torch.long, device=device)

    sent = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    fine_emb = torch.tensor(sent.encode(fine, normalize_embeddings=True),
                            device=device)

    rows = []
    for combo_dir in sorted(p for p in pair_dir.iterdir() if p.is_dir()):
        ckpts = [p for p in combo_dir.iterdir() if CKPT_REGEX.match(p.name)]
        if not ckpts:
            print(f"⚠  no checkpoint in {combo_dir.name}")
            continue

        net = build_model(model, combo_dir.name.split(","), len(fine),
                          ckpts[0], device)

        ok = tot = 0; lat, pr = [], []
        with torch.no_grad():
            for imgs, labels in loader:
                imgs, labels = imgs.to(device), labels.to(device)
                logits, z = net(imgs); p = logits.argmax(1)
                ok += (p == labels).sum().item(); tot += labels.size(0)
                lat.append(z.cpu()); pr.append(p.cpu())

        acc = ok / tot
        lat, pr = torch.cat(lat), torch.cat(pr)
        sil = sil_score(lat, pr)                
        hit = nn_hit_rate(lat, pr, lut)


        coarse_pred = lut[pr]    # map predicted fine → coarse id
        plot_umap(lat,
                pr,
                coarse_pred,
                combo_dir.name.replace(",", "_"), 
                dataset,
                model,
                coarse_names=coarse,   
                outdir=pair_dir)      
        
        clusters = {}
        for z, p in zip(lat, pr):
            clusters.setdefault(int(p), []).append(z)
        keys = sorted(clusters.keys())
        lat_cent = torch.stack([torch.stack(clusters[k]).mean(0) for k in keys])
        sem_cent = fine_emb[torch.tensor(keys, device=device)]

        
        rows.append({
            "combo": combo_dir.name,
            "accuracy": acc,
            "nn_hit_rate": hit,
            "silhouette": sil,               
            **corr_metrics(lat_cent, sem_cent)
        })

    return (pd.DataFrame(rows)
              .sort_values("accuracy", ascending=False)
              .reset_index(drop=True))

def cli():
    p = argparse.ArgumentParser()
    p.add_argument("--root",    required=True,  type=Path)
    p.add_argument("--dataset", required=True,  help="cifar | cub")
    p.add_argument("--model",   required=True,  help="resnet50 | convnext | vit")
    p.add_argument("--device",  default=("cuda" if torch.cuda.is_available()
                                         else "mps" if torch.backends.mps.is_available()
                                         else "cpu"))
    p.add_argument("--data",    default="./data", type=Path)
    return p.parse_args()


if __name__ == "__main__":
    args = cli()
    dev = torch.device(args.device)
    df  = evaluate_pair(args.root, args.dataset, args.model, dev, args.data)

    print(f"\n=== {args.dataset.upper()} | {args.model} ===")
    print(df.loc[:, ["combo", "accuracy",
                 "nn_hit_rate", "silhouette",  
                 "spearman", "pearson"]]
          .to_string(index=False, float_format="%.4f"))