#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from xai.problems import ImageNetValProblem
import argparse
import tifffile
from matplotlib import pyplot as plt
import pandas as pd

import ct_experiment_utils as ceu
from folder_locations import get_imagenet_val_data_path, get_experiments_path

def load_heatmaps(folders, index, method_name):
    heatmaps = []
    for folder in folders:
        for path in (folder / "heatmaps").glob(f"heatmap_{index}_{method_name}_rep_*.tiff"):
            heatmaps.append(tifffile.imread(str(path)))
    return heatmaps

def load_durations(folders, method_names):
    dfs = []
    for folder in folders:
        dfs.append(pd.read_csv(folder / "durations.csv"))
    df = pd.concat(dfs)
    results = {}
    for method_name in method_names:
        mean_duration = np.mean(df[df["method"]==method_name]["duration"])/10
        results[method_name] = mean_duration
    return results

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot the MMBS-MBShap comparison results.")
    parser.add_argument("--experiment_names", help="Names of the experiment folders (comma separated).")
    parser.add_argument("--index", type=int, help="Index of the input image.")
    parser.add_argument("--network_name", help="Name of the network architecture.")
    args = parser.parse_args()

    index = args.index
    folders = [get_experiments_path() / name for name in args.experiment_names.split(",")]

    durations = load_durations(folders, ["mbshap", "mmbs_8", "mmbs_16"])
    print(durations)
    print(f"speedup = {durations['mbshap']/durations['mmbs_8']}")

    heatmaps_mbshap = load_heatmaps(folders, index, "mbshap")
    heatmaps_mmbs_8 = load_heatmaps(folders, index, "mmbs_8")
    heatmaps_mmbs_16 = load_heatmaps(folders, index, "mmbs_16")

    problem = ImageNetValProblem(
        data_path = get_imagenet_val_data_path(),
        network_name = args.network_name,
        num_per_class=1,
        class_step=1,
        device = "cpu")

    img, label = problem.get_sample(index)

    print(np.stack(heatmaps_mmbs_8+heatmaps_mmbs_16+heatmaps_mbshap).shape)
    v_95 = np.percentile(np.abs(np.sum(np.stack(heatmaps_mmbs_8+heatmaps_mmbs_16+heatmaps_mbshap), axis=3)), 95)

    fig, axs = plt.subplots(nrows=3, ncols=6, figsize=(15, 5))

    axs[1, 0].imshow(problem.convert_img_for_imshow(img))
    axs[0, 0].set_axis_off()
    axs[1, 0].get_xaxis().set_visible(False)
    axs[1, 0].get_yaxis().set_visible(False)
    axs[2, 0].set_axis_off()

    img_dict = {}
    for i in range(5):
        p = int(2**i)
        for j, heatmaps in enumerate([heatmaps_mbshap, heatmaps_mmbs_8, heatmaps_mmbs_16]):
            heatmap = np.sum(np.mean(np.stack(heatmaps[:p]), axis=0), axis=2)
            img_dict[(i, j)] = axs[j, i+1].imshow(heatmap, vmin=-v_95, vmax=v_95, cmap="RdBu")
            axs[j, i+1].get_xaxis().set_visible(False)
            axs[j, i+1].get_yaxis().set_visible(False)

    for j in range(3):
        fig.colorbar(img_dict[(4, j)], ax=axs[j, 5], location="right")

    plt.tight_layout()

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    plt.savefig(experiment_path / "baselines.svg", dpi=200)
    plt.close()
