import torch

import numpy as np
from pytorch_grad_cam import EigenCAM
from cam_methods.tensor_cam import TSM, MTSM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import argparse
from cam_methods.eigen_cam import EigenCamMultiVec
from utils.util_models import (
    moco_v2_model_target_layer,
    swav_model_target_layer,
    vicreg_resnet_model_target_layer,
    vicreg_convnext_model_target_layer,
    barlowtwins_model_target_layer,
)
from utils.util_datasets import get_pascal


def main(model, data_loader, target_layer, cam_method):
    cam_model = cam_method(model=model, target_layers=[target_layer])
    device = "cuda" if torch.cuda.is_available() else "cpu"
    device = "cpu"
    means = (
        torch.from_numpy(np.array([0.485, 0.456, 0.406]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    stds = (
        torch.from_numpy(np.array([0.229, 0.224, 0.225]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    means = means.to(device)
    stds = stds.to(device)
    mIoU = {}
    for i in range(4, 10):
        mIoU[i] = []

    for batch in data_loader:
        images, masks = batch
        images = images.to(device)
        masks = masks.to(device)
        labels = [ClassifierOutputTarget(0) for i in range(len(images))]

        sample_images = (images - means) / stds
        sample_images = sample_images.type_as(images)
        cam = cam_model(input_tensor=sample_images, targets=labels)
        masks = masks.squeeze().detach().cpu().numpy()

        for i in range(4, 10):
            threshhold = i / 10.0
            segmentation = np.zeros_like(cam)
            segmentation[cam >= threshhold] = 1.0

            union = np.zeros_like(cam)
            union[np.logical_or(masks == 1, segmentation == 1)] = 1
            intersection = np.zeros_like(cam)
            intersection[np.logical_and(masks == 1, segmentation == 1)] = 1

            tmp = intersection.sum(axis=(-2, -1)) / union.sum(axis=(-2, -1))
            mIoU[i].append(tmp.mean())

    for i in range(4, 10):
        mIoU[i] = sum(mIoU[i]) / len(mIoU[i])
        print(f"thresh : {i/10.} mIoU : {mIoU[i]}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="CAM evaluation.")
    parser.add_argument(
        "--cam_method", type=str, default="eigen", help="Number of training epochs"
    )
    parser.add_argument(
        "--model", type=str, default="vicregl_resnet", help="Latent dim size"
    )
    args = parser.parse_args()

    methods = {
        "eigen": EigenCAM,
        "tsm": TSM,
        "mtsm": MTSM,
        "eigenMultivec": EigenCamMultiVec,
    }
    models = {
        "vicregl_resnet": vicreg_resnet_model_target_layer,
        "vicregl_convnext": vicreg_convnext_model_target_layer,
        "barlowtwins": barlowtwins_model_target_layer,
        "moco": moco_v2_model_target_layer,
        "swav": swav_model_target_layer,
    }

    assert args.model in models.keys()
    assert args.cam_method in methods.keys()

    # Loading the dataset
    device = "cuda" if torch.cuda.is_available() else "cpu"
    device = "cpu"
    model, target_layer, cls_model = models[args.model]()
    model = model.to(device)
    target_layer = target_layer.to(device)
    cls_model = cls_model.to(device)

    main(
        model=model,
        data_loader=get_pascal(),
        target_layer=target_layer,
        cam_method=methods[args.cam_method],
    )
