import os
import toml
import numpy as np
from argparse import ArgumentParser
from src.data.mass_spring import edge_index_chain
from src.utils import generate_results_table, animate_2D, plot_traj_energy_mse, ADAM_HGN, ELM_RF_HGN, SWIM_RF_HGN, plot_snaps, mse, l2_err


argparser = ArgumentParser()
argparser.add_argument("-true", "--truenpz", required=True, type=str,
                       help="Path to .npz data file produced by simulate_model.py")
argparser.add_argument("-swim", "--swimnpz", required=True, type=str,
                       help="Path to .npz data file produced by simulate_model.py")
argparser.add_argument("-elm", "--elmnpz", required=True, type=str,
                       help="Path to .npz data file produced by simulate_model.py")
argparser.add_argument("-adam", "--adamnpz", required=True, type=str,
                       help="Path to .npz data file produced by simulate_model.py")
argparser.add_argument("-o", "--outdir", required=True, type=str,
                       help="Output directory path")
argparser.add_argument("-t", "--toml", required=True, type=str,
                       help="Path to .toml config file")
args = argparser.parse_args()

config = toml.load(args.toml)
anim_config = config["animation"]

"""Read simulation data"""

true = np.load(args.truenpz)
pred_elm = np.load(args.elmnpz)
pred_adam = np.load(args.adamnpz)
pred_swim = np.load(args.swimnpz)

print(f"-> True trajectory read of shape {true['q'].shape}")
print(f"-> ELM  trajectory read of shape {pred_elm['q'].shape}")
print(f"-> ADAM trajectory read of shape {pred_adam['q'].shape}")
print(f"-> SWIM trajectory read of shape {pred_swim['q'].shape}")

"""Tabulate predictions"""
for pred, optim_name in zip([pred_elm, pred_adam, pred_swim],
                            ["ELM", "ADAM", "SWIM"]):
    q_pred, e_pred = pred["q"], pred["e"]
    title, table = generate_results_table(
        true["q"], q_pred, true["e"], e_pred,
        start_idx=1, num_evals=5
    )
    print('\n'.join([
        title, table
    ]))

# Construct a chain system by defining a spring length (data includes displacements only)
n_steps, n_nodes, dof = true["q"].shape
def construct_q_from_displacements(q, n_nodes, spring_length):
    q[..., 0] = q[..., 0] + np.arange(n_nodes).reshape(1, -1) * spring_length
    return q

"""Animate 3 systems side by side"""
edge_index = edge_index_chain(n_nodes)
animate_2D(
    construct_q_from_displacements(true["q"], n_nodes, anim_config["spring_length"]),
    edge_index.cpu().numpy(),
    anim_config["framing_length"],
    filename=os.path.join(
        args.outdir,
        "anim.mp4",
    ),
    q_preds=[
        construct_q_from_displacements(pred_adam["q"], n_nodes, anim_config["spring_length"]),
        construct_q_from_displacements(pred_elm["q"], n_nodes, anim_config["spring_length"]),
        construct_q_from_displacements(pred_swim["q"], n_nodes, anim_config["spring_length"]),
    ],
    pred_labels=["(Adam) HGN", "(ELM) RF-HGN", "(SWIM) RF-HGN"]
)

"""Plot: Trajectory - Energy - MSE"""
plot_traj_energy_mse(
    q_true=construct_q_from_displacements(true["q"], n_nodes, anim_config["spring_length"]),
    e_true=true["e"],
    q_preds=[
        construct_q_from_displacements(pred_swim["q"], n_nodes, anim_config["spring_length"]),
        construct_q_from_displacements(pred_elm["q"], n_nodes, anim_config["spring_length"]),
        construct_q_from_displacements(pred_adam["q"], n_nodes, anim_config["spring_length"]),
    ],
    e_preds=[pred_adam["e"], pred_elm["e"], pred_swim["e"]],
    obj_idx=n_nodes//2,
    model_labels=["(SWIM) RF-HGN", "(ELM) RF-HGN", "(Adam) HGN"],
    model_colors=[SWIM_RF_HGN, ELM_RF_HGN, ADAM_HGN],
    model_linestyles = ["dashed", "dashdot", "dotted"],
    filepath=os.path.join(
        args.outdir,
        "traj.pdf",
    )
)

"""Prepare animation snapshots"""
plot_snaps(
    q_true=construct_q_from_displacements(true["q"], n_nodes, anim_config["spring_length"]),
    q_preds=[
        construct_q_from_displacements(pred_elm["q"], n_nodes, anim_config["spring_length"]),
        construct_q_from_displacements(pred_adam["q"], n_nodes, anim_config["spring_length"]),
        construct_q_from_displacements(pred_swim["q"], n_nodes, anim_config["spring_length"]),
    ],
    row_labels=["(ELM) RF-HGN", "(Adam) HGN", "(SWIM) RF-HGN"],
    colors=[ELM_RF_HGN, "tab:orange", "tab:blue"],
    filepath=os.path.join(
        args.outdir,
        "snaps.pdf",
    ),
    edge_index=edge_index.cpu().numpy(),
)

"""Prepare table: Final position MSE, relative"""
print()
print('-'*40)
print(f"Results Summary at q[{len(true["q"]) - 1}]")
adam_q_last_rel2 = l2_err(true["q"][-1], pred_adam["q"][-1], verbose=False)
elm_q_last_rel2 = l2_err(true["q"][-1], pred_elm["q"][-1], verbose=False)
swim_q_last_rel2 = l2_err(true["q"][-1], pred_swim["q"][-1], verbose=False)
print(f"ADAM-HGN q[-1] rel2:      {adam_q_last_rel2:.3e}")
print(f"ELM-HGN  q[-1] rel2:      {elm_q_last_rel2:.3e}")
print(f"SWIM-HGN q[-1] rel2:      {swim_q_last_rel2:.3e}")
print('-'*40)
adam_q_last_mse = mse(true["q"][-1], pred_adam["q"][-1], verbose=False)
elm_q_last_mse = mse(true["q"][-1], pred_elm["q"][-1], verbose=False)
swim_q_last_mse = mse(true["q"][-1], pred_swim["q"][-1], verbose=False)
print(f"ADAM-HGN q[-1] mse :      {adam_q_last_mse:.3e}")
print(f"ELM-HGN  q[-1] mse :      {elm_q_last_mse:.3e}")
print(f"SWIM-HGN q[-1] mse :      {swim_q_last_mse:.3e}")
print('-'*40)
