import torch
import torch.nn.functional as F
import json
from torch_uncertainty.datamodules import CIFAR100DataModule
from torchvision.models import vit_b_16
try:
    from torchvision.models import ViT_B_16_Weights
except ImportError:
    ViT_B_16_Weights = None
import pytorch_lightning as pl
from torchmetrics.classification import BinaryAUROC, CalibrationError
from torchmetrics import Accuracy
from torch.utils.data import DataLoader
from torchvision import transforms

# ─── 1) MODEL LOADING ─────────────────────────────────────────────────────────
def load_model(ckpt_path: str, device: str) -> torch.nn.Module:
    use_ckpt = ckpt_path not in (None, "",) and not str(ckpt_path).startswith("torchvision://")

    if not use_ckpt:
        print("[info] No CIFAR100 checkpoint provided; initializing ViT-B/16 from scratch.")

    model = vit_b_16(weights=None, num_classes=100, image_size=224)
    if use_ckpt:
        ckpt = torch.load(ckpt_path, map_location="cpu")
        state = ckpt.get("state_dict", ckpt)
        new_state = {
            (k[len("model."):] if k.startswith("model.") else k): v
            for k, v in state.items()
        }
        renamed = {}
        for k, v in new_state.items():
            if k == "heads.weight":
                renamed["heads.head.weight"] = v
            elif k == "heads.bias":
                renamed["heads.head.bias"] = v
            else:
                renamed[k] = v
        model.load_state_dict(renamed, strict=True)

    return model.to(device).eval()

# ───LIGHTNING MODULE FOR METRICS ─────────────────────────────────────────
class EvalModule(pl.LightningModule):
    def __init__(self, model: torch.nn.Module, id_loader, ood_loader):
        super().__init__()
        self.model = model
        self.id_loader = id_loader
        self.ood_loader = ood_loader
        self.auroc = BinaryAUROC(dist_sync_on_step=True)
        self.acc   = Accuracy(task="multiclass", num_classes=100, dist_sync_on_step=True)
        self.ece   = CalibrationError(task="multiclass", num_classes=100, dist_sync_on_step=True)

    def test_dataloader(self):
        return [self.id_loader, self.ood_loader]

    def test_step(self, batch, batch_idx, dataloader_idx):
        x, y = batch
        logits = self.model(x)
        probs = F.softmax(logits, dim=1)
        scores, preds = probs.max(dim=1)

        if dataloader_idx == 0:
            self.acc.update(preds, y)
            self.ece.update(probs, y)
            labels_ood = torch.ones_like(scores).int()
        else:
            labels_ood = torch.zeros_like(scores).int()

        self.auroc.update(scores, labels_ood)

    def on_test_epoch_end(self):
        auc = self.auroc.compute()
        acc = self.acc.compute()
        ece = self.ece.compute()
        self.log("test_auroc", auc,    prog_bar=True,sync_dist=True)
        self.log("test_accuracy", acc, prog_bar=True,sync_dist=True)
        self.log("test_ece", ece,      prog_bar=True,sync_dist=True)
        self.auroc.reset()
        self.acc.reset()
        self.ece.reset()

    def configure_optimizers(self):
        return None 

