"""
Creates data for 2D molecular dynamics
"""
import os
import numpy as np
import torch as tc
from time import time
from argparse import ArgumentParser
from src.data.molecular_dynamics import get_x0_and_box, simulate, get_equil_r, animate_2D

random_state = 1234
rng = np.random.default_rng(random_state)
tc.manual_seed(random_state)
tc.cuda.manual_seed(random_state)
tc.cuda.manual_seed_all(random_state)

argparser = ArgumentParser()
argparser.add_argument("-nx", "--n_obj_x", type=int, help="Number of particles", required=True)
argparser.add_argument("-ny", "--n_obj_y", type=int, help="Number of particles", required=True)
argparser.add_argument("-dt", "--delta_t", type=float, help="Time step size (delta t)", required=True)
argparser.add_argument("-T", "--n_steps", type=int, help="Total number of steps", required=True)
argparser.add_argument("-m", "--mass", type=float, help="Mass of a single particle", required=False, default=1.0)
argparser.add_argument("-eps", "--epsilon", type=float, help="Epsilon parameter in Lennard-Jones", required=False, default=1.0)
argparser.add_argument("-sig", "--sigma", type=float, help="Sigma parameter in Lennard-Jones", required=False, default=1.0)
argparser.add_argument("-cutoff", "--cutoff", type=float, help="Epsilon parameter in Lennard-Jones", required=False, default=np.inf)
argparser.add_argument("-prec", "--precision", type=str, help="single or double", default="double")
argparser.add_argument("-q_noise", "--q_noise", type=float, help="Uniformly disturb initial state q", default=0.1)
argparser.add_argument("-p_noise", "--p_noise", type=float, help="Uniformly disturb initial state p", default=0.0)
argparser.add_argument("-bc", "--bc", type=str, help="Boundary condition", default="None", choices=["none", "reflective", "periodic"])
argparser.add_argument("-o", "--outdir", type=str, help="Output directory to save the created data", required=False, default="data")
argparser.add_argument("-e", "--equil", action="store_true", default=False, required=False,
                       help="Equilibrate before collecting data")
argparser.add_argument("-e_dt", "--equil_delta_t", type=float, help="Equilibration time step size (delta t) only if --cont is specified", required=False, default=5e-3)
argparser.add_argument("-e_T", "--equil_n_steps", type=int, help="Equilibration Total number of steps only if --cont is specified", required=False, default=10_000)
argparser.add_argument("-ns", "--n_simulations", type=int, help="Number of simulations", required=False, default=1_000)
argparser.add_argument("-fl", "--framing_length", type=int, help="Framing length", required=False, default=250)
args = argparser.parse_args()

n_obj = args.n_obj_x * args.n_obj_y
dtype_np = np.float32 if args.precision == "single" else np.float64
r_eq = get_equil_r(args.sigma)

filename_prefix = f"{args.n_simulations}sim_{n_obj}particles_{args.n_steps}steps_{args.delta_t}deltat_{args.cutoff}cutoff"

data = []
for sim_idx in range(1, args.n_simulations + 1):
    x0, box_start, box_end = get_x0_and_box(args.n_obj_x, args.n_obj_y, args.sigma, r_eq, args.q_noise, args.p_noise, dtype_np, rng)
    if args.equil:
        print(f"-> Equilibrating to get an initial condition (--equil specified) equil_delta_t={args.equil_delta_t} and equil_n_steps={args.equil_n_steps}")
        x_traj, _, _, _, _ = simulate(x0, args.mass, args.epsilon, args.sigma, args.cutoff,
                                      args.equil_n_steps, args.equil_delta_t,
                                      bc=args.bc,
                                      box_start = -10**9 if args.bc == "none" else box_start,
                                      box_end = 10**9 if args.bc == "none" else box_end)
        x0 = x_traj[-1]
    print("\n\n" + "-"*40 + "\n\n")
    print(f"-> Simulating {sim_idx}/{args.n_simulations}")
    t0 = time()
    x_traj, ke_traj, pe_traj, h_traj, grad_H_traj = simulate(
        x0, args.mass, args.epsilon, args.sigma, args.cutoff,
        args.n_steps, args.delta_t,
        bc=args.bc,
        box_start = -10**9 if args.bc == "none" else box_start,
        box_end = 10**9 if args.bc == "none" else box_end)

    t1 = time()
    print(f"-> Simulation min p = {np.min(x_traj[..., 2:]):.5f}, max p = {np.max(x_traj[..., 2:]):.5f}")
    print(f"-> Took {(t1 - t0):.2f} seconds")
    data.append(x_traj)
    if sim_idx > 1: continue # only save some animation steps to visualize some examples
    t0 = time()

    animate_2D(x_traj[::args.framing_length][..., :2],
           ke_traj, pe_traj, h_traj,
           box_start = None if args.bc == "none" else box_start,
           box_end = None if args.bc == "none" else box_end,
           delta_t=args.delta_t, framing_length=args.framing_length,
           filename=os.path.join(args.outdir, filename_prefix + ".mp4"))
    t1 = time()
    print(f"-> Took {(t1 - t0):.2f} seconds")

datapath = os.path.join(args.outdir, filename_prefix + ".npy")
np.save(datapath, np.stack(data, axis=0))
print(f"-> Dataset saved at '{datapath}'")
