import torch
import torch.nn as nn
import open_clip

import yaml 
import sys
import warnings
from pathlib import Path
import copy
import os
import types
import argparse
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
BASE_DIR = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(BASE_DIR))

from torch_uncertainty import TUTrainer
from torch_uncertainty.routines import ClassificationRoutine
from torch_uncertainty.post_processing import TemperatureScaler
from torch_uncertainty.models import deep_ensembles
from models.utils import get_prompts, get_text_logits
from torch_uncertainty.datamodules.classification import CIFAR100DataModule, ImageNetDataModule
from utils.datamodules import \
    (CIFAR10DataModule, Food101DataModule, EuroSATDataModule, Sun397DataModule, 
     OxfordIIITPetDataModule, DTDDataModule, Caltech101DataModule)

from bayesvlm.hessians import load_hessians
from bayesvlm.utils import get_covariances, get_bayes_vlm_model
from models.block import MultiHeadAttentionPruned

CLIP_MODEL = [
        ("ViT-B-32", "laion2b_s34b_b79k")
    ]

def replace_mha_with_mha_pruned(clip_model):
    
    device = next(clip_model.parameters()).device

    clip_model_pruned = copy.deepcopy(clip_model).to(device)
    vision_transformer_blocks = clip_model_pruned.visual.transformer.resblocks
    text_transformer_blocks = clip_model_pruned.transformer.resblocks

    indices_vision_heads = torch.arange(0,vision_transformer_blocks[0].attn.num_heads,1).to(device)
    vision_embed_dim = vision_transformer_blocks[0].attn.embed_dim 

    indices_text_heads = torch.arange(0, text_transformer_blocks[0].attn.num_heads, 1).to(device)
    text_embed_dim = text_transformer_blocks[0].attn.embed_dim

    for _, (block_new) in enumerate(vision_transformer_blocks):
        mha_new = block_new.attn
        mha_new = MultiHeadAttentionPruned(
            indices_vision_heads, 
            indices_vision_heads,
            vision_embed_dim,
            mha_new.in_proj_weight.data,  
            mha_new.out_proj.weight.data,
            mha_new.in_proj_bias.data, 
            mha_new.out_proj.bias.data
            )
        setattr(block_new, "attn", mha_new)

    for _, (block_new) in enumerate(text_transformer_blocks):
        mha_new = block_new.attn
        mha_new = MultiHeadAttentionPruned(
            indices_text_heads, 
            indices_text_heads,
            text_embed_dim,
            mha_new.in_proj_weight.data,  
            mha_new.out_proj.weight.data,
            mha_new.in_proj_bias.data, 
            mha_new.out_proj.bias.data
            )
        setattr(block_new, "attn", mha_new)
    
    return clip_model_pruned

def load_clip_checkpoint(
        model, 
        ckpt_path, 
        ):
    # prepare model to load state_dict
    # qkv_weight first dim is expected to be 3x embed dim
    # thus the following line it simply replaces the self-attention block to avoid the error
    model = replace_mha_with_mha_pruned(model)
    
    if not os.path.exists(ckpt_path):
        return model 
    
    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    state = ckpt.get('state_dict', ckpt)
    if "active_heads" in ckpt:
        active_heads = ckpt.get('active_heads')
    elif "pruning" in ckpt:
        active_heads = ckpt["pruning"].get("heads", None)
    else:
        active_heads = None

    if isinstance(active_heads, int):
        active_heads = None
    
    state = {k.replace('model.', ''): v for k, v in state.items()}
    state = {k.replace('module.', ''): v for k, v in state.items()}
    state = {
        (k if k.startswith("heads.head") else k.replace("heads.", "heads.head.")): v
        for k, v in state.items()
    }

    for name, param in model.named_parameters():
        if name in state:
            if param.shape != state[name].shape:
                param.data = torch.zeros_like(state[name], dtype=state[name].dtype, device=param.device)

    model.load_state_dict(state, strict=True)
    head_dim = model.visual.transformer.resblocks[0].attn.head_dim
    for i in range(len(model.visual.transformer.resblocks)):
        mha = model.visual.transformer.resblocks[i].attn
        weights = mha.in_proj_weight.data
        num_heads = weights.shape[0] // 3 // head_dim
        mha.num_heads = num_heads
        mha.out_proj.in_features = num_heads * head_dim
        mha.active_heads = active_heads[i] if active_heads else torch.arange(num_heads, device=weights.device)
    for i in range(len(model.transformer.resblocks)):
        mha = model.transformer.resblocks[i].attn
        weights = mha.in_proj_weight.data
        num_heads = weights.shape[0] // 3 // head_dim
        mha.num_heads = num_heads
        mha.out_proj.in_features = num_heads * head_dim
        mha.active_heads = active_heads[i] if active_heads else torch.arange(num_heads, device=weights.device)
        
    return model

