import warnings
import types
from tqdm import tqdm
import sys, os
from pathlib import Path
import argparse
warnings.filterwarnings('ignore')

import torch
import numpy as np
from torch.nn import functional as F
import torch.nn as nn
import open_clip 

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
BASE_DIR = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(BASE_DIR))

from models.utils import get_prompts
from vilu.models import viluAttention
from vilu.metrics import get_fpr95, get_auroc, get_aupr
from torch_uncertainty.metrics.classification import \
    AdaptiveCalibrationError, CalibrationError, AURC, AUGRC, CovAt5Risk, RiskAt80Cov

from utils.datamodules import \
    (CIFAR10DataModule, Food101DataModule, Sun397DataModule, DTDDataModule, 
     OxfordIIITPetDataModule, EuroSATDataModule, Caltech101DataModule)
from torch_uncertainty.datamodules.classification import CIFAR100DataModule, ImageNetDataModule

MODEL = (
    "ViT-B-32",
    "laion2b_s34b_b79k" 
    )
VILU_PATH = {
    "imagenet-1k": "vilu/weights/vilu_imagenet_laion.ckpt",
    "cifar100": 'vilu/weights/vilu_cifar100_laion.ckpt',
    "cifar10": "vilu/weights/vilu_cifar10_laion.ckpt",
    "food101": "vilu/weights/vilu_food101_laion.ckpt",
    "sun397": "vilu/weights/vilu_sun397_laion.ckpt",
    "oxford_pet": "vilu/weights/vilu_oxford_pet_laion.ckpt",
    "dtd": "vilu/weights/vilu_dtd_laion.ckpt",
    "eurosat": "vilu/weights/vilu_eurosat_laion.ckpt",
    "caltech101": "vilu/weights/vilu_caltech101_laion.ckpt"
    }

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate ViLU on CIFAR100 or ImageNet-1k")
    parser.add_argument("--dataset_cls", type=str, nargs="+", required=False, 
                        choices=[
                            "cifar100", "imagenet-1k", "cifar10", 
                            "sun397", "food101", "eurosat",
                            "dtd", "oxford_pet", "caltech101"
                            ], default=None)
    parser.add_argument("--dataset_ood", type=str, required=False, choices=["cifar100", "imagenet-1k"])
    parser.add_argument("--device", type=str, default=None)
    return parser.parse_args()

def evaluate_ood_with_msp(model, dataloader_id, dataloaders_ood, device="cuda"):
    model.eval()
    results = {}

    with torch.no_grad():
        # Collect ID confidences
        id_confidences = []
        for images, _ in tqdm(dataloader_id["id"], leave=False, desc="ID"):
            images = images.to(device)
            logits = model(images)
            probs = torch.softmax(logits, dim=1)
            conf, _ = probs.max(dim=1)
            conf = -conf
            id_confidences.extend(conf.cpu().numpy())
        id_confidences = np.array(id_confidences)

        # OOD evaluation
        for dataset_name, dataloader in dataloaders_ood.items():
            ood_confidences = []
            for images, labels in tqdm(dataloader, leave=False, desc=dataset_name):
                images = images.to(device)
                logits = model(images)
                probs = torch.softmax(logits, dim=1)
                conf, _ = probs.max(dim=1)
                conf = -conf
                ood_confidences.extend(conf.cpu().numpy())
            ood_confidences = np.array(ood_confidences)

            # Metrics
            auroc = get_auroc(ood_confidences, id_confidences)
            fpr95 = get_fpr95(ood_confidences, id_confidences)
            aupr = get_aupr(ood_confidences, id_confidences)

            results[dataset_name] = {
                "AUROC": auroc,
                "FPR95": fpr95,
                "AUPR": aupr
            }

    return results

def get_datamodule(
        dataset,
        val_tfms
        ):

    if dataset == "cifar100":
        cifar100_path = "data/cifar100_torch_uncertainty"
        os.makedirs(cifar100_path, exist_ok=True)
        dm = CIFAR100DataModule(
            root=cifar100_path,
            batch_size=32,
            test_transform=val_tfms,
            num_workers=4,
            eval_ood=True,
            pin_memory=True,
            persistent_workers=False,
            )
        dm.setup()
        dm.num_classes = 100
    elif dataset == "imagenet-1k":
        in1k_path = "data/in1k_torch_uncertainty"
        dm = ImageNetDataModule(
            root=in1k_path,
            batch_size=32,
            test_transform=val_tfms,
            num_workers=4,
            eval_ood=True,
            pin_memory=True,
            persistent_workers=False
        )
        dm.setup("test")
        dm.num_classes = 1000
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")
    
    return dm

