from glob import glob
from os.path import exists

import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

sns.set(context="paper", font_scale=5, style="darkgrid")
plt.rcParams["figure.autolayout"] = False

if __name__ == "__main__":
    datasets_names = {"mnist": "MNIST", "fashion_mnist": "FashionMNIST", "cifar10": "CIFAR10"}
    models_names = {"simple": "SimpleNN", "lenet": "LeNet", "alexnet": "AlexNet", "vgg": "VGG", "resnet": "ResNet"}
    optimizers_names = {"sgd": "SGD", "siopt": "SIOPT", "sgd_adv_iter": "SGD Adv. Iter."}
    metrics_names = {"loss": "Loss", "train_acc": "TrainAcc", "test_acc": "TestAcc", "test_rob": "TestRob"}

    single_stage_logs = [pd.read_csv(f"{run_folder}/losses_single_stage.txt", sep="\t", encoding="utf-8",
                                     header=0, index_col=None)
                         for dataset in datasets_names for run_folder in glob(f"results/{dataset}/runs/*")
                         if exists(f"{run_folder}/losses_single_stage.txt")]
    single_stage_logs = pd.concat(single_stage_logs, ignore_index=True)

    multi_stage_logs = [pd.read_csv(f"{run_folder}/results_multi_stages.txt", sep="\t", encoding="utf-8",
                                    header=0, index_col=None)
                        for dataset in datasets_names for run_folder in glob(f"results/{dataset}/runs/*")
                        if exists(f"{run_folder}/results_multi_stages.txt")]
    multi_stage_logs = pd.concat(multi_stage_logs, ignore_index=True)

    # Single stage loss
    plot_data_single = single_stage_logs.set_index(["Dataset", "Model", "Optimizer", "Iteration"])
    plot_data_single = plot_data_single.reindex(list(datasets_names.keys()), level=0)
    plot_data_single = plot_data_single.reindex(list(models_names.keys()), level=1)
    plot_data_single = plot_data_single.reindex(list(optimizers_names.keys()), level=2)
    plot_data_single = plot_data_single.reset_index()
    plot_data_single.Dataset = plot_data_single.Dataset.map(datasets_names)
    plot_data_single.Model = plot_data_single.Model.map(models_names)
    plot_data_single.Optimizer = plot_data_single.Optimizer.map(optimizers_names)

    plt.figure(figsize=(len(datasets_names) * 11.7, len(models_names) * 8.3))
    g = sns.relplot(x="Iteration", y="Loss", hue="Optimizer", row="Model", col="Dataset", data=plot_data_single,
                    linewidth=5, palette="Set1", kind="line",
                    height=8.3, aspect=11.7 / 8.3, facet_kws={"sharey": False, "legend_out": True})
    for (model, dataset), ax in g.axes_dict.items():
        ax.set_title(f"{model} | {dataset}")
    for lh in g.legend.get_lines():
        lh.set_linewidth(5)
    g.tight_layout()
    plt.savefig("results/single_stage_loss.pdf", format="pdf")

    plt.figure(figsize=(len(datasets_names) * 11.7, 8.3))
    g = sns.relplot(x="Iteration", y="Loss", hue="Optimizer", col="Dataset", data=plot_data_single,
                    linewidth=5, palette="Set1", kind="line", ci="sd",
                    height=8.3, aspect=11.7 / 8.3, facet_kws={"sharey": False, "legend_out": True})
    for dataset, ax in g.axes_dict.items():
        ax.set_title(dataset)
    for lh in g.legend.get_lines():
        lh.set_linewidth(5)
    g.tight_layout()
    plt.savefig("results/single_stage_loss_agg.pdf", format="pdf")

    # Continual learning
    plot_data_continual = multi_stage_logs[multi_stage_logs.Continual].set_index(
        ["Dataset", "Model", "Optimizer", "Stage"]
    )
    plot_data_continual = plot_data_continual.reindex(list(datasets_names.keys()), level=0)
    plot_data_continual = plot_data_continual.reindex(list(models_names.keys()), level=1)
    plot_data_continual = plot_data_continual.reindex(list(optimizers_names.keys()), level=2)
    plot_data_continual = plot_data_continual.reset_index()
    plot_data_continual.Dataset = plot_data_continual.Dataset.map(datasets_names)
    plot_data_continual.Model = plot_data_continual.Model.map(models_names)
    plot_data_continual.Optimizer = plot_data_continual.Optimizer.map(optimizers_names)

    for metric_key, metric in metrics_names.items():
        plt.figure(figsize=(len(datasets_names) * 11.7, len(models_names) * 8.3))
        g = sns.relplot(x="Stage", y=metric, hue="Optimizer", row="Model", col="Dataset", data=plot_data_continual,
                        linewidth=5, palette="Set1", kind="line",
                        height=8.3, aspect=11.7 / 8.3, facet_kws={"sharey": False, "legend_out": True})
        for (model, dataset), ax in g.axes_dict.items():
            ax.set_title(f"{model} | {dataset}")
        for lh in g.legend.get_lines():
            lh.set_linewidth(5)
        g.tight_layout()
        plt.savefig(f"results/continual_{metric_key}.pdf", format="pdf")

    # Adversarial learning
    plot_data_adversarial = multi_stage_logs[~multi_stage_logs.Continual].set_index(
        ["Dataset", "Model", "Optimizer", "Adversarial", "Stage"]
    )
    plot_data_adversarial = plot_data_adversarial.reindex(list(datasets_names.keys()), level=0)
    plot_data_adversarial = plot_data_adversarial.reindex(list(models_names.keys()), level=1)
    plot_data_adversarial = plot_data_adversarial.reindex(list(optimizers_names.keys()), level=2)
    plot_data_adversarial = plot_data_adversarial.reset_index()
    plot_data_adversarial.Dataset = plot_data_adversarial.Dataset.map(datasets_names)
    plot_data_adversarial.Model = plot_data_adversarial.Model.map(models_names)
    plot_data_adversarial.Optimizer = plot_data_adversarial.Optimizer.map(optimizers_names)

    for metric_key, metric in metrics_names.items():
        plt.figure(figsize=(len(datasets_names) * 11.7, len(models_names) * 8.3))
        g = sns.relplot(x="Stage", y=metric, hue="Optimizer", style="Adversarial",
                        row="Model", col="Dataset", data=plot_data_adversarial,
                        linewidth=5, palette="Set1", kind="line",
                        height=8.3, aspect=11.7 / 8.3, facet_kws={"sharey": False, "legend_out": True})
        for (model, dataset), ax in g.axes_dict.items():
            ax.set_title(f"{model} | {dataset}")
        for lh in g.legend.get_lines():
            lh.set_linewidth(5)
        g.tight_layout()
        plt.savefig(f"results/adversarial_{metric_key}.pdf", format="pdf")