def get_datamodule(
        dataset_name,
        root,
        batch_size,
        num_workers,
        eval_ood
        ):
        
    architecture, pretraining = CLIP_MODEL[0]
    _,_,val_tfms = open_clip.create_model_and_transforms(
        architecture,
        pretrained=pretraining
        )

    if dataset_name == "cifar100":
        num_classes = 100
        dm = CIFAR100DataModule(
            root=root,
            batch_size=batch_size,
            test_transform=val_tfms,
            num_workers=num_workers,
            eval_ood=eval_ood,
            pin_memory=True,
            persistent_workers=False
        )
    elif dataset_name == "imagenet-1k":
        num_classes = 1000
        dm = ImageNetDataModule(
            root=root,
            batch_size=batch_size,
            test_transform=val_tfms,
            num_workers=num_workers,
            eval_ood=eval_ood,
            pin_memory=True,
            persistent_workers=False
        )
    elif dataset_name in (
        "sun397", "food101", "cifar10", 
        "oxford_pet", "dtd", "eurosat", "caltech101"):
        torch.manual_seed(42)
        
        if dataset_name == "cifar10":
            num_classes = 10
            dm = CIFAR10DataModule(
                transform=val_tfms,
                data_dir=root,
                batch_size=batch_size
            )
        if dataset_name == "food101":
            num_classes = 101
            dm = Food101DataModule(
                transform=val_tfms,
                data_dir=root,
                batch_size=batch_size            
                )
        if dataset_name == "sun397":
            num_classes = 397
            dm = Sun397DataModule(
                transform=val_tfms,
                batch_size=batch_size 
            ) 
        if dataset_name == "oxford_pet":
            num_classes = 37
            dm = OxfordIIITPetDataModule(
                transform=val_tfms,
                batch_size=batch_size
            )
        if dataset_name == "dtd":
            num_classes = 47
            dm = DTDDataModule(
                transform=val_tfms,
                batch_size=batch_size
            )
        if dataset_name == "eurosat":
            num_classes = 10
            dm = EuroSATDataModule(
                transform=val_tfms,
                batch_size=batch_size
            )
        if dataset_name == "caltech101":
            num_classes = 101
            dm = Caltech101DataModule(
                transform=val_tfms,
                data_dir=root,
                batch_size=batch_size            
                )
    else:
        raise ValueError(f"Dataset {config['dataset']['name']} is not supported.")

    # Use IN-1K val set as in (A. Bauman et al., 2025)
    # https://arxiv.org/pdf/2412.06014
    dm_temp_scaling = ImageNetDataModule(
        root="data/in1k_torch_uncertainty",
        batch_size=batch_size,
        test_transform=val_tfms,
        num_workers=num_workers,
        eval_ood=eval_ood,
        pin_memory=True,
        persistent_workers=False
    )
    
    return dm, dm_temp_scaling, num_classes
    
def clip_forward(self, x):
    image_logits = self.encode_image(x, normalize=True)
    outputs = (self.logit_scale.exp() * image_logits @ self.text_logits)
    return outputs
    
def get_model(
        model_name, 
        dataset_name,
        device
        ):
    
    architecture, pretraining = CLIP_MODEL[0]

    if model_name == "clip_vit_b_32": 
        model, _, _ = open_clip.create_model_and_transforms(
            architecture,
            pretrained=pretraining
        )
        tokenizer = open_clip.get_tokenizer(architecture)
        model.prompts = get_prompts(dataset = dataset_name)
        model.tokenizer = tokenizer
        model.to(device)
        model.forward = types.MethodType(clip_forward, model)
        return model
    elif model_name == "bayes_vlm_b_32":
        model = get_bayes_vlm_model(dataset_name)
        A_img, B_img = load_hessians(model_name = architecture, tag='img', return_info=False)
        A_txt, B_txt = load_hessians(model_name = architecture, tag='txt', return_info=False)  
        cov_img, cov_txt = get_covariances(model.open_clip_model, A_img, B_img, A_txt, B_txt)
        model.set_covariances(cov_img, cov_txt)
        return model
    else:
        raise ValueError(f"Model {config['model']['name']} is not supported.")

def parse_args():
    parser = argparse.ArgumentParser(description="Testing script for pruned Vision Transformers")
    parser.add_argument("--eval_ood", action="store_true")
    parser.add_argument("--eval_temp_scaling", action="store_true")
    parser.add_argument("--device", type=int, default=0)
    return parser.parse_args()

def summary(config, parser):
    print("Testing OpenCLIP ViT/B-32 Configuraton Summary")
    print("=" * 40)
    print(f"Dataset: {config['dataset']['name']}")
    print(f"Eval OOD: {parser.eval_ood}")
    print(f"Eval Temperature Scaling: {parser.eval_temp_scaling}")
    print(f"Device: {parser.device}")
    print("=" * 40)

