from ipi.utils.scripting import InteractiveSimulation
from metatensor.torch.atomistic import load_atomistic_model
from skipmd.ipi import get_skipmd_velocity_verlet_step
import torch
from ipi.utils.depend import dstrip
from ipi.utils.units import Constants
import numpy as np


model_name = "universal"
delta_t = 1

print(model_name, delta_t)

with open("input_skipmd.xml", "r") as input_xml:
    sim = InteractiveSimulation(input_xml)

model = load_atomistic_model(f"../../../models/{model_name}_{delta_t}fs_final.pt")
device = ("cuda" if torch.cuda.is_available() else "cpu")

skipmd_velocity_verlet_step = get_skipmd_velocity_verlet_step(sim, model, device)

# sets barostat q-step timestep to the same as p step (as we do a dual split)
system = sim.simulation.syslist[0]

def pbaro(baro):
    # we are assuming then that p the coupling between p^2 and dp/dt only involves the fast force
    dt = baro.pdt[0]

    # computes the pressure associated with the forces at each MTS level.
    press = np.trace(baro.stress_mts(0)) / 3.0
    # integerates the kinetic part of the pressure with the force at the inner-most level.
    nbeads = baro.beads.nbeads
    baro.p += (
        3.0
        * dt
        * (
            baro.cell.V * (press - nbeads * baro.pext)
            + Constants.kb * baro.temp
        )
    )

def qbaro(baro):
    """Propagation step for the cell (adjusting atomic positions and momenta)."""

    v = baro.p[0] / baro.m[0]
    halfdt = (
        baro.qdt
    )  # this is set to half the inner loop in all integrators that use a barostat
    expq, expp = (np.exp(v * halfdt), np.exp(-v * halfdt))

    m = dstrip(baro.beads.m3)[0]

    baro.nm.qnm[0, :] *= expq
    baro.nm.pnm[0, :] *= expp
    baro.cell.h *= expq

def skipmd_opbabpo(self, **_):

    self.thermostat.step()
    self.barostat.thermostat.step()
    self.integrator.pconstraints()

    pbaro(self.barostat)
    qbaro(self.barostat)

    skipmd_velocity_verlet_step(self, rescale_energy=True, random_rotation=True)

    qbaro(self.barostat)
    pbaro(self.barostat)

    self.barostat.thermostat.step()
    self.thermostat.step()
    self.integrator.pconstraints()

    self.ensemble.time += self.dt

sim.set_motion_step(skipmd_opbabpo)
sim.run(100000//delta_t)  # 100 ps
