import ase.io
from metatensor.torch.atomistic import load_atomistic_model
import numpy as np
import sys
from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator
import matplotlib.pyplot as plt
from skipmd.ase.bussi import Bussi
import ase.units


rescale_energy = (sys.argv[1] == "True")

model_path = "../../models/water_4fs.pt"
model_path_no_pt = model_path[:-3]
time_lag = int(model_path_no_pt.split("_")[-1].split("fs")[0])

# read atoms from a trajectory
atoms = ase.io.read("water.xyz")

calculator = MetatensorCalculator("../../models/pet-mad-v1.0.1.pt", device="cuda")
atoms.calc = calculator

device="cuda"

model = load_atomistic_model(model_path)
model = model.to(device)

dyn = Bussi(atoms, time_lag * ase.units.fs, 450.0, model, time_constant=10 * ase.units.fs, device=device, rescale_energy=rescale_energy)

n_steps = 1000

temperatures = [atoms.get_temperature()]
potential_energies = [atoms.get_potential_energy()]
kinetic_energies = [atoms.get_kinetic_energy()]
total_energies = [atoms.get_total_energy()]
def get_energies():
    if dyn.nsteps % 100 == 0:
        print(f"step {dyn.nsteps}")
    potential_energies.append(atoms.get_potential_energy())
    kinetic_energies.append(atoms.get_kinetic_energy())
    total_energies.append(atoms.get_total_energy())
    temperatures.append(atoms.get_temperature())

dyn.attach(get_energies)
dyn.run(n_steps)

potential_energies = np.array(potential_energies)
kinetic_energies = np.array(kinetic_energies)
total_energies = np.array(total_energies)
temperatures = np.array(temperatures)

np.save(f"potential_energies_{rescale_energy}_nvt.npy", potential_energies)
np.save(f"kinetic_energies_{rescale_energy}_nvt.npy", kinetic_energies)
np.save(f"total_energies_{rescale_energy}_nvt.npy", total_energies)
np.save(f"temperatures_{rescale_energy}_nvt.npy", temperatures)

