from skipmd.utils.energy_error import get_energy_error
import sys
import numpy as np
import matplotlib.pyplot as plt


model_name = sys.argv[1]

energies_md, energies_skipmd = get_energy_error("trj_pet_mad_nvt.xyz", "../../models/pet-mad-v1.0.1.pt", model_name)
rmse = np.sqrt(np.mean((energies_md - energies_skipmd) ** 2))

np.save(f"energy_errors_{model_name.split('.')[-2].split('_')[-1]}.npy", np.array([energies_md, energies_skipmd]))
