import os
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from src.data.molecular_dynamics import animate_2D
from src.utils import l2_err, mse, SWIM_RF_HGN, ELM_RF_HGN, ADAM_HGN
from tabulate import tabulate

argparser = ArgumentParser()
argparser.add_argument("-swim", "--swimnpz", required=True, type=str,
                       help="Path to .npz data file produced by main_lennard_jones.py for SWIM")
argparser.add_argument("-elm", "--elmnpz", required=True, type=str,
                       help="Path to .npz data file produced by main_lennard_jones.py for ELM")
argparser.add_argument("-adam", "--adamnpz", required=True, type=str,
                       help="Path to .npz data file produced by main_lennard_jones.py for Adam")
argparser.add_argument("-o", "--outdir", required=True, type=str,
                       help="Output directory path")
argparser.add_argument("-fl", "--framinglen", required=True, type=int,
                       help="Framing length for matplotlib animation")

args = argparser.parse_args()

"""Load the data"""

def load_data(npz_path):
    data = np.load(npz_path)
    q_traj, q_traj_pred = data["q_traj"], data["q_traj_pred"]
    ke_traj, ke_traj_pred = data["ke_traj"], data["ke_traj_pred"]
    pe_traj, pe_traj_pred = data["pe_traj"], data["pe_traj_pred"]
    h_traj, h_traj_pred = data["h_traj"], data["h_traj_pred"]
    h_traj_error = ke_traj_pred + pe_traj_pred
    return q_traj, q_traj_pred, ke_traj, ke_traj_pred, pe_traj, pe_traj_pred, h_traj, h_traj_pred, h_traj_error

swim_filepath, elm_filepath, adam_filepath = args.swimnpz, args.elmnpz, args.adamnpz
q_traj, swim_q_traj, ke_traj, swim_ke_traj, pe_traj, swim_pe_traj, h_traj, swim_h_traj, swim_h_traj_error = load_data(swim_filepath)
_, elm_q_traj, _, elm_ke_traj, _, elm_pe_traj, _, elm_h_traj, elm_h_traj_error = load_data(elm_filepath)
_, adam_q_traj, _, adam_ke_traj, _, adam_pe_traj, _, adam_h_traj, adam_h_traj_error = load_data(adam_filepath)

"""Plot trajectories and energies over time with MSE"""

