import torch as tc
import numpy as np
from time import time
from argparse import ArgumentParser
from src.data import MassSpring
from src.utils import Mesh, eval_dxdt

argparser = ArgumentParser()
argparser.add_argument("-m", "--model", type=str, help="Path to .pt model trained with any number of nodes", required=True)
argparser.add_argument("-n_obj", type=int, nargs="+", required=True)
args = vars(argparser.parse_args())
n_obj = args["n_obj"]
n_edges = int(np.prod(n_obj))
model = tc.load(args["model"], weights_only=False)
if model.dtype == tc.float32:
    dtype_np = np.float32
elif model.dtype == tc.float64:
    dtype_np = np.float64
else:
    raise ValueError(f"Unknown precision in the model {model.dtype}")

data_seed = 129847
dof = model.dof
n_points = 1000
x_min, x_max = -0.5, 0.5

mass, spring_constant = 1.0, 1.0
meshing = "rectangular"
meshing = Mesh(mesh_type=meshing)

# Prepare data
rng = np.random.default_rng(data_seed)
n_features = np.prod(n_obj) * 2*dof
x = rng.uniform(x_min, x_max, size=(n_points, n_features)).astype(dtype_np)
q, p = np.split(x, 2, axis=-1)
system = MassSpring(n_points, n_features, q, p, n_obj, dof, mass, spring_constant, meshing)
x_test, dxdt_test  = tc.from_numpy(system.to_array(flatten=True)), tc.from_numpy(system.dxdt(flatten=True))
L = tc.from_numpy(system.L())

# Evaluate model using the new edge_index and n_obj
model.n_obj = n_obj
model.edge_index = system.edge_index()

time0 = time()
test_mse, test_rel2= eval_dxdt(model, x_test, L, dxdt_test, verbose=False)
time1 = time()
print(f"Testing with n_obj={n_obj}")
print(f"- Relative L2:      {test_rel2:.2e}")
print(f"- MSE        :      {test_mse:.2e}")
print(f"Inference took {(time1-time0):.2f} seconds")
