from ase import io, units
from ase.io import Trajectory
from ase.optimize import BFGS
from ase.md.langevin import Langevin
from ase.md.verlet import VelocityVerlet
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution

import numpy as np
import sys

from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator

calc = MetatensorCalculator("pet-mad-v1.5.pt", extensions_directory="extensions", device="cuda")

str_idx = int(sys.argv[1])
run_idx = int(sys.argv[2])

# Load the initial structure (adjust filename as needed)
atoms = io.read('rand_strucs_new_10000.xyz', str_idx)
atoms.calc = calc

# Randomly select a temperature between 0 and 1500 K
target_T = np.random.uniform(low=0.0, high=1500.0)

print(f"Structure index {str_idx}:", atoms)
print("Temperate [K]:", target_T)

# OPTIMIZE GEOMETRY (CELL FIXED)
opt = BFGS(atoms)
opt.run(fmax=0.01, steps=1000)
# Write out the optimized structure
print("Geometry optimization complete. *****")

# EQUILIBRATE WITH NVT
MaxwellBoltzmannDistribution(atoms, temperature_K=target_T)
dyn_nvt = Langevin(atoms,
        timestep=0.5 * units.fs,
        temperature_K=target_T,
        friction=0.01 / units.fs,
        logfile='-',
        loginterval=10,
)
nvt_traj = Trajectory(f'str{str_idx}_NVT_{run_idx}.traj', 'w', atoms)
dyn_nvt.attach(nvt_traj.write, interval=20)
dyn_nvt.run(steps=20000)

print("NVT equilibration complete. *****")

# PRODUCTION WITH NVE
dyn_nve = VelocityVerlet(atoms,
        timestep = 0.25 * units.fs,
        logfile='-',
        loginterval=1,
)

nve_traj = Trajectory(f'str{str_idx}_NVE_{run_idx}.traj', 'w', atoms)
     
# Custom function to write frame
def write_energy():
    atoms.get_forces()  # ensures forces/energies are updated
    epot = atoms.get_potential_energy()
    ekin = atoms.get_kinetic_energy()
    atoms.info['potential_energy'] = epot 
    atoms.info['kinetic_energy'] = ekin
    atoms.info['total_energy'] = epot + ekin
    nve_traj.write(atoms)

dyn_nve.attach(write_energy, interval=1)
dyn_nve.run(steps=10000)

print("NVE production complete. *****")


