import numpy as np
import torch
from cam_methods.tensor_cam import TSM, MTSM
from pytorch_grad_cam import EigenCAM
import matplotlib.pyplot as plt
import argparse
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import os
from cam_methods.eigen_cam import EigenCamMultiVec
from utils.util_models import (
    moco_v2_model_target_layer,
    swav_model_target_layer,
    vicreg_resnet_model_target_layer,
    vicreg_convnext_model_target_layer,
    barlowtwins_model_target_layer,
)
from utils.util_datasets import get_pascal


def save_image(imgs, mask, cams, model_name, method_name):
    path = f"figures/segmentation/{model_name}"
    if not os.path.exists(path):
        os.makedirs(path)
    if not os.path.exists(f"{path}/img_0.png"):
        print("Saving original images first.")
        for i, img in enumerate(imgs):
            fig = plt.figure(frameon=False)
            ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
            ax.set_axis_off()
            fig.add_axes(ax)
            ax.imshow(img.permute(1, 2, 0).numpy(), aspect="auto")
            ax.imshow(mask[i], aspect="auto", alpha=0.6, cmap="jet")
            fig.savefig(f"{path}/img_{i}.png")
            plt.close(fig)

    if not os.path.exists(f"{path}/img_0_org.png"):
        print("Saving original images first.")
        for i, img in enumerate(imgs):
            fig = plt.figure(frameon=False)
            ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
            ax.set_axis_off()
            fig.add_axes(ax)
            ax.imshow(img.permute(1, 2, 0).numpy(), aspect="auto")
            fig.savefig(f"{path}/img_{i}_org.png")
            plt.close(fig)

    for i, img in enumerate(imgs):
        fig = plt.figure(frameon=False)
        ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(img.permute(1, 2, 0).numpy(), aspect="auto")
        ax.imshow(cams[i], aspect="auto", alpha=0.6, cmap="jet")
        fig.savefig(f"{path}/img_{i}_{method_name}.png")
        plt.close(fig)
    print("Saved")


def plot(dataset, model, target_layer, cam_method, model_name):
    torch.manual_seed(2023)
    threshhold = 0.4
    method_name = cam_method.__name__
    print(f"Saving CAMs for : {method_name}")
    means = (
        torch.from_numpy(np.array([0.485, 0.456, 0.406]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    stds = (
        torch.from_numpy(np.array([0.229, 0.224, 0.225]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    cam_model = cam_method(model=model, target_layers=[target_layer])

    images, masks = next(iter(dataset))
    labels = [ClassifierOutputTarget(0) for i in range(len(images))]
    norm_imgs = (images - means) / stds

    cam = cam_model(input_tensor=norm_imgs, targets=labels)
    masks = masks.squeeze().detach().cpu().numpy()

    cam[cam >= threshhold] = 1.0
    cam[cam < threshhold] = 0.0
    cam = torch.zeros(len(images), 1, images.shape[-2], images.shape[-1])

    save_image(images, masks, cam, model_name, method_name)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="CAM evaluation.")
    parser.add_argument(
        "--model", type=str, default="vicregl_resnet", help="Latent dim size"
    )
    parser.add_argument("--batch_size", type=int, default=10, help="Batch size")
    parser.add_argument("--image_size", type=int, default=224, help="Batch size")
    parser.add_argument(
        "--cam_method", type=str, default="eigen", help="Number of training epochs"
    )
    args = parser.parse_args()

    models = {
        "vicregl_resnet": vicreg_resnet_model_target_layer,
        "vicregl_convnext": vicreg_convnext_model_target_layer,
        "barlowtwins": barlowtwins_model_target_layer,
        "moco": moco_v2_model_target_layer,
        "swav": swav_model_target_layer,
    }
    methods = {
        "eigen": EigenCAM,
        "tsm": TSM,
        "mtsm": MTSM,
        "eigenMultivec": EigenCamMultiVec,
    }
    assert args.model in models.keys()
    model, target_layer = models[args.model]()

    plot(
        dataset=get_pascal(args),
        model=model,
        target_layer=target_layer,
        cam_method=methods[args.cam_method],
        model_name=args.model,
    )
