import os
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from src.utils import SWIM_RF_HGN

argparser = ArgumentParser()
argparser.add_argument("-o", type=str, help="Path to the integration results and also output dir")
argparser.add_argument("--len_traj", type=int, required=True, help="Trajectory length of the integration")
argparser.add_argument("--delta_t", type=float, required=True, help="Time step size")

args = vars(argparser.parse_args())
n_obj = [8]
out_dir = args["o"]
print("outdir is ", out_dir)
len_traj = args["len_traj"]
delta_t = args["delta_t"]
plotx = np.arange(1, len_traj + 1)

qx_true_path = os.path.join(out_dir, f"qx_true_[8]_{len_traj}_{delta_t:.2e}.npy")
qy_true_path = os.path.join(out_dir, f"qy_true_[8]_{len_traj}_{delta_t:.2e}.npy")
energy_true_path = os.path.join(out_dir, f"energy_true_[8]_{len_traj}_{delta_t:.2e}.npy")

# qx_hnn_path = os.path.join(out_dir, f"qx_[8]_hnn_{len_traj}_{delta_t:.2e}.npy")
# qy_hnn_path = os.path.join(out_dir, f"qy_[8]_hnn_{len_traj}_{delta_t:.2e}.npy")
# energy_hnn_path = os.path.join(out_dir, f"energy_[8]_hnn_{len_traj}_{delta_t:.2e}.npy")

qx_plain_rf_hgn_path = os.path.join(out_dir, f"qx_[8]_plain-rf-hgn_{len_traj}_{delta_t:.2e}.npy")
qy_plain_rf_hgn_path = os.path.join(out_dir, f"qy_[8]_plain-rf-hgn_{len_traj}_{delta_t:.2e}.npy")
energy_plain_rf_hgn_path = os.path.join(out_dir, f"energy_[8]_plain-rf-hgn_{len_traj}_{delta_t:.2e}.npy")

qx_rf_hgn_path = os.path.join(out_dir, f"qx_[8]_rf-hgn_{len_traj}_{delta_t:.2e}.npy")
qy_rf_hgn_path = os.path.join(out_dir, f"qy_[8]_rf-hgn_{len_traj}_{delta_t:.2e}.npy")
energy_rf_hgn_path = os.path.join(out_dir, f"energy_[8]_rf-hgn_{len_traj}_{delta_t:.2e}.npy")

qx_true = np.load(qx_true_path)
qy_true = np.load(qy_true_path)
energy_true = np.load(energy_true_path)
plotx = np.arange(1, len(qx_true) + 1)

# qx_hnn = np.load(qx_hnn_path)
# qy_hnn = np.load(qy_hnn_path)
# energy_hnn = np.load(energy_hnn_path)

qx_plain_rf_hgn = np.load(qx_plain_rf_hgn_path)
qy_plain_rf_hgn = np.load(qy_plain_rf_hgn_path)
energy_plain_rf_hgn = np.load(energy_plain_rf_hgn_path)

qx_rf_hgn = np.load(qx_rf_hgn_path)
qy_rf_hgn = np.load(qy_rf_hgn_path)
energy_rf_hgn = np.load(energy_rf_hgn_path)