n_steps, n_obj, dof = q_traj.shape
obj_idx = int(n_obj // 2)   # Select an object to compare trajectories of different approximations
plotx = range(1, n_steps + 1)
legend_fontsize = 12
label_fontsize = 14

def plot_xy_traj(ax_x, ax_y, qx, qy, color, linestyle, linewidth=2.0):
    """
    Plots obj_idx x and y into the given axes.
    [x position of particle 1][y positions of particle 1][MSE][Hamiltonian]
    """
    ax_x.plot(plotx, qx[:, obj_idx], c=color, linestyle=linestyle, linewidth=linewidth)
    ax_y.plot(plotx, qy[:, obj_idx], c=color, linestyle=linestyle, linewidth=linewidth)

def plot_mse(ax_mse, qx_true, qy_true, qx_pred, qy_pred, label, color, linestyle, linewidth=2.0):
    """
    Plots MSE positions against integration which uses the true Hamiltonian of the system.
    """
    mse = np.mean((qx_true - qx_pred)**2 + (qy_true - qy_pred)**2, axis=1)
    ax_mse.semilogy(plotx, mse, c=color, label=label, linestyle=linestyle, linewidth=linewidth)

def plot_hamiltonian(ax_energy, energy, label, color, linestyle, linewidth=2.0):
    ax_energy.set_yscale("symlog", linthresh=0.015)
    ax_energy.plot(plotx, energy, label=label, c=color, linestyle=linestyle, linewidth=linewidth)

fig, (ax_x, ax_y, ax_energy, ax_mse) = plt.subplots(1, 4, figsize=(14, 3), dpi=100)

model_q_trajs = [
    # elm_q_traj,                 # Uncomment to plot ELM too
    swim_q_traj,
    adam_q_traj,
]
model_ke_trajs = [
    # elm_ke_traj,                # Uncomment to plot ELM too
    swim_ke_traj,
    adam_ke_traj,
]
model_pe_trajs = [
    # elm_pe_traj,                # Uncomment to plot ELM too
    swim_pe_traj,
    adam_pe_traj,
]
model_h_traj_errors = [
    # elm_h_traj_error,               # Uncomment to plot ELM too
    swim_h_traj_error,
    adam_h_traj_error,
]
model_mse_trajs = [
    # ((q_traj - elm_q_traj)**2).mean(axis=(1, 2)),     # Uncomment to plot ELM too
    ((q_traj - swim_q_traj)**2).mean(axis=(1, 2)),
    ((q_traj - adam_q_traj)**2).mean(axis=(1, 2)),
]
model_labels = [
    # "(ELM) RF-HGN",         # Uncomment to plot ELM too
    "(SWIM) RF-HGN",
    "(Adam) HGN",
]
model_colors = [
    # ELM_RF_HGN,             # Uncomment to plot ELM too
    SWIM_RF_HGN,
    ADAM_HGN,
]
# model_linestyles = ["dashdot", "dashed", "dotted"]
model_linestyles = [
    # "solid",                # Uncomment to plot ELM too
    "dashdot",
    "dashed",
]
# Plot ground truth
margin = 0.5
ax_x.set_ylim(np.min(q_traj[..., obj_idx, 0])-margin, np.max(q_traj[..., obj_idx, 0])+margin)
ax_y.set_ylim(np.min(q_traj[..., obj_idx, 1])-margin, np.max(q_traj[..., obj_idx, 1])+margin)
plot_xy_traj(ax_x, ax_y, q_traj[..., 0], q_traj[..., 1], "k", "solid", linewidth=4)
# ax_energy.set_ylim(-1e3, 1e4)
ax_mse.set_ylim(1e-9, 1e3)
plot_hamiltonian(ax_energy, h_traj, r"using true $\mathcal{H}$", "k", "solid", linewidth=4)
# Plot predictions
for model_q_traj, model_h_traj_error, model_mse_traj, model_label, model_color, model_linestyle in zip(model_q_trajs,
                                                                                                 model_h_traj_errors,
                                                                                                 model_mse_trajs,
                                                                                                 model_labels,
                                                                                                 model_colors,
                                                                                                 model_linestyles):
    qx_pred, qy_pred = model_q_traj[..., 0], model_q_traj[..., 1]
    plot_xy_traj(ax_x, ax_y, qx_pred, qy_pred, model_color, model_linestyle, linewidth=3)
    plot_mse(ax_mse, q_traj[..., 0], q_traj[..., 1], qx_pred, qy_pred, model_label, model_color, model_linestyle, linewidth=3)
    plot_hamiltonian(ax_energy, model_h_traj_error, model_label, model_color, model_linestyle, linewidth=3)
# Legends
lines = []
labels = []
Line, Label = ax_energy.get_legend_handles_labels()
lines.extend(Line)
labels.extend(Label)
fig.legend(lines, labels, loc='lower center', ncol=len(labels), fontsize=legend_fontsize, bbox_to_anchor=(0.5, 0.875))

[ ax.set_xlabel("Time step", fontsize=label_fontsize) for ax in [ax_x, ax_y, ax_mse, ax_energy] ]
ax_x.set_ylabel(fr"$q_1$ of node {obj_idx}", fontsize=label_fontsize)
ax_y.set_ylabel(fr"$q_2$ of node {obj_idx}", fontsize=label_fontsize)
ax_mse.set_ylabel("MSE", fontsize=label_fontsize)
ax_mse.grid()

ax_energy.set_ylabel(r"True $\mathcal{H}$", fontsize=label_fontsize)
fig.tight_layout()
traj_plot_filepath = os.path.join(args.outdir, f"{n_obj}particles_traj.pdf")
fig.savefig(traj_plot_filepath)
print(f"Trajectory plot saved at '{traj_plot_filepath}'")

"""Generate results table"""

def generate_results_table(q_traj_pred, h_traj_pred):
    # Evaluate at 5 points in the trajectory
    n_evals_in_traj = 5
    test_indices = np.linspace(0, n_steps - 2, num=n_evals_in_traj, dtype=np.int64)
    q_mse, q_rel2 = np.zeros(n_evals_in_traj, dtype=q_traj.dtype), np.zeros(n_evals_in_traj, dtype=q_traj.dtype)
    h_true = np.zeros(n_evals_in_traj, dtype=h_traj.dtype)
    h_pred = np.zeros(n_evals_in_traj, dtype=h_traj.dtype)
    for idx, test_index in enumerate(test_indices):
        # rmse_traj = np.sqrt(((q_traj- q_traj_pred)**2).mean(axis=(1, 2)))
        q_mse[idx] = mse(q_traj[test_index], q_traj_pred[test_index], verbose=False)
        q_rel2[idx] = l2_err(q_traj[test_index], q_traj_pred[test_index], verbose=False)
        h_true[idx] = h_traj[test_index].item()
        h_pred[idx] = h_traj_pred[test_index].item()

    table_title = "\nTable: Predicted trajectory evaluation, error values on positions (q) are displayed with true and predicted conserved (energy) values."
    arr_columns = [ f"T={step_idx+1}" for step_idx in test_indices ]
    results_table = tabulate(
        headers=[""] + arr_columns,
        tabular_data=[
            ["q MSE"] + list(q_mse),
            ["q L2 rel."] + list(q_rel2),
            ["True H"] + list(h_true),
            ["Pred H"] + list(h_pred),
        ],
        floatfmt=".3e"
    )
    print(results_table)

print("-> ELM Approximation")
generate_results_table(elm_q_traj, elm_h_traj)
print("-> SWIM Approximation")
generate_results_table(swim_q_traj, swim_h_traj)
print("-> Adam Approximation")
generate_results_table(adam_q_traj, adam_h_traj)

# """Generate snapshots of the system at the results table"""
# for model_q_traj, model_ke_traj, model_pe_traj, model_h_traj_error, model_label in zip(model_q_trajs,
                                                                                 # model_ke_trajs,
                                                                                 # model_pe_trajs,
                                                                                 # model_h_traj_errors,
                                                                                 # model_labels):
#
    # rel2_traj = np.linalg.norm(q_traj - model_q_traj, axis=(1, 2)) / np.linalg.norm(q_traj, axis=(1, 2))
    # filename = f"{n_obj}simulation_{model_label}.mp4"
    # print(f"-> Simulating for {model_label}")
    # animate_2D(q_traj[::args.framinglen], ke_traj, pe_traj, h_traj,
               # q_pred=model_q_traj[::args.framinglen], ke_pred=model_ke_traj, pe_pred=model_pe_traj, h_pred=model_h_traj_error,
               # rel2=rel2_traj, framing_length=args.framinglen, filename=os.path.join(args.outdir, filename))

# Snaps of trajectories and Hamiltonian plot
plt.clf()
true_args = { "facecolors": "none", "edgecolors": "red" }
fig, (ax_elm, ax_adam, ax_swim)  = plt.subplots(3, 5, figsize=(16, 9), dpi=100, sharex=True, sharey=True)
# fig, (ax_elm, ax_swim)  = plt.subplots(2, 5, figsize=(16, 6), dpi=100, sharex=True, sharey=True)
margin = 0.2 * np.max(np.abs(q_traj))
anim_xlim_min = np.min(q_traj[..., 0] - margin).item()
anim_xlim_max = np.max(q_traj[..., 0] + margin).item()
anim_ylim_min = np.min(q_traj[..., 1] - margin).item()
anim_ylim_max = np.max(q_traj[..., 1] + margin).item()
# lines = []
# labels = []
# label_added = False

n_evals_in_traj = 5
test_indices = np.linspace(0, n_steps - 2, num=n_evals_in_traj, dtype=np.int64)
lines = []
labels = []

true_label_set = False
any_var: int = 1
vars = { 'true_label_set': False }

def plot_snap(row_axes, q_traj, q_traj_pred, row_label, color):
    for row_axis, test_index in zip(row_axes, test_indices):
        row_axis.set_xlim(anim_xlim_min, anim_xlim_max)
        row_axis.set_ylim(anim_ylim_min, anim_ylim_max)

        if vars["true_label_set"]:
            row_axis.scatter(q_traj[test_index][..., 0], q_traj[test_index][..., 1], s=80, **true_args, zorder=5)
        else:
            row_axis.scatter(q_traj[test_index][..., 0], q_traj[test_index][..., 1], s=80, **true_args, zorder=5, label=r"using true $H$")
            vars["true_label_set"] = True
        row_axis.scatter(q_traj_pred[test_index][..., 0], q_traj_pred[test_index][..., 1], s=80, edgecolors="tab:blue", zorder=1, label=row_label, c=color)
    Line, Label = row_axes[0].get_legend_handles_labels()
    lines.extend(Line)
    labels.extend(Label)

plot_snap(ax_elm, q_traj, elm_q_traj, "(ELM) RF-HGN", color=ELM_RF_HGN)
plot_snap(ax_adam, q_traj, adam_q_traj, "(Adam) HGN", color="tab:orange")
plot_snap(ax_swim, q_traj, swim_q_traj, "(SWIM) RF-HGN", color="tab:blue")
for ax, test_index in zip(ax_swim, test_indices):
    ax.set_xlabel(f"Time step {test_index+1}")
# Legends
fig.legend(lines, labels, loc='lower center', ncol=len(labels), fontsize=legend_fontsize, bbox_to_anchor=(0.5, 0.95))
fig.tight_layout()
# fig.legend(loc="upper center")
traj_plot_filepath = os.path.join(args.outdir, f"{n_obj}particles_snaps.pdf")
fig.savefig(traj_plot_filepath)
print(f"-> Snaps are saved at '{traj_plot_filepath}'")
