import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import open_clip
import json
import pytorch_lightning as pl
from torchmetrics.classification import BinaryAUROC, CalibrationError
from torchmetrics import Accuracy
import sys, os, argparse
from pathlib import Path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
BASE_DIR = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(BASE_DIR))

from models.utils import get_prompts, get_text_logits
from torch_uncertainty.datamodules.classification import ImageNetDataModule, CIFAR100DataModule

MODEL = (
    "ViT-B-32", 
    "laion2b_s34b_b79k"
    )

def get_datamodule(root, dataset):
    
    batch_size = 512
    num_workers = 4
    _, _, preprocess = open_clip.create_model_and_transforms(MODEL[0], 
                                                             pretrained=MODEL[1])
    if dataset == "imagenet-1k":
        dm = ImageNetDataModule(
            root=root[0],
            batch_size=batch_size,
            test_transform=preprocess,
            num_workers=num_workers,
            eval_ood=True
        )
        dm = {"imagenet-1k": dm}
    elif dataset == "cifar100":
        dm = CIFAR100DataModule(
            root=root[0],
            batch_size=batch_size,
            test_transform=preprocess,
            num_workers=num_workers,
            eval_ood=True
        )
        dm = {"cifar100": dm}
    else:
        dm = dict()
        in1k_dm = ImageNetDataModule(
            root=root[0],
            batch_size=batch_size,
            test_transform=preprocess,
            num_workers=num_workers,
            eval_ood=True
        )
        cifar100_dm = CIFAR100DataModule(
            root=root[1],
            batch_size=batch_size,
            test_transform=preprocess,
            num_workers=num_workers,
            eval_ood=True
        )
        dm = {
            "imagenet-1k": in1k_dm,
            "cifar100": cifar100_dm
        }
    
    return dm

class EvalModule(pl.LightningModule):
    def __init__(
            self, 
            model, 
            text_logits: dict, 
            val_id_loader: dict, 
            val_ood_loader: dict
            ):
        super().__init__()
        self.model = model
        self.dataset_names = list(text_logits.keys())
        self.text_logits = text_logits
        self.id_loaders = val_id_loader
        self.ood_loaders = val_ood_loader

        # Initialize metrics per dataset
        self.aurocs = nn.ModuleDict()
        self.accs   = nn.ModuleDict()
        self.eces   = nn.ModuleDict()
        for name in self.dataset_names:
            num_classes = self.text_logits[name].shape[1]
            self.aurocs[name] = BinaryAUROC(dist_sync_on_step=True)
            self.accs[name]   = Accuracy(task="multiclass", num_classes=num_classes, dist_sync_on_step=True)
            self.eces[name]   = CalibrationError(task="multiclass", num_classes=num_classes, dist_sync_on_step=True)
    
    def test_dataloader(self):
        loaders = []
        self.loader_to_dataset = {}  # Map loader index to (name, is_id)
        idx = 0
        for name in self.dataset_names:
            loaders.append(self.id_loaders[name])
            self.loader_to_dataset[idx] = (name, True)
            idx += 1

            loaders.append(self.ood_loaders[name])
            self.loader_to_dataset[idx] = (name, False)
            idx += 1
        return loaders

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        name, is_id = self.loader_to_dataset[dataloader_idx]

        # Forward
        image_logits = self.model.encode_image(x, normalize=True)
        logits = self.model.logit_scale.exp() * image_logits @ self.text_logits[name]
        probs = F.softmax(logits, dim=1)
        scores, preds = probs.max(dim=1)

        # Metrics
        if is_id:
            self.accs[name].update(preds, y)
            self.eces[name].update(probs, y)
            labels_ood = torch.ones_like(scores).int()
        else:
            labels_ood = torch.zeros_like(scores).int()

        self.aurocs[name].update(scores, labels_ood)

    def on_test_epoch_end(self):
        for name in self.dataset_names:
            auc = self.aurocs[name].compute()
            acc = self.accs[name].compute()
            ece = self.eces[name].compute()

            self.log(f"{name}/auroc", auc, prog_bar=True, sync_dist=True)
            self.log(f"{name}/accuracy", acc, prog_bar=True, sync_dist=True)
            self.log(f"{name}/ece", ece, prog_bar=True, sync_dist=True)

            self.aurocs[name].reset()
            self.accs[name].reset()
            self.eces[name].reset()

    def configure_optimizers(self):
        return None  

    def update_text_logits(self, new_text_logits: dict):
        dataset_names = list(new_text_logits.keys())
        assert set(dataset_names) == set(self.dataset_names), \
            f"Dataset names must match. Instead got: {dataset_names} but was expecting {self.dataset_names}"
        self.text_logits = new_text_logits

