#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import numpy as np
from xai.attribution import mmbs, ig, gig, xrai, gradcam, smoothgrad
from xai.imputation import ConstantImputation
from xai.problems import ImageNetValProblem
from matplotlib import pyplot as plt
from functools import partial

import ct_experiment_utils as ceu
from folder_locations import get_imagenet_val_data_path, get_experiments_path

if __name__ == "__main__":
    label = 251
    ig_steps = 256
    mmbs_steps = 8
    mmbs_samples = 1024
    indices = [36, 39, 5, 45]

    problem = ImageNetValProblem(
        data_path = get_imagenet_val_data_path(),
        network_name = "ResNet50_V2",
        num_per_class=50,
        class_step=1,
        device = "cuda:0")

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    for i in indices:
        img, _ = problem.get_sample(label*50+i)

        print(f"Img {i} output on the correct class: {problem.model(img)[0,label]}")

        bl = torch.zeros_like(img)
        imputation = ConstantImputation(bl)

        heatmap_mmbs = torch.sum(mmbs(problem.model, img, label, imputation, mmbs_steps, mmbs_samples, progress_bar=True), dim=1).cpu().detach().numpy()[0, ...]
        heatmap_ig = torch.sum(ig(problem.model, img, label, bl, ig_steps), dim=1).cpu().detach().numpy()[0, ...]
        heatmap_xrai = xrai(problem.model, img, label).cpu().detach().numpy()[0, 0, ...]
        heatmap_gradcam = gradcam(problem.model, img, label, problem.get_gradcam_layer()).cpu().detach().numpy()[0, 0, ...]
        heatmap_gig_sg = torch.sum(smoothgrad(
            partial(gig, model=problem.model, label=label, baseline=bl, num_steps=ig_steps, fraction=0.25, max_dist=0.02),
            img, progress_bar=True), dim=1).cpu().detach().numpy()[0, ...]

        heatmaps = [heatmap_mmbs, heatmap_gig_sg, heatmap_xrai, heatmap_ig, heatmap_gradcam]
        method_names = ["MMBS (ours)", "GIG+SG", "XRAI", "IG", "GradCAM"]

        fig, axs = plt.subplots(nrows=1, ncols=6, figsize=(30, 3))

        axs[0].imshow(problem.convert_img_for_imshow(img))
        #axs[0].set_title("Input image")
        axs[0].set_xticks(np.arange(0, 224, 56))
        axs[0].set_yticks(np.arange(0, 224, 56))
        axs[0].set_xticklabels([])
        axs[0].set_yticklabels([])
        axs[0].tick_params(axis="both", length=0, labelbottom=False, labelleft=False)
        axs[0].grid()


        for j, heatmap, method_name in zip(range(1, len(heatmaps)+1), heatmaps, method_names):
            v_99 = np.percentile(np.abs(heatmap), 99)
            im = axs[j].imshow(heatmap, vmin=-v_99, vmax=v_99, cmap="RdBu")
            #axs[j].set_title(method_name)
            axs[j].set_xticks(np.arange(0, 224, 56))
            axs[j].set_yticks(np.arange(0, 224, 56))
            axs[j].set_xticklabels([])
            axs[j].set_yticklabels([])
            axs[j].grid()
            axs[j].tick_params(axis="both", length=0, labelbottom=False, labelleft=False)
            #fig.colorbar(im, ax=axs[j], location="right")

        plt.tight_layout()
        plt.savefig(experiment_path / f"attribution_maps_{i}.svg", dpi=200)
        plt.close()