import os
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from src.utils import ADAM_HNN, ELM_RF_HNN, SWIM_RF_HNN, ADAM_HGN, ELM_RF_HGN, SWIM_RF_HGN, mpl_setup
# mpl_setup() # sets some default params for the matplotlib figures

argparser = ArgumentParser()
argparser.add_argument("--node_scaling_dir", type=str, help="Path to the node scaling experiment directory containing trained models.")
argparser.add_argument("--Nx_gnn", type=int, required=True, help="Number of objects in the training chain system.")
argparser.add_argument("--Nx_test", type=int, required=True, help="Number of nodes in the test chain system (for the integration).")
argparser.add_argument("--obj_idx", type=int, required=True, help="Index of the object to plot")
argparser.add_argument("--len_traj", type=int, required=True, help="Trajectory length of the integration")
argparser.add_argument("--delta_t", type=float, required=True, help="Time step size")
argparser.add_argument("--all", action="store_true", help="Plot every curve in the experimetn", required=False, default=False)

args = vars(argparser.parse_args())
plot_all = args["all"]
Nx_test = args["Nx_test"]
Nx_gnn = args["Nx_gnn"]
obj_idx = args["obj_idx"]
model_dir = args["node_scaling_dir"]
len_traj = args["len_traj"]
delta_t = args["delta_t"]
plotx = np.arange(1, len_traj + 1)
legend_fontsize = 12
label_fontsize = 14

model_names = [ "elm-rf-hgn", "swim-rf-hgn" ]
model_linestyles = [ "dashdot", "dashed" ]
model_colors = [ ELM_RF_HGN, SWIM_RF_HGN ]
model_labels = [ "(ELM) RF-HGN", "(SWIM) RF-HGN" ]
# if plot_all:
    # model_names = [ "adam-hgn" ] + model_names
    # model_linestyles = [ "dotted" ] + model_linestyles
    # model_colors = [ ADAM_HGN ] + model_colors
    # model_labels = [ "(Adam) HGN" ] + model_labels

if Nx_test <= Nx_gnn:
    # model_names = [ "adam-hnn" ] + model_names
    # model_linestyles = [ "dashed" ] + model_linestyles
    # model_colors = [ ADAM_HNN ] + model_colors
    # model_labels = [ "(Adam) HNN" ] + model_labels
    if plot_all:
        model_names = [ "elm-rf-hnn", "swim-rf-hnn" ] + model_names
        model_linestyles = [ "dotted", "dotted" ] + model_linestyles
        model_colors = [ ELM_RF_HNN, SWIM_RF_HNN ] + model_colors
        model_labels = [ "(ELM) RF-HNN", "(SWIM) RF-HNN" ] + model_labels

def plot_xy_traj(ax_x, ax_y, qx, qy, color, linestyle, linewidth=2.0):
    """
    Plots obj_idx x and y into the given axes.
    [x position of particle 1][y positions of particle 1][MSE][Hamiltonian]
    """
    ax_x.plot(plotx, qx[:, obj_idx], c=color, linestyle=linestyle, linewidth=linewidth)
    ax_y.plot(plotx, qy[:, obj_idx], c=color, linestyle=linestyle, linewidth=linewidth)

def plot_mse(ax_mse, qx_true, qy_true, qx_pred, qy_pred, label, color, linestyle, linewidth=2.0):
    """
    Plots MSE positions against integration which uses the true Hamiltonian of the system.
    """
    mse = np.mean((qx_true - qx_pred)**2 + (qy_true - qy_pred)**2, axis=1)
    plotx = np.arange(1, len(mse) + 1)
    ax_mse.semilogy(plotx, mse, c=color, label=label, linestyle=linestyle, linewidth=linewidth)

def plot_hamiltonian(ax_energy, energy, label, color, linestyle, linewidth=2.0):
    plotx = np.arange(1, len(energy) + 1)
    ax_energy.semilogy(plotx, energy, label=label, c=color, linestyle=linestyle, linewidth=linewidth)

