#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

import ct_experiment_utils as ceu
from folder_locations import get_experiments_path

if __name__ == "__main__":
    in_experiment_tuples = [
        (get_experiments_path() / "calc_method_comparison_experiments" / "combined" / "Fashion", "Fashion MNIST"),
        (get_experiments_path() / "calc_method_comparison_experiments" / "combined" / "ViT", "ImageNet ViT"),
        (get_experiments_path() / "calc_method_comparison_experiments" / "combined" / "ResNet", "ImageNet ResNet"),
    ]

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())

    latex_names = ["MMBS (ours)", "MMBS + SG", "IG", "IG + SG", "GIG (paper)", "GIG (paper) + SG", "GIG (Saliency)", "GIG (Saliency) + SG", "XRAI (B + W)", "XRAI (zero)", "GradCAM", "Random"]
    col_names = ["mmbs", "mmbs_sg", "ig", "ig_sg", "gig_paper", "gig_paper_sg", "gig_saliency", "gig_saliency_sg", "xrai", "xrai_bl", "gradcam", "random"]

    palette = sns.color_palette()

    plt.figure(figsize=(11, 2.8))
    for i, in_experiment_tuple in enumerate(in_experiment_tuples):
        in_experiment_path, network_name = in_experiment_tuple
        plt.subplot(131+i)

        first_curve_path = next((in_experiment_path / "curves").glob(f"*_{col_names[0]}.csv"))
        fracs_removed = np.loadtxt(first_curve_path, delimiter=",", skiprows=1, usecols=[0,])

        df_list = []

        for method_name, legend_name in zip(col_names, latex_names):
            outs_list = []
            for curve_path in (in_experiment_path / "curves").glob(f"*_{method_name}.csv"):
                outs = np.loadtxt(curve_path, delimiter=",", skiprows=1, usecols=[1,])
                if outs[0] > outs[-1]:
                    outs_list.append(outs)
            if len(outs_list) == 0:
                df_list.append(pd.DataFrame(
                    {"Fraction removed" : [-1, ],
                     "Method" : [legend_name, ],
                     "Average network output" : 0}))
                continue
            outs_mat = np.stack(outs_list, axis=0)
            outs_mean = np.mean(outs_mat, axis=0)

            df_list.append(pd.DataFrame(
                {"Fraction removed" : fracs_removed,
                 "Method" : [legend_name, ]*len(fracs_removed),
                 "Average network output" : outs_mean}))

        df = pd.concat(df_list)
        if i == 2:
            sns.lineplot(data=df, x="Fraction removed", y="Average network output", hue="Method", style="Method", palette=palette, legend=True)
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
        else:
            sns.lineplot(data=df, x="Fraction removed", y="Average network output", hue="Method", style="Method", palette=palette, legend=False)
        plt.xlim(-0.04, 1.04)
        plt.title(network_name)
    plt.tight_layout()
    plt.savefig(experiment_path / "plot.svg")
