import torch

import numpy as np
from cam_methods.stats import Stats
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import argparse
import os
from utils.util_models import (
    moco_v2_model_target_layer,
    swav_model_target_layer,
    vicreg_resnet_model_target_layer,
    vicreg_convnext_model_target_layer,
    barlowtwins_model_target_layer,
)
from utils.util_datasets import get_pascal


def main(model, model_name, data_loader, target_layer, tensor):
    cam_model = Stats(model=model, target_layers=[target_layer], tensor=tensor)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    device = "cpu"
    means = (
        torch.from_numpy(np.array([0.485, 0.456, 0.406]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    stds = (
        torch.from_numpy(np.array([0.229, 0.224, 0.225]))
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
        .float()
    )
    means = means.to(device)
    stds = stds.to(device)

    for i, batch in enumerate(data_loader):
        images, masks = batch
        images = images.to(device)
        masks = masks.to(device)
        labels = [ClassifierOutputTarget(0) for i in range(len(images))]

        sample_images = (images - means) / stds
        sample_images = sample_images.type_as(images)
        cam_model(input_tensor=sample_images, targets=labels)

    singular_values = np.concatenate(cam_model.singular_values, axis=0)
    singular_values = np.transpose(singular_values)
    print(singular_values.shape)
    if not os.path.exists("./arrays"):
        os.mkdir("./arrays")
    file_name = f"arrays/distribution_singular_values_{
        model_name}_{'tensor' if tensor else 'eigen'}.npy"
    print(file_name)

    with open(file_name, "wb") as f:
        np.save(f, singular_values)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="CAM evaluation.")
    parser.add_argument(
        "--model", type=str, default="vicregl_resnet", help="model to test"
    )
    parser.add_argument(
        "--eigen", default=False, action="store_true", help="EigenCAM ?"
    )
    args = parser.parse_args()

    models = {
        "vicregl_resnet": vicreg_resnet_model_target_layer,
        "vicregl_convnext": vicreg_convnext_model_target_layer,
        "barlowtwins": barlowtwins_model_target_layer,
        "moco": moco_v2_model_target_layer,
        "swav": swav_model_target_layer,
    }

    assert args.model in models.keys()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    device = "cpu"
    model, target_layer, cls_model = models[args.model]()
    model = model.to(device)
    target_layer = target_layer.to(device)
    cls_model = cls_model.to(device)

    main(
        model=model,
        model_name=args.model,
        data_loader=get_pascal(),
        target_layer=target_layer,
        tensor=not args.eigen,
    )