# save mse plot
fig, (ax_mse) = plt.subplots(1, 1, figsize=(3, 2), dpi=100)
# mse_hnn = np.mean((qx_true - qx_hnn)**2 + (qy_true - qy_hnn)**2, axis=1)
mse_plain_rf_hgn = np.mean((qx_true - qx_plain_rf_hgn)**2 + (qy_true - qy_plain_rf_hgn)**2, axis=1)
mse_rf_hgn = np.mean((qx_true - qx_rf_hgn)**2 + (qy_true - qy_rf_hgn)**2, axis=1)
# print("-> mean mse hnn:", np.mean(mse_hnn))
print("-> mean mse plain rf-hgn:", np.mean(mse_plain_rf_hgn))
print("-> mean mse rf-hgn:", np.mean(mse_rf_hgn))
plotx = np.arange(1, len(mse_plain_rf_hgn) + 1)
# ax_mse.semilogy(plotx, mse_hnn, label="HNN", c="tab:blue", linestyle="dashed", linewidth=3)
ax_mse.semilogy(plotx, mse_plain_rf_hgn, label="Non-invariant RF-HGN", c="tab:blue", linestyle="dashdot", linewidth=3)
ax_mse.semilogy(plotx, mse_rf_hgn, label="RF-HGN", c=SWIM_RF_HGN, linestyle="dashdot", linewidth=3)
ax_mse.grid(True)
ax_mse.set_xlabel("Time step")
ax_mse.set_ylabel("MSE")
fig.tight_layout()
fig_path = "traj_tran_rot_mse.pdf"
fig.savefig(fig_path)
print(f"-> figure saved under '{fig_path}'")

# save relative plot
fig, (ax_mse) = plt.subplots(1, 1, figsize=(3, 2), dpi=100)
denom = np.mean(qx_true**2 + qy_true**2, axis=1)
# rel_hnn = np.mean((qx_true - qx_hnn)**2 + (qy_true - qy_hnn)**2, axis=1) / denom
rel_plain_rf_hgn = np.mean((qx_true - qx_plain_rf_hgn)**2 + (qy_true - qy_plain_rf_hgn)**2, axis=1) / denom
rel_rf_hgn = np.mean((qx_true - qx_rf_hgn)**2 + (qy_true - qy_rf_hgn)**2, axis=1) / denom
# print("-> mean rel hnn:", np.mean(mse_hnn))
print("-> mean rel plain rf-hgn:", np.mean(mse_plain_rf_hgn))
print("-> mean rel rf-hgn:", np.mean(mse_rf_hgn))
# plotx = np.arange(1, len(mse_hnn) + 1)
plotx = np.arange(1, len(mse_rf_hgn) + 1)
# ax_mse.semilogy(plotx, rel_hnn, label="HNN", c="tab:blue", linestyle="dashed", linewidth=3)
ax_mse.semilogy(plotx, rel_plain_rf_hgn, label="RF-HGN (Non-invariant)", c="tab:blue", linestyle="dashdot", linewidth=3)
ax_mse.semilogy(plotx, rel_rf_hgn, label="RF-HGN", c=SWIM_RF_HGN, linestyle="dashdot", linewidth=3)
ax_mse.grid(True)
ax_mse.set_xlabel("Time step")
ax_mse.set_ylabel("Relative squared error")
fig.tight_layout()
fig_path = "traj_tran_rot_rel.pdf"
fig.savefig(fig_path)
print(f"-> figure saved under '{fig_path}'")

# save energy plot
fig, (ax_energy) = plt.subplots(1, 1, figsize=(3, 2), dpi=100)
ax_energy.semilogy(plotx, energy_true, label=r"Using true $\mathcal{H}$", c="k", linestyle="solid", linewidth=4)
# ax_energy.semilogy(plotx, energy_hnn, label="HNN", c="tab:blue", linestyle="dashed", linewidth=3)
ax_energy.semilogy(plotx, energy_plain_rf_hgn, label="RF-HGN (Non-invariant)", c="tab:blue", linestyle="dashdot", linewidth=3)
ax_energy.semilogy(plotx, energy_rf_hgn, label="RF-HGN", c=SWIM_RF_HGN, linestyle="dashdot", linewidth=3)
ax_energy.grid(True)
ax_energy.legend(fontsize=6)
ax_energy.set_xlabel("Time step")
ax_energy.set_ylabel(r"$\mathcal{H}$")
ax_energy.tick_params(axis='y', labelsize=8)  # Set y-axis tick label font size
fig.tight_layout()
fig_path = "traj_tran_rot_energy.pdf"
fig.savefig(fig_path)
print(f"-> figure saved under '{fig_path}'")

exit(0)