def get_dataloaders_cls(
        dataset_names,
        val_tfms
    ):
    dataloaders_cls = {}
    os.makedirs("data", exist_ok=True)

    if "imagenet-1k" in dataset_names:
        dm = ImageNetDataModule(
            root="data/in1k_torch_uncertainty",
            batch_size=64,
            test_transform=val_tfms,
            num_workers=4,
            eval_ood=False,
            pin_memory=True,
            persistent_workers=False
            )
        dm.setup("test")
        dataloader = dm.test_dataloader()[0]
        dataloaders_cls["imagenet-1k"] = dataloader
    if "cifar100" in dataset_names:
        dm = CIFAR100DataModule(
                root="data/cifar100_torch_uncertainty",
                batch_size=64,
                test_transform=val_tfms,
                num_workers=4,
                eval_ood=False,
                pin_memory=True,
                persistent_workers=False,
                )
        dm.setup("test")
        dataloader = dm.test_dataloader()[0]
        dataloaders_cls["cifar100"] = dataloader
    if "cifar10" in dataset_names:
        dm = CIFAR10DataModule(val_tfms)
        dm.setup()
        dataloader = dm.test_dataloader()
        dataloaders_cls["cifar10"] = dataloader
    if "sun397" in dataset_names:
        dm = Sun397DataModule(val_tfms)
        dm.setup()
        dataloader = dm.test_dataloader()
        dataloaders_cls["sun397"] = dataloader
    if "food101" in dataset_names:
        dm = Food101DataModule(val_tfms)
        dm.setup()
        dataloader = dm.test_dataloader()
        dataloaders_cls["food101"] = dataloader
    if "oxford_pet" in dataset_names:
        dm = OxfordIIITPetDataModule(val_tfms)
        dm.setup()
        dataloader = dm.test_dataloader()
        dataloaders_cls["oxford_pet"] = dataloader
    if "eurosat" in dataset_names:
        dm = EuroSATDataModule(val_tfms)
        dm.setup()
        dataloader = dm.test_dataloader()
        dataloaders_cls["eurosat"] = dataloader
    if "dtd" in dataset_names:
        dm = DTDDataModule(val_tfms)
        dm.setup()
        dataloader = dm.test_dataloader()
        dataloaders_cls["dtd"] = dataloader
    if "caltech101" in dataset_names:
        dm = Caltech101DataModule(val_tfms)
        dm.setup()
        dataloader = dm.test_dataloader()
        dataloaders_cls["caltech101"] = dataloader

    return dataloaders_cls

def get_vilu_scores(
        clip_model,
        text_feats,
        vilu,
        dataloaders,
        device
):

    vilu_scores_dict = {}

    for dataset_name, dataloader in dataloaders.items():
        vilu_scores_ = []

        with torch.no_grad():
            pbar = tqdm(dataloader, desc=f"{dataset_name}", leave=False)
            for _, batch in enumerate(pbar):
                images, _ = batch
                images = images.to(device)
                visual_feats = clip_model.encode_image(images)
                visual_feats /= visual_feats.norm(dim=-1, keepdim=True)

                with torch.autocast(device_type='cuda'):
                    vilu_scores = vilu(visual_feats, text_feats).squeeze(-1)
                    vilu_scores = torch.sigmoid(vilu_scores)

                vilu_scores_.append(vilu_scores)

        vilu_scores_dict[dataset_name] = -torch.cat(vilu_scores_).cpu().numpy()

    return vilu_scores_dict

def eval_cls_vilu(
        clip_model,
        text_features,
        vilu,
        dataloader,
        dataset_name,
        device
        ):
    
    ece = CalibrationError(task = "binary", num_classes = 2).to(device)
    aece = AdaptiveCalibrationError(task = "binary", num_classes = 2).to(device)
    augrc = AUGRC().to(device)
    aurc = AURC().to(device)
    cov_at5 = CovAt5Risk().to(device)
    risk_at80 = RiskAt80Cov().to(device)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"{dataset_name}", leave=False)
        for i, batch in enumerate(pbar):
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            visual_feats = clip_model.encode_image(images)
            visual_feats /= visual_feats.norm(dim=-1, keepdim=True)
            logits = (clip_model.logit_scale.exp() * visual_feats @ text_features.T)
            pred_class = torch.argmax(logits, dim=-1)
            correct = (pred_class == labels).long()

            vilu_scores = vilu(visual_feats, text_features).squeeze(-1)
            vilu_probabs = torch.sigmoid(vilu_scores)

            ece.update(vilu_probabs, correct)
            aece.update(vilu_probabs, correct)
            augrc.update(vilu_probabs, correct)
            aurc.update(vilu_probabs, correct)
            cov_at5.update(vilu_probabs, correct)
            risk_at80.update(vilu_probabs, correct)

    ece_score = ece.compute()
    aece_score = aece.compute()
    augrc_score = augrc.compute()
    aurc_score = aurc.compute()
    cov_at5_score = cov_at5.compute()
    risk_at80_score = risk_at80.compute()

    print(f"{dataset_name}:\n\t-ECE: {ece_score:.3f} \n\t-AECE: {aece_score:.3f} \
          \n\t-AUGRC: {augrc_score:.3f} \n\t-AURC: {aurc_score:.3f} \n\t-Cov@5: {cov_at5_score:.3f} \
          \n\t-Risk@80: {risk_at80_score:.3f}")

