#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import numpy as np
from xai.attribution import mmbs, ig
from xai.imputation import ConstantImputation
from xai.problems import ImageNetValProblem
from matplotlib import pyplot as plt

from torchvision.transforms.functional import gaussian_blur

import ct_experiment_utils as ceu
from folder_locations import get_imagenet_val_data_path, get_experiments_path

if __name__ == "__main__":
    index = 96

    problem = ImageNetValProblem(
        data_path = get_imagenet_val_data_path(),
        network_name = "ResNet50_V2",
        num_per_class=1,
        class_step=1,
        device = "cuda:0")

    img, label = problem.get_sample(index)

    print(f"Model output on the correct class: {problem.model(img)[0,label]}")

    baselines = [
        problem.normalize_intensity(torch.zeros_like(img)),
        problem.normalize_intensity(torch.ones_like(img)),
        torch.zeros_like(img),
        problem.normalize_intensity(torch.rand_like(img)),
        gaussian_blur(img, 101, 25)
    ]

    ig_heatmaps = [
        np.sum(ig(problem.model, img, label, bl, 256).cpu().numpy()[0, ...], axis=0)
        for bl in baselines
    ]

    mmbs_heatmaps = [
        np.sum(mmbs(problem.model, img, label, ConstantImputation(bl), 8, 1024).cpu().numpy()[0, ...], axis=0)
        for bl in baselines
    ]

    v_95_ig = np.percentile(np.abs(np.stack(ig_heatmaps)), 95)
    v_95_mmbs = np.percentile(np.abs(np.stack(mmbs_heatmaps)), 95)

    fig, axs = plt.subplots(nrows=3, ncols=6, figsize=(15, 5))

    axs[1, 0].imshow(problem.convert_img_for_imshow(img))
    axs[0, 0].set_axis_off()
    axs[1, 0].get_xaxis().set_visible(False)
    axs[1, 0].get_yaxis().set_visible(False)
    axs[2, 0].set_axis_off()

    for i in range(5):
        axs[0, i+1].imshow(problem.convert_img_for_imshow(baselines[i]))
        axs[0, i+1].get_xaxis().set_visible(False)
        axs[0, i+1].get_yaxis().set_visible(False)

    for i in range(5):
        im_ig = axs[1, i+1].imshow(ig_heatmaps[i], vmin=-v_95_ig, vmax=v_95_ig, cmap="RdBu")
        axs[1, i+1].get_xaxis().set_visible(False)
        axs[1, i+1].get_yaxis().set_visible(False)

    fig.colorbar(im_ig, ax=axs[1, 5], location="right")

    for i in range(5):
        im_mmbs = axs[2, i+1].imshow(mmbs_heatmaps[i], vmin=-v_95_mmbs, vmax=v_95_mmbs, cmap="RdBu")
        axs[2, i+1].get_xaxis().set_visible(False)
        axs[2, i+1].get_yaxis().set_visible(False)

    fig.colorbar(im_mmbs, ax=axs[2, 5], location="right")


    plt.tight_layout()

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    plt.savefig(experiment_path / "baselines.svg")
    plt.close()
