import numpy as np
import torchvision
import torch
from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from cam_methods.tensor_cam import TSM, MTSM
from cam_methods.eigen_cam import EigenCamMultiVec
from utils.util_models import (
    moco_v2_model_target_layer,
    swav_model_target_layer,
    vicreg_resnet_model_target_layer,
    vicreg_convnext_model_target_layer,
    barlowtwins_model_target_layer,
)
from utils.util_datasets import get_imagenet
import time
import argparse
from utils.metrics import AverageDrop, AverageGain, AverageIncrease


def eval(
    dataset,
    model,
    target_layer,
    cam_method,
    batch_size,
    model_name,
    nb_batches,
    cls_model,
):
    model.eval()
    cls_model.eval()
    cam_model = cam_method(
        model=model, target_layers=[target_layer]
    )  # , use_cuda=torch.cuda.is_available())
    ad, ai, ag = AverageDrop(), AverageIncrease(), AverageGain()
    print(cam_method.__name__)
    start_time = time.time()
    means = (
        torch.from_numpy(np.array([0.485, 0.456, 0.406]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    stds = (
        torch.from_numpy(np.array([0.229, 0.224, 0.225]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"
    means, stds = means.to(device), stds.to(device)

    for nb_batch, (img, labels) in enumerate(dataset):
        img = img.to(device)
        labels = labels.to(device)

        cam_labels = [ClassifierOutputTarget(i) for i in labels]
        tmp_img = (img - means) / stds
        cam = cam_model(input_tensor=tmp_img, targets=cam_labels)
        cam = torch.from_numpy(cam).unsqueeze(1).type_as(img)

        masked_img = (img * cam).type_as(img)

        combin = torch.concat((img, masked_img), dim=0).type_as(img)
        combin = (combin - means) / stds
        output = torch.softmax(cls_model(combin), dim=-1).cpu()

        targets, preds = output[:batch_size], output[batch_size:]

        targets = [
            [targets[i, idx].item(), preds[i, idx].item()]
            for i, idx in enumerate(labels)
        ]
        targets = torch.from_numpy(np.array(targets))
        targets, preds = targets[:, 0], targets[:, 1]
        if device == "cuda":
            torch.cuda.memory_summary(device=None, abbreviated=False)

        ad.update(preds, targets)
        ai.update(preds, targets)
        ag.update(preds, targets)

        if nb_batch == nb_batches:
            break
    ad, ai, ag = ad.compute(), ai.compute(), ag.compute()
    print(f"Average Drop : {ad} Average Increase : {ai} Average Gain : {ag}")
    print(f"Execution time : {time.time()-start_time}")


def eval_MSE(
    dataset, model, target_layer, cam_method, batch_size, model_name, nb_batches
):
    model.eval()
    cam_model = cam_method(
        model=model, target_layers=[target_layer]
    )  # , use_cuda=torch.cuda.is_available())
    print(cam_method.__name__)
    start_time = time.time()
    means = (
        torch.from_numpy(np.array([0.485, 0.456, 0.406]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    stds = (
        torch.from_numpy(np.array([0.229, 0.224, 0.225]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"
    means, stds = means.to(device), stds.to(device)
    loss = []

    for nb_batch, (img, labels) in enumerate(dataset):
        img = img.to(device)

        cam_labels = [ClassifierOutputTarget(i) for i in range(len(labels))]
        tmp_img = (img - means) / stds
        cam = cam_model(input_tensor=tmp_img, targets=cam_labels)
        cam = torch.from_numpy(cam).unsqueeze(1).type_as(img)

        masked_img = (img * cam).type_as(img)

        combin = torch.concat((img, masked_img), dim=0).type_as(img)
        combin = (combin - means) / stds

        if model_name == "vicregl_resnet":
            output = model(combin)[0].cpu()
            output = torch.nn.AdaptiveAvgPool2d(output_size=(1, 2048))(output).squeeze()
        elif model_name == "vicregl_convnext":
            output = model(combin)[0].cpu()
            output = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1024))(output).squeeze()
        elif model_name in ["barlowtwins", "moco", "swav"]:
            activation = {}

            def get_activation(name):
                def hook(model, input, output):
                    activation[name] = output.detach()

                return hook

            model.avgpool.register_forward_hook(get_activation("fc3"))
            output = model(combin)
            output = activation["fc3"].cpu().squeeze()

        loss.append(
            torch.nn.functional.mse_loss(
                output[:batch_size], output[batch_size:], reduction="mean"
            )
            .detach()
            .item()
        )

        if nb_batch == nb_batches:
            break

    mse = sum(loss) / len(loss)
    print(f"MSE : {mse}")
    print(f"Execution time : {time.time()-start_time}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="CAM evaluation.")
    parser.add_argument(
        "--cam_method", type=str, default="eigen", help="CAM method to evaluate"
    )
    parser.add_argument(
        "--model", type=str, default="vicregl_resnet", help="Model to interpret"
    )
    parser.add_argument("--batch_size", type=int, default=5, help="Batch size")
    parser.add_argument(
        "--nb_batches",
        type=int,
        default=1,
        help="Number of batches to process in the validation dataset",
    )
    parser.add_argument(
        "--image_size", type=int, default=224, help="Image size of the dataset"
    )
    parser.add_argument("--mse", action="store_true")
    args = parser.parse_args()

    methods = {
        "eigen": EigenCAM,
        "eigenMultivec": EigenCamMultiVec,
        "tsm": TSM,
        "mtsm": MTSM,
    }
    models = {
        "vicregl_resnet": vicreg_resnet_model_target_layer,
        "vicregl_convnext": vicreg_convnext_model_target_layer,
        "barlowtwins": barlowtwins_model_target_layer,
        "moco": moco_v2_model_target_layer,
        "swav": swav_model_target_layer,
    }

    assert args.model in models.keys()
    assert args.cam_method in methods.keys()

    # Loading the dataset
    dataset = get_imagenet(args)

    model, target_layer, cls_model = models[args.model]()
    model = model.to("cuda" if torch.cuda.is_available() else "cpu")
    target_layer = target_layer.to("cuda" if torch.cuda.is_available() else "cpu")
    cls_model = cls_model.to("cuda" if torch.cuda.is_available() else "cpu")

    if args.mse:
        eval_MSE(
            dataset=dataset,
            model=model,
            target_layer=target_layer,
            cam_method=methods[args.cam_method],
            batch_size=args.batch_size,
            model_name=args.model,
            nb_batches=args.nb_batches,
        )
    else:
        eval(
            dataset=dataset,
            model=model,
            target_layer=target_layer,
            cam_method=methods[args.cam_method],
            batch_size=args.batch_size,
            model_name=args.model,
            nb_batches=args.nb_batches,
            cls_model=cls_model,
        )