if __name__ == "__main__":

    warnings.filterwarnings('ignore')

    parser = parse_args()
    eval_ood = parser.eval_ood
    eval_temp_scaling = parser.eval_temp_scaling
    device = parser.device

    with open("testing/test.yaml", "r") as file:
        config = yaml.safe_load(file)

    if eval_ood:
        if "imagenet-1k" not in config["dataset"]["name"] and "cifar100" not in config["dataset"]["name"]:
            raise ValueError("When evaluating OOD one of the datasets must be imagent-1k or cifar100")

    summary(config, parser)

    datasets_name, datasets_root = config["dataset"]["name"], config["dataset"]["root"]
    batch_size = config["dataset"]["batch_size"]
    num_workers = config["dataset"]["num_workers"]

    trainer = TUTrainer(accelerator="gpu",enable_progress_bar=True, devices=[device])

    for method in config["methods"]:
        checkpoints = config["methods"][method]
        if "bayes" in method.lower():
            model_name = "bayes_vlm_b_32"
        elif "clip" in method.lower():
            model_name = "clip_vit_b_32"
        else:
            raise ValueError(f"Couldn't determine model to use. Method name should contain either 'clip' or 'bayes', but got {method.lower()}")
        for (dataset_name, dataset_root) in zip(datasets_name, datasets_root):
            print("=" * 40)
            print(f"Dataset: {dataset_name}")
            print("=" * 40)

            if eval_ood:
                eval_ood_curr_dataset = True if dataset_name in ("imagenet-1k", "cifar100") else False
            else:
                eval_ood_curr_dataset = eval_ood

            dm, dm_temp_scaling, num_classes = get_datamodule(
                dataset_name,
                dataset_root,
                batch_size,
                num_workers,
                eval_ood_curr_dataset
                )
            dm.setup()

            print(f"Testing method: {method}")
            print(f"Checkpoints:")
            for ckpt in checkpoints:
                print(f" - {ckpt}")
            print("="*40)
            
            if not eval_temp_scaling:
                model = get_model(model_name, dataset_name, device=f"cuda:{device}")

                if model_name != "bayes_vlm_b_32":
                    models = []
                    for ckpt in checkpoints: 
                        model_tuned = load_clip_checkpoint(model, ckpt)
                        models.append(model_tuned)
                    for model in models:
                        model.text_logits = get_text_logits(
                            model.prompts, model, model.tokenizer, device=f"cuda:{device}"
                        )
                else:
                    models = [model] 

                models = [model.eval().to(f"cuda:{device}") for model in models]
                
                # Use the deep-ensemble constructor to evaluate ID and OOD performance.
                # With Hydra Ensembles, forward passes remain independent, so the behavior
                # matches deep ensembles in zero-shot settings or whenever the MLPs stay
                # fixed and do not require merging. This is not true for computational performance.
                model = deep_ensembles(models) if len(models) > 1 else models[0]
                        
                routine = ClassificationRoutine(
                    num_classes=num_classes,
                    model=model,
                    loss=nn.CrossEntropyLoss(),
                    eval_ood=eval_ood
                )
                res = trainer.test(routine,datamodule=dm)
            else:
                print("="*40)
                print(f"Evaluating with Temperature Scaling")

                if model_name == "bayes_vlm_b_32":
                    raise ValueError("Bayes-VLM is not compatible with temperature scaling.")
                
                model = get_model(model_name, dataset_name = "imagenet-1k", device=f"cuda:{device}")
                models = []
                for ckpt in checkpoints: 
                    model_tuned = load_clip_checkpoint(model, ckpt)
                    models.append(model_tuned)
                models = [model.eval().to(f"cuda:{device}") for model in models]

                for model in models:
                    model.text_logits = get_text_logits(
                        model.prompts, model, model.tokenizer, device=f"cuda:{device}"
                    )

                # Use the deep-ensemble constructor to evaluate ID and OOD performance.
                # With Hydra Ensembles, forward passes remain independent, so the behavior
                # matches deep ensembles in zero-shot settings or whenever the MLPs stay
                # fixed and do not require merging. This is not true for computational performance.
                model = deep_ensembles(models) if len(models) > 1 else models[0]

                dm_temp_scaling.setup("fit")
                scaler1 = TemperatureScaler(model=model, device=f"cuda:{device}")
                print("Fitting temperature scaling...")
                scaler1.fit(dataloader=dm_temp_scaling.postprocess_dataloader())
                
                tokenizer = open_clip.get_tokenizer(CLIP_MODEL[0][0])
                prompts = get_prompts(dataset = dataset_name)
                if len(models) > 1:
                    for member in model.core_models:
                        text_logits = get_text_logits(
                            prompts,
                            member,
                            tokenizer,
                            device=f"cuda:{device}"
                        )
                        member.text_logits = text_logits
                else:
                    model.text_logits = get_text_logits(
                                prompts, 
                                model,
                                tokenizer, 
                                device=f"cuda:{device}"
                            )
                
                routine = ClassificationRoutine(
                    num_classes=num_classes,
                    model=scaler1,
                    loss=nn.CrossEntropyLoss(),
                    eval_ood=eval_ood
                )
                res = trainer.test(routine,datamodule=dm)