import numpy as np
import torchvision
import torch
from pytorch_grad_cam import (EigenCAM, GradCAM, GradCAMPlusPlus,
                              HiResCAM, XGradCAM)
from cam_methods.ablation_cam import AblationCAM
from cam_methods.score_cam import ScoreCAM
from utils.util_models import *
from utils.util_datasets import get_imagenet

from cam_methods.eigen_cam import EigenCamMultiVec
from cam_methods.tensor_cam import TSM, MTSM
from cam_methods.opticam import OptiCAM
import matplotlib.pyplot as plt
import argparse
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import os

def save_image(imgs, cams, model_name, method_name):
    path = f'figures/{model_name}'
    if not os.path.exists(path):
        os.makedirs(path)
    if not os.path.exists(f'{path}/img_0.png'):
        print('Saving original images first.')
        for i, img in enumerate(imgs):
            fig = plt.figure(frameon=False)
            ax = plt.Axes(fig, [0., 0., 1., 1.])
            ax.set_axis_off()
            fig.add_axes(ax)
            ax.imshow(img.permute(1,2,0).numpy(), aspect='auto')
            fig.savefig(f'{path}/img_{i}.png')
        
    for i, img in enumerate(imgs):
        fig = plt.figure(frameon=False)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(img.permute(1,2,0).numpy(), aspect='auto')
        ax.imshow(cams[i], aspect='auto', alpha=.6, cmap='jet')
        fig.savefig(f'{path}/img_{i}_{method_name}.png')
    print('Saved')

def plot(dataset, model, target_layer, cam_method, model_name):
    torch.manual_seed(1)
    method_name = cam_method.__name__
    print(f"Saving CAMs for : {method_name}")
    means = torch.from_numpy(np.array([0.485, 0.456, 0.406])).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float()
    stds = torch.from_numpy(np.array([0.229, 0.224, 0.225])).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float()
    cam_model = cam_method(model=model, target_layers=[target_layer], reshape_transform=None)
    
    sample_images, labels = next(iter(dataset))
    labels = [ClassifierOutputTarget(i) for i in labels]
    sample_images, labels = sample_images[:5], labels[5]
    
    normed_images = (sample_images - means)/stds
    cams = cam_model(input_tensor=normed_images, targets=labels)
    
    save_image(sample_images, cams, model_name, method_name)



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='CAM evaluation.')
    parser.add_argument('--cam_method', type=str, default='grad', help='Number of training epochs')
    parser.add_argument('--model', type=str, default='resnet50', help='Latent dim size')
    parser.add_argument('--batch_size', type=int, default=20, help='Batch size')
    parser.add_argument('--image_size', type=int, default=224, help='Batch size')
    args = parser.parse_args()

    methods = {
        'ablation': AblationCAM,
        'eigen': EigenCAM, 
        'grad': GradCAM, 
        'grad++': GradCAMPlusPlus, 
        'highres': HiResCAM, 
        'score': ScoreCAM, 
        'xgrad': XGradCAM, 
        'tsm': TSM, 
        'mtsm': MTSM,
        'opti':OptiCAM,
        'eigenMultivec': EigenCamMultiVec
    }
    models = {
        'resnet50' : resnet50_model_target_layer,
        'vgg16': vgg16_model_target_layer,
        'convnext_b': convnext_base_model_target_layer,
        'vicregl_resnet' : vicreg_resnet_model_target_layer,
        'vicregl_convnext': vicreg_convnext_model_target_layer,
        'barlowtwins': barlowtwins_model_target_layer,
        'moco': moco_v2_model_target_layer,
        'swav': swav_model_target_layer
    }
    assert args.model in models.keys()
    assert args.cam_method in methods.keys()

    model, target_layer = models[args.model]()
    supervised = True if args.model in ['resnet50', 'vgg16', 'convnext_b'] else False
    
    plot(dataset=get_imagenet(args), 
         model=model, 
         target_layer=target_layer, 
         cam_method=methods[args.cam_method], 
         model_name=args.model, 
        )