import json
import toml
import os
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from src.utils import mpl_setup, ADAM_HNN, ELM_RF_HGN, SWIM_RF_HGN, ELM_RF_HNN, SWIM_RF_HNN, ADAM_HGN
# mpl_setup() # sets some default params for the matplotlib figures

argparser = ArgumentParser()
argparser.add_argument("-d", "--dir", type=str, help="Output directory of the experiment", required=True)
argparser.add_argument("--all", action="store_true", help="Plot every curve in the experimetn", required=False, default=False)
args = vars(argparser.parse_args())
plot_all = args["all"]

print("Reading experiment results and config at", args["dir"])
results_path = os.path.join(args["dir"], "results.json")
results_file = open(results_path)
results = json.load(results_file)
results_file.close()

config_path = os.path.join(args["dir"], "config.toml")
config_file = open(config_path)
config = toml.load(config_file)
config_file.close()
print("Configuration is read as:")
print(config)

n_nodes = [ Nx for [Nx] in results["n_nodes"] ] # node scaling (x-axis)
marker_adam_hnn = "."
marker_elm_rf_hnn= "s"
marker_swim_rf_hnn = "*"

marker_elm_rf_hgn = "+"
marker_swim_rf_hgn = "x"
marker_adam_hgn = "o"

linewidth = 3
markersize = 8
labelsize = 12

# Set what to plot
if plot_all:
    # models = [ "adam-hnn", "elm-rf-hnn", "swim-rf-hnn", "adam-hgn", "elm-rf-hgn", "swim-rf-hgn" ]
    models = [ "elm-rf-hnn", "swim-rf-hnn", "elm-rf-hgn", "swim-rf-hgn" ]
    # labels = [ "(Adam) HNN", "(ELM) RF-HNN", "(SWIM) RF-HNN", "(Adam) HGN", "(ELM) RF-HGN", "(SWIM) RF-HGN" ]
    labels = [ "(ELM) RF-HNN", "(SWIM) RF-HNN", "(ELM) RF-HGN", "(SWIM) RF-HGN" ]
    # markers = [ marker_adam_hnn, marker_elm_rf_hnn, marker_swim_rf_hnn, marker_adam_hgn, marker_elm_rf_hgn, marker_swim_rf_hgn ]
    markers = [ marker_elm_rf_hnn, marker_swim_rf_hnn, marker_elm_rf_hgn, marker_swim_rf_hgn ]
    # colors = [ ADAM_HNN, ELM_RF_HNN, SWIM_RF_HNN, ADAM_HGN, ELM_RF_HGN, SWIM_RF_HGN ]
    colors = [ ELM_RF_HNN, SWIM_RF_HNN, ELM_RF_HGN, SWIM_RF_HGN ]
    # linestyles = [ "solid", "dashed", "dashdot", "solid", "dashed", "dashdot"  ]
    linestyles = [ "dotted", (0, (1, 1)), "dashdot", "dashed"  ]
    # legend_fontsize = 8
    legend_fontsize = 10
else:
    models = [ "adam-hnn", "elm-rf-hgn", "swim-rf-hgn" ]
    labels = [ "(Adam) HNN", "(ELM) RF-HGN", "(SWIM) RF-HGN" ]
    markers = [ marker_adam_hnn, marker_elm_rf_hgn, marker_swim_rf_hgn ]
    colors = [ ADAM_HNN, ELM_RF_HGN, SWIM_RF_HGN ]
    linestyles = [ "solid", "dashed", "dashdot" ]
    legend_fontsize = 8

# Plot scaling (errors and training time)

# without MSE (because if it is similar then don't plot in the main text)
fig, (ax3, ax2) = plt.subplots(1, 2, figsize=(8,2.5), dpi=100)
[ (ax.set_xscale("log", base=2), ax.set_yscale("log", base=10)) for ax in [ax2, ax3] ]

# with MSE
# fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,3), dpi=100)
# [ (ax.set_xscale("log", base=2), ax.set_yscale("log", base=10)) for ax in [ax1, ax2, ax3] ]

