import ase.io
import sys
from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator
import ase.units
import ase.md
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.io.trajectory import Trajectory


size = sys.argv[1]
temperature = float(sys.argv[2])

if size == "normal":
    size = ""
else:
    size = f"_{size}"

atoms = ase.io.read(f"../structures/water{size}.xyz")
calculator = MetatensorCalculator("pet-mad-v1.0.1.pt", device="cuda")
atoms.calc = calculator

# equilibrate with Langevin dynamics
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature)
dyn = ase.md.Langevin(atoms, timestep=0.5*ase.units.fs, temperature_K=temperature, friction=1.0/(10.0*ase.units.fs))

def print_temperature():
    print(f"Temperature: {atoms.get_temperature()} K")

dyn.attach(print_temperature, interval=100)
dyn.run(10000)

# run NVE
traj = Trajectory(f"water{size}_{temperature}.traj", "w", atoms)

def save_atoms():
    atoms.info["total_energy"] = atoms.get_total_energy()
    traj.write(atoms)

dyn = ase.md.VelocityVerlet(atoms, timestep=0.25*ase.units.fs)
dyn.attach(save_atoms)
dyn.run(8000)
