"""
Integrates 2DOF chain system using the true Hamiltonian and learned models.
"""

import os
import torch as tc
import numpy as np
from time import time
from argparse import ArgumentParser
from src.utils import Mesh, flatten_TC, unflatten_TC, grad
from src.data import MassSpring, grad_H_mass_spring

argparser = ArgumentParser()
argparser.add_argument("--node_scaling_dir", type=str, help="Path to the node scaling experiment directory containing trained models.")
argparser.add_argument("--Nx_gnn", type=int, required=True, help="Number of objects in the training chain system.")
argparser.add_argument("--Nx_test", type=int, required=True, help="Number of nodes in the test chain system (for the integration).")
argparser.add_argument("--len_traj", type=int, required=True, help="Trajectory length of the integration")
argparser.add_argument("--delta_t", type=float, required=True, help="Time step size")

args = vars(argparser.parse_args())
Nx_test = args["Nx_test"]
Nx_gnn = args["Nx_gnn"]
model_dir = args["node_scaling_dir"]
len_traj = args["len_traj"]
delta_t = args["delta_t"]

model_names = [ "adam-hgn", "elm-rf-hgn", "swim-rf-hgn" ]
if Nx_test <= Nx_gnn:
    model_names = [ "adam-hnn", "elm-rf-hnn", "swim-rf-hnn" ] + model_names

# zero-shot generalization demonstrated
model_paths = [ os.path.join(model_dir, f"{model_name}_Nx{Nx_gnn}.pt") for model_name in model_names ]

data_seed = 1945615
dof = 2
q_min, q_max = -0.4, 0.4                    # should be small enough to keep data inside training range
p_min, p_max = 0.0, 0.0
plotx = np.arange(1, len_traj + 1)

dtype_np = np.float32
mass, spring_constant = 1.0, 1.0
meshing = "rectangular"
meshing = Mesh(mesh_type=meshing)
n_features = int(Nx_test * 2*dof)

def prepare_data(n_obj, data_seed):
    rng = np.random.default_rng(data_seed)
    # Create mass spring data to store the trajectory, where n_points = len_traj
    q = np.empty((len_traj, n_features//2), dtype=np.float32)
    p = np.empty((len_traj, n_features//2), dtype=np.float32)
    # Initial displacements
    q0 = rng.uniform(q_min, q_max, size=(n_features//2)).astype(np.float32)
    p0 = rng.uniform(p_min, p_max, size=(n_features//2)).astype(np.float32)
    q[0] = q0; p[0] = p0
    system = MassSpring(n_points=len_traj, n_features=n_features, q=q, p=p,
                        n_obj=n_obj, dof=dof, meshing=meshing)
    return system

def integrate(system, integration_grad_H):
    time0 = time()
    system.integrate(len_traj, delta_t, integration_grad_H, integration_method="stormer_verlet")
    time1 = time()
    print(f"Integration took: {(time1-time0):.2f} seconds")
    x = system.to_array()
    print(f"q limits: [{np.min(x[..., :dof])}, {np.max(x[..., :dof])}]")
    print(f"p limits: [{np.min(x[..., dof:])}, {np.max(x[..., dof:])}]")
    qx, qy = x[..., 0], x[..., 1]
    hamiltonian = system.H()
    print(f"hamiltonian limits: [{np.min(hamiltonian):.2f}, {np.max(hamiltonian):.2f}]")
    return qx, qy, hamiltonian

# use the true gradient of the system to integrate reference solution
system = prepare_data([Nx_test], data_seed)
integration_grad_H = lambda x: grad_H_mass_spring(x, system.dof, system.meshing, system.mass, system.spring_constant, "chain")
qx_true, qy_true, hamil_true = integrate(system, integration_grad_H)
qx_true_path = os.path.join(model_dir, f"qx_true_{Nx_test}_{len_traj}_{delta_t:.2e}.npy")
qy_true_path = os.path.join(model_dir, f"qy_true_{Nx_test}_{len_traj}_{delta_t:.2e}.npy")
hamil_true_path = os.path.join(model_dir, f"energy_true_{Nx_test}_{len_traj}_{delta_t:.2e}.npy")
for data, path in zip([qx_true, qy_true, hamil_true], [qx_true_path, qy_true_path, hamil_true_path]):
    with open(path, "wb") as f:
        np.save(f, data)
        print(f"-> saved data at {path}")

for model_path, model_name in zip(model_paths, model_names):
    # use trained model's gradient to integrate
    system = prepare_data([Nx_test], data_seed)
    model = tc.load(model_path, weights_only=False)
    model.n_obj = [Nx_test]
    model.edge_index = system.edge_index()
    model_grad_H = lambda x: unflatten_TC(grad(model, flatten_TC(tc.from_numpy(x))), model.n_obj, dof).detach().numpy()
    print("-> Integrating", model_name)
    qx_pred, qy_pred, hamil_pred = integrate(system, model_grad_H)
    qx_pred_path = os.path.join(model_dir, f"qx_{Nx_test}_{model_name}_{len_traj}_{delta_t:.2e}.npy")
    qy_pred_path = os.path.join(model_dir, f"qy_{Nx_test}_{model_name}_{len_traj}_{delta_t:.2e}.npy")
    hamil_pred_path = os.path.join(model_dir, f"energy_{Nx_test}_{model_name}_{len_traj}_{delta_t:.2e}.npy")
    for data, path in zip([qx_pred, qy_pred, hamil_pred], [qx_pred_path, qy_pred_path, hamil_pred_path]):
        with open(path, "wb") as f:
            np.save(f, data)
            print(f"-> saved data at {path}")
    mse = np.mean((qx_true - qx_pred)**2 + (qy_true - qy_pred)**2, axis=1)
    print(f"=trajectory mean of all (q) MSE: {mse.mean():.2e}")
exit(0)
