import os
import numpy as np
import torch as tc
from argparse import ArgumentParser
from src.data.molecular_dynamics import animate_2D, get_x0_and_box, get_equil_r, simulate
from time import time

argparser = ArgumentParser()
argparser.add_argument("-nx", "--n_obj_x", type=int, help="Number of particles", required=True)
argparser.add_argument("-ny", "--n_obj_y", type=int, help="Number of particles", required=True)
argparser.add_argument("-dt", "--delta_t", type=float, help="Time step size (delta t)", required=True)
argparser.add_argument("-T", "--n_steps", type=int, help="Total number of steps", required=True)
argparser.add_argument("-m", "--mass", type=float, help="Mass of a single particle", required=False, default=1.0)
argparser.add_argument("-eps", "--epsilon", type=float, help="Epsilon parameter in Lennard-Jones", required=False, default=1.0)
argparser.add_argument("-sig", "--sigma", type=float, help="Sigma parameter in Lennard-Jones", required=False, default=1.0)
argparser.add_argument("-cutoff", "--cutoff", type=float, help="Epsilon parameter in Lennard-Jones", required=False, default=np.inf)
argparser.add_argument("-p", "--path", type=str, required=True, help="Saved model path")
argparser.add_argument("-q_noise", "--q_noise", type=float, required=False, default=0.1,
                       help="Uniformly disturb initial state q, default=0.1")
argparser.add_argument("-p_noise", "--p_noise", type=float, required=False, default=0.0,
                       help="Uniformly disturb initial state p, default=0.0")
argparser.add_argument("-c", "--cont", action="store_true", default=False, required=False,
                       help="Evaluate from equilibrium position")
argparser.add_argument("-e_dt", "--equil_delta_t", type=float, help="Equilibration time step size (delta t) only if --cont is specified", required=False, default=5e-3)
argparser.add_argument("-e_T", "--equil_n_steps", type=int, help="Equilibration Total number of steps only if --cont is specified", required=False, default=10_000)
argparser.add_argument("-o", "--outdir", required=True, help="Output dir to save the animation")
args = argparser.parse_args()
random_state = 47285628
rng = np.random.default_rng(random_state)

# Load model
model = tc.load(args.path, weights_only=False)
print(f"-> Model loaded '{args.path}'")
dtype_np = np.float32 if model.dtype == tc.float32 else np.float64
r_eq = get_equil_r(args.sigma)
x0, box_start, box_end = get_x0_and_box(args.n_obj_x, args.n_obj_y, args.sigma, r_eq, args.q_noise, args.p_noise, dtype_np, rng)

print(f"-> Simulation box is created with {args.n_obj_x * args.n_obj_y} number of particles")

# Equilibrate if 'cont' is specified in arguments
if args.cont:
    print(f"-> Equilibrating to get an initial condition (--cont specified) equil_delta_t={args.equil_delta_t} and equil_n_steps={args.equil_n_steps}")
    x_traj, _, _, _, _ = simulate(x0, args.mass, args.epsilon, args.sigma, args.cutoff,
                               args.equil_n_steps, args.equil_delta_t,
                               bc="reflective", box_start=box_start, box_end=box_end)
    x0 = x_traj[-1]

# Ground truth simulation
print("\n\n" + "-"*40 + "\n\n")
print("-> Simulating ground truth trajectory using the true Hamiltonian of the system")
t0 = time()
x_traj, ke_traj, pe_traj, h_traj, dHdx_traj = simulate(
    x0, args.mass, args.epsilon, args.sigma, args.cutoff,
    args.n_steps, args.delta_t,
    bc="reflective", box_start=box_start, box_end=box_end)
t1 = time()
print(f"-> Took {(t1 - t0):.2f} seconds")

# Model prediction
print("\n\n" + "-"*40 + "\n\n")
print("-> Predicting with the trained model")
t0 = time()
x_traj_pred, ke_traj_pred, pe_traj_pred, h_traj_pred, dHdx_traj_pred = simulate(
    x0, args.mass, args.epsilon, args.sigma, args.cutoff,
    args.n_steps, args.delta_t, model=model,
    bc="reflective", box_start=box_start, box_end=box_end)
t1 = time()
print(f"-> Took {(t1 - t0):.2f} seconds")

q_true = x_traj[..., :2]
q_pred = x_traj_pred[..., :2]
rmse_traj = np.sqrt(((q_true - q_pred)**2).mean(axis=(1, 2)))
rmse_dHdq_traj = np.sqrt(((dHdx_traj[..., :2] - dHdx_traj_pred[..., :2])**2).mean(axis=(1, 2)))
rmse_dHdp_traj = np.sqrt(((dHdx_traj[..., 2:] - dHdx_traj_pred[..., 2:])**2).mean(axis=(1, 2)))
print(f"-> Trajectory MSE:", ((x_traj - x_traj_pred)**2).mean())
print(f"-> KE Part    MSE:", ((ke_traj - ke_traj_pred)**2).mean())
print(f"-> PE Part    MSE:", ((pe_traj - pe_traj_pred)**2).mean())
print(f"-> dH/dx      MSE:", ((dHdx_traj - dHdx_traj_pred)**2).mean())
print(f"-> H true mean | stdev :", np.mean(h_traj), "|", np.std(h_traj))
logabs_h_traj = np.log(np.abs(h_traj))
logabs_h_traj_pred = np.log(np.abs(h_traj_pred))
print(f"   logabs H true mean | stdev :", np.mean(logabs_h_traj), np.std(logabs_h_traj))
print(f"-> H pred mean | stdev :", np.mean(h_traj_pred), "|", np.std(h_traj_pred))
print(f"   logabs H pred mean | stdev :", np.mean(logabs_h_traj_pred), np.std(logabs_h_traj_pred))

# Animation
# framing_length = 10
framing_length = 20
animate_2D(q_true[::framing_length], ke_traj, pe_traj, h_traj, box_start=box_start, box_end=box_end,
           q_pred=q_pred[::framing_length], ke_pred=ke_traj_pred, pe_pred=pe_traj_pred, h_pred=h_traj_pred, rmse=rmse_traj,
           delta_t=args.delta_t, framing_length=framing_length, filename=os.path.join(args.outdir, f"{args.n_obj_x*args.n_obj_y}n_obj_2DOF_cutoff{args.cutoff}.mp4"))


# t0 = time()
# animate_2D(x_traj[::framing_length][..., :2], ke_traj, pe_traj, filename=animation_filename_prefix+f"{sim_idx}_2D_{n_obj}particles_{len(x_traj)}steps_{args.delta_t}deltat.mp4")
# t1 = time()
# print(f"-> Took {(t1 - t0):.2f} seconds")

# Load dataset (if trained on the last members only)
# x0 as randomly initialized q_noise = 0.1 p_noise = 0.0 (parameters)
# x0 as the last snapshot of the loaded data (if trained on the last members only)

# simulate trajectory
# evaluate traj mse
# animate
