#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
import tifffile
import torch as torch
import argparse

import ct_experiment_utils as ceu
from folder_locations import get_experiments_path, get_fashion_mnist_data_path

if __name__ == "__main__":
    device = "CPU"
    parser = argparse.ArgumentParser(description="Plot the Fashion MNSIT sweep results.")
    parser.add_argument("--experiment", help="Name of the experiment folder")
    args = parser.parse_args()

    heatmaps_path = get_experiments_path() / args.experiment / "heatmaps"

    test_data = torch.load(get_fashion_mnist_data_path() / 'test_data.pt')

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())

    heatmaps = {}

    heatmaps[("ig")] = tifffile.imread(str(heatmaps_path/"heatmap_ig.tiff"))
    for j in range(7):
        num_paths = 4**j
        heatmaps[("mbshap", j)] = tifffile.imread(str(heatmaps_path/f"heatmap_mbshap_paths_{num_paths}.tiff"))
        for k, m in enumerate(reversed([1, 2, 4, 8, 16, 32])):
            heatmaps[("mmbs", j, k)] = tifffile.imread(str(heatmaps_path/f"heatmap_mmbs_paths_{num_paths}_steps_{m}.tiff"))

    for key in heatmaps.keys():
        if heatmaps[key].shape[0] == 1:
            heatmaps[key] = heatmaps[key][0, 0, :, :]

    fig, axs = plt.subplots(nrows=8, ncols=7, figsize=(16, 16))
    v_95 = np.percentile(np.abs(np.stack(list(heatmaps.values()))), 95)

    axs[0, 0].imshow(test_data[0, 0])
    axs[0, 0].set_axis_off()
    #axs[0, 0].set_title("Image")

    axs[0, 1].imshow(heatmaps[("ig")], vmin=-v_95, vmax=v_95, cmap="RdBu")
    axs[0, 1].set_axis_off()
    #axs[0, 1].set_title("IG")

    for j in range(2, 7):
        axs[0, j].imshow(np.zeros((2, 2)), vmin=0, vmax=1, cmap="binary")
        axs[0, j].set_axis_off()

    for j in range(7):
        num_paths = 4**j
        axs[1, j].imshow(heatmaps[("mbshap", j)], vmin=-v_95, vmax=v_95, cmap="RdBu")
        axs[1, j].set_axis_off()
        #axs[1, j].set_title(f"MBShap (p={num_paths})")
        for k, m in enumerate(reversed([1, 2, 4, 8, 16, 32])):
            axs[k+2, j].imshow(heatmaps[("mmbs", j, k)], vmin=-v_95, vmax=v_95, cmap="RdBu")
            axs[k+2, j].set_axis_off()
            #axs[k+2, j].set_title(f"MMBS (s={m}, p={num_paths})")
    plt.tight_layout()
    #plt.show()
    plt.savefig(experiment_path / "fashion_mnist_sweep.svg")