# ─── PRUNING LOOP ───────────────────────────────────────────────────
def main():
    # data
    weights_tfms = None
    if ViT_B_16_Weights is not None:
        try:
            weights_tfms = ViT_B_16_Weights.IMAGENET1K_V1.transforms(antialias=True)
        except Exception:
            weights_tfms = ViT_B_16_Weights.IMAGENET1K_V1.transforms() if hasattr(ViT_B_16_Weights, "IMAGENET1K_V1") else None
    if weights_tfms is None:
        norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        weights_tfms = transforms.Compose([
                transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.ToTensor(),
                norm,])

    dm = CIFAR100DataModule(
    root="data",
    test_transform=weights_tfms,
    batch_size=1024,
    num_workers=8,
    eval_ood=True,
    pin_memory=True,
    persistent_workers=False,
    )

    dm.setup("test"); ood_loader = dm.test_dataloader()[2]
    dm.setup("fit");  id_loader  = dm.val_dataloader()
    

    model = load_model("torchvision://vit_b_16_cifar100","cpu")
    evaluator = EvalModule(model, id_loader, ood_loader)
    trainer = pl.Trainer(
        accelerator="gpu",
        devices="auto",        
        strategy="ddp",  
        logger=False,
    )

    # baseline
    print("\n>>> Baseline evaluation")
    base = trainer.test(evaluator, verbose=False)[0]
    baseline_auc  = base["test_auroc"]
    baseline_acc  = base["test_accuracy"]
    baseline_ece  = base["test_ece"]
    print(f"  AUROC = {baseline_auc:.4f}")
    print(f"  Accuracy = {baseline_acc:.4%}")
    print(f"  ECE = {baseline_ece:.4f}")

    # pruning setup
    budget = 100
    num_layers     = len(model.encoder.layers)
    heads_per_layer = model.encoder.layers[0].self_attention.num_heads
    remaining_heads = [(l, h) for l in range(num_layers) for h in range(heads_per_layer)]
    pruned_heads    = []

    history = [{
        "step": 0,
        "remaining": len(remaining_heads),
        "auroc":  float(baseline_auc),
        "accuracy": float(baseline_acc),
        "ece": float(baseline_ece)
    }]

    # loop
    for step in range(1, budget+1):
        print(f"\n=== Pruning step {step}/{budget} ===")
        # report current
        cur = trainer.test(evaluator, verbose=False)[0]
        print(f" Current: AUROC={cur['test_auroc']:.4f}, "
              f"Acc={cur['test_accuracy']:.4%}, ECE={cur['test_ece']:.4f}")

        best_score = -1.0
        best_head  = None
        best_metrics = None

        for (l, h) in remaining_heads:
            block = model.encoder.layers[l]
            mha   = block.self_attention
            head_dim = model.hidden_dim // mha.num_heads
            s, e = h * head_dim, (h+1) * head_dim

            orig = mha.out_proj.weight[:, s:e].clone()
            with torch.no_grad():
                mha.out_proj.weight[:, s:e] = 0.0

            # evaluate
            res = trainer.test(evaluator, verbose=False)[0]
            auc = res["test_auroc"]
            acc = res["test_accuracy"]
            score = 0.5 * (auc + acc)
            print(f"  L{l}H{h}: AUROC={auc:.4f}, Acc={acc:.4%}, Score={score:.4f}")

            # restore
            with torch.no_grad():
                mha.out_proj.weight[:, s:e] = orig

            if score > best_score:
                best_score = score
                best_head = (l, h)
                best_metrics = (auc, acc)

        # permanently prune best_head
        l, h = best_head
        block = model.encoder.layers[l]
        mha   = block.self_attention
        s, e = h * head_dim, (h+1) * head_dim
        with torch.no_grad():
            mha.out_proj.weight[:, s:e] = 0.0

        pruned_heads.append(best_head)
        remaining_heads.remove(best_head)
        print(f"→ Pruned head L{l}H{h} (Score={best_score:.4f})")

        # log history
        new = trainer.test(evaluator, verbose=False)[0]

        history.append({
            "step": step,
            "pruned": [l, h],
            "remaining": len(remaining_heads),
            "auroc":  float(new["test_auroc"]),
            "accuracy": float(new["test_accuracy"]),
            "ece": float(new["test_ece"]),
            "pruned_score": float(best_score),
            "pruned_head": [l, h],
            "pruned_score_detail": {
                "auroc": float(best_metrics[0]),
                "accuracy": float(best_metrics[1])
            }
        })

    # final report
    print("\n=== Pruning complete ===")
    print("Pruned heads:")
    for l, h in pruned_heads:
        print(f"  Layer {l}, Head {h}")
    print("\nRemaining heads:")
    for l, h in remaining_heads:
        print(f"  Layer {l}, Head {h}")

    with open("pruning_history_ood_acc_100.json", "w") as f:
        json.dump(history, f, indent=2)
    print("\nSaved pruning_history_ood_acc_100.json")

if __name__ == "__main__":
    main()
