import numpy as np
import torchvision
import torch
from pytorch_grad_cam import EigenCAM, GradCAM, GradCAMPlusPlus, HiResCAM, XGradCAM
from cam_methods.ablation_cam import AblationCAM
from cam_methods.score_cam import ScoreCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from utils.util_datasets import get_imagenet
from utils.util_models import (
    resnet50_model_target_layer,
    convnext_base_model_target_layer,
    vgg16_model_target_layer,
)

from cam_methods.eigen_cam import EigenCamMultiVec
from cam_methods.tensor_cam import TSM, MTSM

from cam_methods.opticam import OptiCAM

import time
import argparse
from utils.metrics import AverageDrop, AverageGain, AverageIncrease


def eval(dataset, model, target_layer, cam_method, batch_size, model_name, nb_batches):
    model.eval()
    cam_model = cam_method(
        model=model, target_layers=[target_layer]
    )  # , use_cuda=torch.cuda.is_available())
    ad, ai, ag = AverageDrop(), AverageIncrease(), AverageGain()
    print(cam_method.__name__, end=" ")
    start_time = time.time()
    means = (
        torch.from_numpy(np.array([0.485, 0.456, 0.406]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    stds = (
        torch.from_numpy(np.array([0.229, 0.224, 0.225]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"
    means, stds = means.to(device), stds.to(device)

    for nb_batch, (img, labels) in enumerate(dataset):
        img = img.to(device)
        labels = labels.to(device)

        cam_labels = [ClassifierOutputTarget(i) for i in labels]
        tmp_img = (img - means) / stds
        cam = cam_model(input_tensor=tmp_img, targets=cam_labels)
        cam = torch.from_numpy(cam).unsqueeze(1).type_as(img)

        masked_img = (img * cam).type_as(img)

        combin = torch.concat((img, masked_img), dim=0).type_as(img)
        combin = (combin - means) / stds
        output = torch.softmax(model(combin), dim=-1).cpu()

        targets, preds = output[:batch_size], output[batch_size:]

        targets = [
            [targets[i, idx].item(), preds[i, idx].item()]
            for i, idx in enumerate(labels)
        ]
        targets = torch.from_numpy(np.array(targets))
        targets, preds = targets[:, 0], targets[:, 1]

        ad.update(preds, targets)
        ai.update(preds, targets)
        ag.update(preds, targets)

        if nb_batch == nb_batches:
            break
    ad, ai, ag = ad.compute(), ai.compute(), ag.compute()
    print(f"Average Drop : {ad} Average Increase : {ai} Average Gain : {ag}")

    print(f"Execution time : {time.time()-start_time}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="CAM evaluation.")
    parser.add_argument(
        "--cam_method", type=str, default="grad", help="Number of training epochs"
    )
    parser.add_argument("--model", type=str, default="resnet50", help="Latent dim size")
    parser.add_argument("--batch_size", type=int, default=20, help="Batch size")
    parser.add_argument("--nb_batches", type=int, default=1, help="Batch size")
    parser.add_argument("--image_size", type=int, default=224, help="Batch size")
    args = parser.parse_args()

    methods = {
        "ablation": AblationCAM,
        "eigen": EigenCAM,
        "grad": GradCAM,
        "grad++": GradCAMPlusPlus,
        "highres": HiResCAM,
        "score": ScoreCAM,
        "xgrad": XGradCAM,
        "eigenMultivec": EigenCamMultiVec,
        "tsm": TSM,
        "mtsm": MTSM,
        "opti": OptiCAM,
    }
    models = {
        "resnet50": resnet50_model_target_layer,
        "vgg16": vgg16_model_target_layer,
        "convnext_b": convnext_base_model_target_layer,
    }
    assert args.model in models.keys()
    assert args.cam_method in methods.keys()

    # Loading the dataset
    dataset = get_imagenet(args)

    model, target_layer = models[args.model]()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    target_layer = target_layer.to(device)

    eval(
        dataset=dataset,
        model=model,
        target_layer=target_layer,
        cam_method=methods[args.cam_method],
        batch_size=args.batch_size,
        model_name=args.model,
        nb_batches=args.nb_batches,
    )