# Results include: 'elm-rf-hnn', 'swim-rf-hnn', 'adam-hnn', 'elm-rf-hgn', 'swim-rf-hgn', 'adam-hgn'
for model, label, marker, color, linestyle in zip(models, labels, markers, colors, linestyles):
    rel2        = [ entry["relative_l2_test"] for entry in results[model] ]
    mse         = [ entry["mse_test"] for entry in results[model] ]
    infer_times = [ entry["infer_time"] for entry in results[model] ]

    if model == "elm-rf-hnn" or model == "swim-rf-hnn" or model == "adam-hnn":
        train_times = [ entry["train_time"] for entry in results[model] ]
        Nx_end_fcnn = len(train_times)
        # label = "(ELM) RF-HNN" if model == "elm-fcnn" else "(SWIM) RF-HNN"
        # marker = marker_elm_fcnn if model == "elm-fcnn" else marker_swim_fcnn
        # color = color_elm_fcnn if model == "elm-fcnn" else color_swim_fcnn
        # ax1.plot(n_nodes[:Nx_end_fcnn], mse, marker, label=label, linestyle="solid")
        ax2.plot(n_nodes[:Nx_end_fcnn], rel2, marker, c=color, label=label, linestyle=linestyle, linewidth=linewidth, markersize=markersize)
        # linestyle = "dashdot" if model == "swim-fcnn" else "dashed"
        ax3.plot(n_nodes[:Nx_end_fcnn], train_times, marker, c=color, linestyle=linestyle, linewidth=linewidth, markersize=markersize)
    elif model == "elm-rf-hgn" or model == "swim-rf-hgn" or model == "adam-hgn":
        # extend train times for the GNNs to with the last trained training time,
        # because that model was used to evaluate test error on larger n_nodes than what is used in training.
        train_times = []
        for entry in results[model]:
            if "train_time" in entry:
                train_times.append(entry["train_time"])
        Nx_gnn = len(train_times)           # end index where GNN is stopped training further and previous model is used for inference
        Nx_end_gnn = len(results[model])    # end index of the train/inference for the GNN
        assert len(n_nodes) == Nx_end_gnn

        # label = "(ELM) RF-HGN" if model == "elm-gnn" else "(SWIM) RF-HGN"
        # marker = marker_elm_gnn if model == "elm-gnn" else marker_swim_gnn
        # color = color_elm_gnn if model == "elm-gnn" else color_swim_gnn
        # ax1.plot(n_nodes, mse, marker, label=label, linestyle="solid")
        ax2.plot(n_nodes, rel2, marker, c=color, label=label, linestyle=linestyle, linewidth=linewidth, markersize=markersize)
        # linestyle = "dashdot" if model == "swim-gnn" else "dashed"
        ax3.plot(n_nodes[:Nx_gnn], train_times, marker, c=color, linestyle=linestyle, linewidth=linewidth, markersize=markersize)
        # ax3.plot(n_nodes[Nx_gnn-1:], np.zeros((len(n_nodes) - Nx_gnn + 1)) + train_times[-1], c=color, linestyle=linestyle, linewidth=linewidth, markersize=markersize)
    else: raise NotImplementedError(f"The plot is not implemented for the new model {model}")

lines = []
labels = []

for ax in fig.axes:
    Line, Label = ax.get_legend_handles_labels()
    # print(Label)
    lines.extend(Line)
    labels.extend(Label)

if plot_all:
    fig.legend(lines, labels, loc='lower center', ncol=len(labels), fontsize=legend_fontsize, bbox_to_anchor=(0.5, 0.90))
else:
    fig.legend(lines, labels, loc='upper center', ncol=len(labels), fontsize=legend_fontsize)

# ax2.legend(loc="lower right", fontsize=8)
# ax3.legend(loc="lower right", fontsize=8)
# [ ax.legend(loc="upper left", fontsize=8) for ax in [ax2, ax3] ]
# [ ax.legend(loc="best", fontsize=8) for ax in [ax1, ax2, ax3] ]

[ ax.set_xlabel("Number of nodes", fontsize=labelsize) for ax in [ax2, ax3] ]
# [ ax.set_xlabel("Number of nodes", fontsize=8) for ax in [ax1, ax2, ax3] ]

ax2.set_ylabel(r"Relative error", fontsize=labelsize); ax3.set_ylabel("Training time [s]")
# ax1.set_ylabel("MSE"); ax2.set_ylabel(r"Relative $L^2$"); ax3.set_ylabel("Training time [s]")

[ ax.grid(True) for ax in [ax2, ax3] ]
# [ ax.grid(True) for ax in [ax1, ax2, ax3] ]

fig.tight_layout()
if plot_all:
    file_prefix = "node_scaling_all"
else:
    file_prefix = "node_scaling"
plot_path = os.path.join(args["dir"], file_prefix)

fig.savefig(plot_path + ".pdf")
print(f"-> Saved under {plot_path + '.pdf'}")

fig.savefig(plot_path + ".png")
print(f"-> Saved under {plot_path + '.png'}")

exit(0)
