from ipi.utils.scripting import InteractiveSimulation
from ipi.utils.depend import dstrip
from ipi.utils.units import Constants
import numpy as np


with open("input_petmad.xml", "r") as input_xml:
    sim = InteractiveSimulation(input_xml)

# sets barostat q-step timestep to the same as p step (as we do a dual split)
system = sim.simulation.syslist[0]

def pbaro(baro):
    # we are assuming then that p the coupling between p^2 and dp/dt only involves the fast force
    dt = baro.pdt[0]

    # computes the pressure associated with the forces at each MTS level.
    press = np.trace(baro.stress_mts(0)) / 3.0
    # integerates the kinetic part of the pressure with the force at the inner-most level.
    nbeads = baro.beads.nbeads
    baro.p += (
        3.0
        * dt
        * (
            baro.cell.V * (press - nbeads * baro.pext)
            + Constants.kb * baro.temp
        )
    )

def qbaro(baro):
    """Propagation step for the cell (adjusting atomic positions and momenta)."""

    v = baro.p[0] / baro.m[0]
    halfdt = (
        baro.qdt
    )  # this is set to half the inner loop in all integrators that use a barostat
    expq, expp = (np.exp(v * halfdt), np.exp(-v * halfdt))

    m = dstrip(baro.beads.m3)[0]

    baro.nm.qnm[0, :] *= expq
    baro.nm.pnm[0, :] *= expp
    baro.cell.h *= expq

def opbabpo(self, **_):
    self.thermostat.step()
    self.barostat.thermostat.step()
    self.integrator.pconstraints()

    pbaro(self.barostat)
    qbaro(self.barostat)

    self.beads.p[:] += dstrip(self.forces.f) * self.dt * 0.5
    self.beads.q[:] += dstrip(self.beads.p) / dstrip(self.beads.m3) * self.dt
    self.beads.p[:] += dstrip(self.forces.f) * self.dt * 0.5

    qbaro(self.barostat)
    pbaro(self.barostat)

    self.barostat.thermostat.step()
    self.thermostat.step()
    self.integrator.pconstraints()

    self.ensemble.time += self.dt

sim.set_motion_step(opbabpo)
sim.run(400000)  # 100 ps
