import matplotlib.pyplot as plt
from pytorch_grad_cam import (
    AblationCAM,
    EigenCAM,
    GradCAM,
    GradCAMPlusPlus,
    HiResCAM,
    ScoreCAM,
    XGradCAM,
    EigenGradCAM,
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from cam_methods.eigen_cam import EigenCamMultiVec
from cam_methods.tensor_cam import TSM, MTSM
from cam_methods.opticam import OptiCAM
import torch
import numpy as np

from utils.util_datasets import get_imagenet
from utils.util_models import *


def main():
    # Setting paramters
    IMAGESIZE = 224
    imgs = 5

    # Loading the dataset
    dataset = get_imagenet(image_size=IMAGESIZE, batch_size=imgs)
    sample_images, labels = next(iter(dataset))
    labels = [ClassifierOutputTarget(i) for i in labels]

    model, target_layer = resnet50_model_target_layer()

    cam_methods = [EigenCamMultiVec]

    f, axarr = plt.subplots(imgs, len(cam_methods) + 1, figsize=(15, 7))
    for j in range(imgs):
        for i in range(0, len(cam_methods) + 1):
            axarr[j, i].imshow(sample_images[j].permute(1, 2, 0))
            axarr[j, i].set_xticklabels([])
            axarr[j, i].set_yticklabels([])
            axarr[j, i].tick_params(left=False)
            axarr[j, i].tick_params(bottom=False)
        axarr[0, 0].set_title("Input image")

    # Getting CAM methods
    for i, method in enumerate(cam_methods):
        method_name = method.__name__
        print(method_name)

        if method_name == "KSVDCAM" or method_name == "KSVDCAMMultiVec":
            cam_model = method(
                model=model,
                target_layers=[target_layer],
                use_cuda=False,
                pca_kernel="rbf",
            )
        else:
            cam_model = method(model=model, target_layers=[target_layer])

        means = (
            torch.from_numpy(np.array([0.485, 0.456, 0.406]))
            .unsqueeze(0)
            .unsqueeze(-1)
            .unsqueeze(-1)
            .float()
        )
        stds = (
            torch.from_numpy(np.array([0.229, 0.224, 0.225]))
            .unsqueeze(0)
            .unsqueeze(-1)
            .unsqueeze(-1)
            .float()
        )
        sample_images = (sample_images - means) / stds
        cam = cam_model(input_tensor=sample_images, targets=labels)

        for j in range(imgs):
            axarr[j, i + 1].imshow(cam[j].squeeze(), cmap="jet", alpha=0.7)
            axarr[0, i + 1].set_title(method_name)
            axarr[j, i + 1].tick_params(left=False)
            axarr[j, i + 1].tick_params(bottom=False)

    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    main()