fig, (ax_x, ax_y, ax_energy, ax_mse) = plt.subplots(1, 4, figsize=(14, 3), dpi=100)

# use the true gradient of the system to integrate reference solution
qx_true_path = os.path.join(model_dir, f"qx_true_{Nx_test}_{len_traj}_{delta_t:.2e}.npy")
qy_true_path = os.path.join(model_dir, f"qy_true_{Nx_test}_{len_traj}_{delta_t:.2e}.npy")
hamil_true_path = os.path.join(model_dir, f"energy_true_{Nx_test}_{len_traj}_{delta_t:.2e}.npy")

qx_true = np.load(qx_true_path)
qy_true = np.load(qy_true_path)
hamil_true = np.load(hamil_true_path)

plot_xy_traj(ax_x, ax_y, qx_true, qy_true, "k", "solid", linewidth=4)
plot_hamiltonian(ax_energy, hamil_true, r"using true $\mathcal{H}$", "k", "solid", linewidth=4)

for model_name, model_label, model_color, model_linestyle in zip(model_names, model_labels, model_colors, model_linestyles):
    # use trained model's gradient to integrate
    qx_pred_path = os.path.join(model_dir, f"qx_{Nx_test}_{model_name}_{len_traj}_{delta_t:.2e}.npy")
    qy_pred_path = os.path.join(model_dir, f"qy_{Nx_test}_{model_name}_{len_traj}_{delta_t:.2e}.npy")
    hamil_pred_path = os.path.join(model_dir, f"energy_{Nx_test}_{model_name}_{len_traj}_{delta_t:.2e}.npy")
    qx_pred = np.load(qx_pred_path)
    qy_pred = np.load(qy_pred_path)
    hamil_pred = np.load(hamil_pred_path)

    plot_xy_traj(ax_x, ax_y, qx_pred, qy_pred, model_color, model_linestyle, linewidth=3)
    plot_mse(ax_mse, qx_true, qy_true, qx_pred, qy_pred, model_label, model_color, model_linestyle, linewidth=3)
    plot_hamiltonian(ax_energy, hamil_pred, model_label, model_color, model_linestyle, linewidth=3)


# ax_energy.legend(loc="best", fontsize=8)
if Nx_test <= Nx_gnn: # if we are also plotting HNNs then put a legend on top of this plot
    lines = []
    labels = []
    Line, Label = ax_energy.get_legend_handles_labels()
    lines.extend(Line)
    labels.extend(Label)
    if plot_all:
        fig.legend(lines, labels, loc='lower center', ncol=len(labels), fontsize=legend_fontsize, bbox_to_anchor=(0.5, 0.88))
    else:
        fig.legend(lines, labels, loc='lower center', ncol=len(labels), fontsize=legend_fontsize, bbox_to_anchor=(0.5, 0.92))
else:
    ax_mse.legend(loc="lower right", fontsize=8)

# if plot_all:
    # fig.legend(lines, labels, loc='lower center', ncol=len(labels), fontsize=legend_fontsize, bbox_to_anchor=(0.5, 0.92))
# else:
    # fig.legend(lines, labels, loc='lower center', ncol=len(labels), fontsize=legend_fontsize)

[ ax.set_xlabel("Time step", fontsize=label_fontsize) for ax in [ax_x, ax_y, ax_mse, ax_energy] ]
ax_x.set_ylabel(fr"$q_1$ of node {obj_idx}", fontsize=label_fontsize)
ax_y.set_ylabel(fr"$q_2$ of node {obj_idx}", fontsize=label_fontsize)
ax_mse.set_ylabel("MSE", fontsize=label_fontsize)
ax_mse.grid()
ax_energy.set_ylabel(r"True $\mathcal{H}$", fontsize=label_fontsize)
fig.tight_layout()
if plot_all:
    figure_path = f"traj_chain_train{Nx_gnn}_test{Nx_test}_all.pdf"
else:
    figure_path = f"traj_chain_train{Nx_gnn}_test{Nx_test}.pdf"
fig.savefig(figure_path)
print(f"-> figure saved under {figure_path}")

exit(0)
