import matplotlib.pyplot as plt
from pytorch_grad_cam import (
    EigenCAM,
    GradCAM,
    GradCAMPlusPlus,
    HiResCAM,
    XGradCAM,
    EigenGradCAM,
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from cam_methods.eigen_cam import EigenCamMultiVec
from cam_methods.tensor_cam import TSM, MTSM
import torch
import numpy as np
from utils.util_datasets import get_imagenet
from utils.util_models import (
    moco_v2_model_target_layer,
    swav_model_target_layer,
    vicreg_resnet_model_target_layer,
    vicreg_convnext_model_target_layer,
    barlowtwins_model_target_layer,
)


def main():
    # Setting paramters
    IMAGESIZE = 224
    cols = 6

    dataset = get_imagenet(image_size=IMAGESIZE, batch_size=cols)
    sample_images, labels = next(iter(dataset))
    labels = [ClassifierOutputTarget(i) for i in labels]

    model, target_layer, _ = swav_model_target_layer()

    cam_methods = [EigenCAM, TSM]

    f, axarr = plt.subplots(len(cam_methods) + 1, cols, figsize=(15, 7))
    for col in range(cols):
        for j in range(0, len(cam_methods) + 1):
            axarr[j, col].imshow(sample_images[col].permute(1, 2, 0))
            axarr[j, col].set_xticklabels([])
            axarr[j, col].set_yticklabels([])
            axarr[j, col].tick_params(left=False)
            axarr[j, col].tick_params(bottom=False)
        axarr[0, 0].set_ylabel("Input image")

    # Getting CAM methods
    for i, method in enumerate(cam_methods):
        method_name = method.__name__
        print(method_name)

        cam_model = method(model=model, target_layers=[target_layer])

        means = (
            torch.from_numpy(np.array([0.485, 0.456, 0.406]))
            .unsqueeze(0)
            .unsqueeze(-1)
            .unsqueeze(-1)
            .float()
        )
        stds = (
            torch.from_numpy(np.array([0.229, 0.224, 0.225]))
            .unsqueeze(0)
            .unsqueeze(-1)
            .unsqueeze(-1)
            .float()
        )
        sample_images = (sample_images - means) / stds
        cam = cam_model(input_tensor=sample_images, targets=labels)

        for col in range(cols):
            axarr[i + 1, col].imshow(cam[col].squeeze(), cmap="jet", alpha=0.6)
            axarr[i + 1, 0].set_ylabel(method_name)
            axarr[i + 1, col].tick_params(left=False)
            axarr[i + 1, col].tick_params(bottom=False)

    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    main()