def eval_vilu_ood(
        vilu_id_scores,
        vilu_ood_scores
):
    for dataset_name, vilu_ood_scores_ in vilu_ood_scores.items():
        fpr_vilu_ood = get_fpr95(vilu_ood_scores_, vilu_id_scores) * 100
        auc_vilu_ood = get_auroc(vilu_ood_scores_, vilu_id_scores) * 100
        aupr_vilu_ood = get_aupr(vilu_ood_scores_, vilu_id_scores) * 100

        print(f"[{dataset_name}]\tAUROC:{auc_vilu_ood:.3f},\tFPR95: {fpr_vilu_ood:.3f}\tAUPR:{aupr_vilu_ood:.3f}")

def get_text_features(
        dataset_name,
        clip_model,
        tokenizer,
        device
        ):
    with torch.no_grad():
        prompts = get_prompts(dataset = dataset_name)
        prompts = tokenizer(prompts).to(device)
        text_feats = clip_model.encode_text(prompts)
        text_feats /= text_feats.norm(dim=-1, keepdim=True)
        text_feats = text_feats
    return text_feats

def clip_forward(self, x):
    image_logits = self.encode_image(x)
    image_logits /= image_logits.norm(dim=-1, keepdim=True)
    outputs = (self.logit_scale.exp() * image_logits @ self.text_logits.T)
    return outputs

def main():

    config = parse_args()
    dataset_ood = config.dataset_ood
    dataset_cls = config.dataset_cls
    device = config.device

    assert dataset_ood or dataset_cls, \
        "Use at least one of --dataset_cls or --dataset_ood"

    if not device:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = f"cuda:{device}"

    clip_model, _, val_tfms = open_clip.create_model_and_transforms(
            MODEL[0],
            MODEL[1]
        )
    tokenizer = open_clip.get_tokenizer(MODEL[0])
    clip_model.to(device)
    clip_model.eval()

    if dataset_ood:
        dm = get_datamodule(dataset_ood, val_tfms)
        text_features = get_text_features(
            dataset_name=dataset_ood,
            clip_model=clip_model,
            tokenizer=tokenizer,
            device=device
            )
        if dataset_ood == "cifar100":
            ood_dataset = {
                3: "cifar10",
                4: "tinyimagenet",
                5: "texture",
                6: "mnist",
                7: "svhn",
                8: "places365"
            }
        elif dataset_ood == "imagenet-1k":
            ood_dataset = {
                3: "ssb",
                4: "ninco",
                5: "inaturalist",
                6: "texture"
            }		

        dataloaders_ood = {}
        for id, name_ood_dataset in ood_dataset.items():
            dataloaders_ood[name_ood_dataset] = dm.test_dataloader()[id]
        dataloader_id = {
            "id": dm.test_dataloader()[1]
            }
        
        vilu = viluAttention(
            concat=True,
            identity_init=True,
            n_iter_freeze_proj=1000,
            use_predicted_caption=True,
            use_attention=True,
            )
        vilu_path = VILU_PATH[dataset_ood]
        ckpt_vilu = torch.load(vilu_path, weights_only=False, map_location='cuda' if torch.cuda.is_available() else 'cpu')
        vilu.load_state_dict(ckpt_vilu['state_dict'])
        vilu.to(device)
        vilu.eval()

        print(f"Getting VILU scores for ID {dataset_ood}")
        vilu_id_scores = get_vilu_scores(
            clip_model,
            text_features,
            vilu,
            dataloader_id,
            device
            )
        print(f"Getting VILU scores for OOD {dataset_ood}")
        vilu_ood_scores = get_vilu_scores(
            clip_model,
            text_features,
            vilu,
            dataloaders_ood,
            device
            )

        eval_vilu_ood(
            vilu_id_scores["id"],
            vilu_ood_scores
            )
        
    if dataset_cls:
        dataloaders_cls = get_dataloaders_cls(dataset_cls, val_tfms)
        text_features = dict()
        for dataset_name in dataloaders_cls:
            text_features[dataset_name] = get_text_features(
                dataset_name=dataset_name,
                clip_model=clip_model,
                tokenizer=tokenizer,
                device=device
            )

        vilu = viluAttention(
            concat=True,
            identity_init=True,
            n_iter_freeze_proj=1000,
            use_predicted_caption=True,
            use_attention=True,
            )
        vilu.to(device)
        vilu.eval()
        for dataset_name, dataloader in dataloaders_cls.items():
            vilu_path = VILU_PATH[dataset_name]
            ckpt_vilu = torch.load(vilu_path, weights_only=False, map_location='cuda' if torch.cuda.is_available() else 'cpu')
            vilu.load_state_dict(ckpt_vilu['state_dict'])
                
            eval_cls_vilu(
                clip_model,
                text_features[dataset_name],
                vilu,
                dataloader,
                dataset_name,
                device
            )
        
if __name__ == "__main__":
    main()