def parse_args():
    parser = argparse.ArgumentParser(description="OOD Circuit Pruning for CIFAR100")
    parser.add_argument("--device", type=int, required=True, help="Device to run the model on")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading")
    parser.add_argument("--batch_size", type=int, default=512, help="Batch size for data loading")
    parser.add_argument("--data_root", type=str, nargs="+", default=["data/in1k_torch_uncertainty"], help="Root directory for the IN1K dataset")
    parser.add_argument("--dataset", type=str, default="imagenet-1k", help="Dataset to use for pruning (default: imagenet-1k)")
    parser.add_argument("--out_dir", type=str, default="circuit", help="Output directory for results")
    parser.add_argument("--metric", type=str, required=True, choices=["auroc", "accuracy", "mix"], help="Metric to pick the best head to prune")
    parser.add_argument("--circuit_path", type=str, default="None", help="Path to already existing circuit pruning to continue extraction")
    parser.add_argument("--budget", type=int, default=40, help="Number of heads to prune")
    parser.add_argument("--prune_text_encoder", action="store_true", help="Whether to prune the text encoder as well")
    return parser.parse_args()

def main():
    """
    Workflow: 
    1. Load the CIFAR100 dataset using Torch Uncertainty's datamodule.
    2. Load the pre-trained ViT-B/16 model from a checkpoint tuned on CIFAR100.
    3. Set up the evaluation module with the model and data loaders for ID and OOD datasets.
    4. Evaluate the model on the ID dataset to get baseline metrics (AUROC, accuracy, ECE).
    5. Iteratively prune the least important attention heads based on the specified metric (AUROC, accuracy, 
       or mix = average of AUROC and Accuracy).
    6. After each pruning step , evaluate the model again to check the new metrics.
    7. Save the pruning history (step, remaining heads, metrics) to a JSON file in the specified output directory.
       The output path is structured as:
         {args.out_dir}/ood_circuit/CIFAR100/metric_{args.metric}.json
    """
    args = parse_args()
    batch_size = args.batch_size
    num_workers = args.num_workers
    data_root = args.data_root
    dataset = args.dataset
    device = args.device
    metric = args.metric
    budget = args.budget
    prune_text_encoder = args.prune_text_encoder

    if dataset == "imagenet-1k":
        dataset_path = "in1k"
    elif dataset == "cifar100":
        dataset_path = "cifar100"
    elif dataset == "in1k_cifar100":
        dataset_path = "in1k_cifar100"
    
    model_name = f"{MODEL[0].lower().replace('-', '_')}_{MODEL[1].lower()}"
    out_dir = os.path.join(args.out_dir, f"{model_name}/{dataset_path}")
    if prune_text_encoder:
        out_dir += "_text"
    out_dir += "_vision_pruned"

    circuit_path = args.circuit_path
    os.makedirs(out_dir, exist_ok=True)
    file_name = f"metric_{metric}.json" if metric in ["auroc", "accuracy"] else f"metric_avg_auroc_accuracy.json"

    print("=" * 40)
    print(f"Using device: {device}")
    print(f"Batch size: {batch_size}") 
    print(f"Number of workers: {num_workers}")
    print(f"Dataset: {dataset}")
    print(f"Data root: {data_root}")
    print(f"Output directory: {out_dir}")
    print(f"Metric for pruning: {metric}")
    print(f"Pruning budget: {budget} heads")
    print(f"Prune text encoder: {prune_text_encoder}")
    if circuit_path != "None":
        file_name = file_name.replace(".json", "_1.json")
        print(f"Using existing circuit pruning from: {circuit_path}")
    print("=" * 40)

    out_dir = os.path.join(out_dir, file_name)
    dm_dicts = get_datamodule(data_root, dataset)
    
    val_ood_loaders = dict()
    val_id_loaders = dict()
    for name, dm in dm_dicts.items():
        dm.setup("fit")
        val_id_dataset = dm.val_dataloader().dataset
        val_id_sampler = dm.val_dataloader().sampler
        val_id_loader = DataLoader(
            val_id_dataset,
            batch_size=batch_size,
            sampler=val_id_sampler,
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=False,
        )
            
        dm.setup("test")
        test_data_loaders = dm.test_dataloader()
        val_ood_dataset = test_data_loaders[2].dataset
        val_ood_sampler = test_data_loaders[2].sampler
        val_ood_loader = DataLoader(
            val_ood_dataset,
            batch_size=batch_size,
            sampler=val_ood_sampler,
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=False,
        )
        val_ood_loaders[name] = val_ood_loader
        val_id_loaders[name] = val_id_loader
    
    clip_vit, _, _ = open_clip.create_model_and_transforms(MODEL[0], 
                                                           pretrained=MODEL[1])
    torch_device = torch.device(f"cuda:{device}" if torch.cuda.is_available() else "cpu")
    clip_vit.to(torch_device)
    clip_vit.eval()
    tokenizer = open_clip.get_tokenizer(MODEL[0])
    
    text_logits = dict()
    for name in val_id_loaders.keys():
        prompts = get_prompts(dataset=name)
        text_logits[name] = get_text_logits(prompts, clip_vit, tokenizer, device=torch_device)
    
    evaluator = EvalModule(clip_vit, text_logits, val_id_loaders, val_ood_loaders)
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=[device],
        strategy="ddp",
        logger=False,
    )

    num_layers_vision_encoder = len(clip_vit.visual.transformer.resblocks)
    heads_per_layer = clip_vit.visual.transformer.resblocks[0].attn.num_heads
    remaining_heads = [(l, h) for l in range(num_layers_vision_encoder) for h in range(heads_per_layer)]
    if prune_text_encoder:
        heads_per_layer_text = clip_vit.transformer.resblocks[0].attn.num_heads
        num_layers_text_encoder = len(clip_vit.transformer.resblocks)  
        # treat text encoder layers as continuing after visual encoder layers
        remaining_heads += [(l + num_layers_vision_encoder, h) \
                                for l in range(num_layers_text_encoder) for h in range(heads_per_layer_text)]
    pruned_heads    = []

    if circuit_path != "None":
        with open(circuit_path, "r") as f:
            history = json.load(f)
        removed_heads = history[-1]["step"] 
        budget -= removed_heads   
        for pruning_step in history:
            if "pruned" in pruning_step:
                remaining_heads.remove(tuple(pruning_step["pruned"]))
                pruned_heads.append(tuple(pruning_step["pruned"]))
        initial_step = len(pruned_heads) 
        for l, h in pruned_heads:
            if l < num_layers_vision_encoder:
                block = clip_vit.visual.transformer.resblocks[l]
                mha   = block.attn
                head_dim = mha.embed_dim // mha.num_heads
                s, e = h * head_dim, (h+1) * head_dim
                with torch.no_grad():
                    mha.out_proj.weight[:, s:e] = 0.0
            else:
                l_text = l - num_layers_vision_encoder
                block = clip_vit.transformer.resblocks[l_text]
                mha   = block.attn
                head_dim = mha.embed_dim // mha.num_heads
                s, e = h * head_dim, (h+1) * head_dim
                with torch.no_grad():
                    mha.out_proj.weight[:, s:e] = 0.0
        print(f"\n>>> Starting from existing circuit pruning with {removed_heads} heads removed")   
        print(f"  Remaining heads: {len(remaining_heads)}")
    else:
        initial_step = 0
    
    print("\n>>> Baseline evaluation")
    base = trainer.test(evaluator, verbose=False)[0]
    if dataset != "in1k_cifar100":
        baseline_auc  = base[f"{dataset}/auroc"]
        baseline_acc  = base[f"{dataset}/accuracy"]
        baseline_ece  = base[f"{dataset}/ece"]
        print(f"  AUROC = {baseline_auc:.4f}")
        print(f"  Accuracy = {baseline_acc:.4%}")
        print(f"  ECE = {baseline_ece:.4f}")

        if circuit_path == "None":
            history = [{
                "step": 0,
                "remaining": len(remaining_heads),
                f"{dataset}_auroc":  float(baseline_auc),
                f"{dataset}_accuracy": float(baseline_acc),
                f"{dataset}_ece": float(baseline_ece)
            }] 
    else:
        in1k_auc = base["imagenet-1k/auroc"]
        in1k_acc = base["imagenet-1k/accuracy"]
        in1k_ece = base["imagenet-1k/ece"]

        cifar100_auc = base["cifar100/auroc"]
        cifar100_acc = base["cifar100/accuracy"]
        cifar100_ece = base["cifar100/ece"]

        baseline_auc = (in1k_auc + cifar100_auc) / 2
        baseline_acc = (in1k_acc + cifar100_acc) / 2
        baseline_ece = (in1k_ece + cifar100_ece) / 2

        print("=" * 40)
        print(f"  IN1K AUROC = {in1k_auc:.4f}")
        print(f"  IN1K Accuracy = {in1k_acc:.4%}")
        print(f"  IN1K ECE = {in1k_ece:.4f}")
        print("=" * 40)
        print(f"  CIFAR100 AUROC = {cifar100_auc:.4f}")
        print(f"  CIFAR100 Accuracy = {cifar100_acc:.4%}")
        print(f"  CIFAR100 ECE = {cifar100_ece:.4}")
        print("=" * 40)
        print(f"  Average AUROC = {baseline_auc:.4f}")
        print(f"  Average Accuracy = {baseline_acc:.4%}")
        print(f"  Average ECE = {baseline_ece:.4f}")

        if circuit_path == "None":
            history = [{
                "step": 0,
                "remaining": len(remaining_heads),
                "imagenet-1k_auroc":  float(in1k_auc),
                "imagenet-1k_accuracy": float(in1k_acc),
                "imagenet-1k_ece": float(in1k_ece),
                "cifar100_auroc":  float(cifar100_auc),
                "cifar100_accuracy": float(cifar100_acc),
                "cifar100_ece": float(cifar100_ece),
                "avg_auroc": float(baseline_auc),
                "avg_accuracy": float(baseline_acc),
                "avg_ece": float(baseline_ece)
            }]
    
    for step in range(1, budget+1):
        print(f"\n=== Pruning step {step}/{budget} ===")
        cur = trainer.test(evaluator, verbose=False)[0]
        if dataset != "in1k_cifar100":
            print(f" Current: AUROC={cur[f'{dataset}/auroc']:.4f}, "
                f"Acc={cur[f'{dataset}/accuracy']:.4%}, ECE={cur[f'{dataset}/ece']:.4f}")
        else:
            in1k_auc = cur["imagenet-1k/auroc"]
            in1k_acc = cur["imagenet-1k/accuracy"]
            in1k_ece = cur["imagenet-1k/ece"]
            cifar100_auc = cur["cifar100/auroc"]
            cifar100_acc = cur["cifar100/accuracy"]
            cifar100_ece = cur["cifar100/ece"]

            avg_auc = (in1k_auc + cifar100_auc) / 2
            avg_acc = (in1k_acc + cifar100_acc) / 2
            avg_ece = (in1k_ece + cifar100_ece) / 2

            print(f"Avg AUROC={avg_auc:.4f}, "
                f"Avg Acc={avg_acc:.4%}, "
                f"Avg ECE={avg_ece:.4f}")

        best_score = -1.0
        least_important_head = None
        best_metrics = None

        for (l, h) in remaining_heads:

            if l < num_layers_vision_encoder:
                block = clip_vit.visual.transformer.resblocks[l]
                mha   = block.attn
                head_dim = mha.embed_dim // mha.num_heads
                s, e = h * head_dim, (h+1) * head_dim

                orig = mha.out_proj.weight[:, s:e].clone()
                with torch.no_grad():
                    mha.out_proj.weight[:, s:e] = 0.0
            else:
                l_text = l - num_layers_vision_encoder
                block = clip_vit.transformer.resblocks[l_text]
                mha   = block.attn
                head_dim = mha.embed_dim // mha.num_heads
                s, e = h * head_dim, (h+1) * head_dim

                orig = mha.out_proj.weight[:, s:e].clone()
                with torch.no_grad():
                    mha.out_proj.weight[:, s:e] = 0.0
                clip_vit.to(torch_device)
                text_logits = dict()
                for name in val_id_loaders.keys():
                    print(f"Updating text logits for {name}...")
                    prompts = get_prompts(dataset=name)
                    text_logits[name] = get_text_logits(prompts, clip_vit, tokenizer, device=torch_device)
                evaluator.update_text_logits(text_logits)

            res = trainer.test(evaluator, verbose=False)[0]

            if dataset != "in1k_cifar100":
                auc = res[f"{dataset}/auroc"]
                acc = res[f"{dataset}/accuracy"]
            else:
                in1k_auc = res["imagenet-1k/auroc"]
                in1k_acc = res["imagenet-1k/accuracy"]
                cifar100_auc = res["cifar100/auroc"]
                cifar100_acc = res["cifar100/accuracy"]
                auc = (in1k_auc + cifar100_auc) / 2
                acc = (in1k_acc + cifar100_acc) / 2
            if metric == "auroc":
                score = auc
            elif metric == "accuracy":
                score = acc
            elif metric == "mix":
                score = 0.5 * (auc + acc)
            print(f"  L{l}H{h}: AUROC={auc:.4f}, Acc={acc:.4%}, Score={score:.4f}")

            with torch.no_grad():
                mha.out_proj.weight[:, s:e] = orig

            if score > best_score:
                best_score = score
                least_important_head = (l, h)
                best_metrics = (auc, acc)

        l, h = least_important_head
        if l < num_layers_vision_encoder:
            block = clip_vit.visual.transformer.resblocks[l]
            mha   = block.attn
            head_dim = mha.embed_dim // mha.num_heads
            s, e = h * head_dim, (h+1) * head_dim
        else:
            l_text = l - num_layers_vision_encoder
            block = clip_vit.transformer.resblocks[l_text]
            mha   = block.attn
            head_dim = mha.embed_dim // mha.num_heads
            s, e = h * head_dim, (h+1) * head_dim
        with torch.no_grad():
            mha.out_proj.weight[:, s:e] = 0.0
        
        pruned_heads.append(least_important_head)
        remaining_heads.remove(least_important_head)
        print(f"→ Pruned head L{l}H{h} (Score={best_score:.4f})")

        # ensure evaluator has the latest text logits
        text_logits = dict()
        clip_vit.to(torch_device)
        for name in val_id_loaders.keys():
            print(f"Updating text logits for {name}...")
            prompts = get_prompts(dataset=name)
            text_logits[name] = get_text_logits(prompts, clip_vit, tokenizer, device=torch_device)
        evaluator.update_text_logits(text_logits)

        new = trainer.test(evaluator, verbose=False)[0]
        if dataset != "in1k_cifar100":
            history.append({
                "step": step + initial_step,
                "pruned": [l, h],
                "remaining": len(remaining_heads),
                f"{dataset}_auroc":  float(new[f"{dataset}/auroc"]),
                f"{dataset}_accuracy": float(new[f"{dataset}/accuracy"]),
                f"{dataset}_ece": float(new[f"{dataset}/ece"]),
                "best_score": float(best_score),
                "least_important_head": [l, h],
                "best_metrics": {
                    f"{dataset}_auroc": float(best_metrics[0]),
                    f"{dataset}_accuracy" : float(best_metrics[1])
                }
            })
        else:
            history.append({
                "step": step + initial_step,
                "pruned": [l, h],
                "remaining": len(remaining_heads),
                "imagenet-1k_auroc":  float(new["imagenet-1k/auroc"]),
                "imagenet-1k_accuracy": float(new["imagenet-1k/accuracy"]),
                "imagenet-1k_ece": float(new["imagenet-1k/ece"]),
                "cifar100_auroc":  float(new["cifar100/auroc"]),
                "cifar100_accuracy": float(new["cifar100/accuracy"]),
                "cifar100_ece": float(new["cifar100/ece"]),
                "avg_auroc": float((new["imagenet-1k/auroc"] + new["cifar100/auroc"]) / 2),
                "avg_accuracy": float((new["imagenet-1k/accuracy"] + new["cifar100/accuracy"]) / 2),
                "avg_ece": float((new["imagenet-1k/ece"] + new["cifar100/ece"]) / 2),
                "best_score": float(best_score),
                "least_important_head": [l, h],
                "best_metrics": {
                    "avg_accuracy": float(best_metrics[0]),
                    "avg_auroc" : float(best_metrics[1])
                }
            })

        with open(out_dir, "w") as f:
            json.dump(history, f, indent=2)
        print(f"\nSaved {out_dir}")

    print("\n=== Pruning complete ===")
    print("Pruned heads:")
    for l, h in pruned_heads:
        print(f"  Layer {l}, Head {h}")
    print("\nRemaining heads:")
    for l, h in remaining_heads:
        print(f"  Layer {l}, Head {h}")

if __name__ == "__main__":
    main()