from ipi.utils.scripting import InteractiveSimulation
from metatensor.torch.atomistic import load_atomistic_model
from skipmd.ipi import get_skipmd_velocity_verlet_step
import torch


model_name = "universal"
delta_t = 1
rescale_energy = False

print(model_name, delta_t, rescale_energy)

with open("input_skipmd.xml", "r") as input_xml:
    sim = InteractiveSimulation(input_xml)

model = load_atomistic_model(f"../../../models/{model_name}_{delta_t}fs_final.pt")
device = ("cuda" if torch.cuda.is_available() else "cpu")

skipmd_velocity_verlet_step = get_skipmd_velocity_verlet_step(sim, model, device)

def skipmd_step(self, **_):
    self.thermostat.step()
    self.integrator.pconstraints()
    skipmd_velocity_verlet_step(self, rescale_energy=rescale_energy, random_rotation=True)
    self.thermostat.step()
    self.integrator.pconstraints()
    self.ensemble.time += self.dt

sim.set_motion_step(skipmd_step)
sim.run(100000//delta_t)  # 100 ps
