import argparse
import numpy as np
import os
from ase.io import read, Trajectory
from ase import Atoms
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.md.nose_hoover_chain import NoseHooverChainNVT
from ase.units import fs, kB
import ase.units as units
from tqdm import tqdm
from mace.calculators import MACECalculator
from rdkit import Chem

def main():
    parser = argparse.ArgumentParser(description="Run MD simulation with MLFF model and calculate stability metric.")
    parser.add_argument('--xyz_file_path', type=str, required=True, help='Path to the XYZ file.')
    parser.add_argument('--output_path', type=str, required=True, help='Output directory path.')
    parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to the MACE model checkpoint.')
    parser.add_argument('--smiles', type=str, required=True, help='SMILES string for the molecule.')
    parser.add_argument('--temperature', type=float, default=300.0, help='Temperature in Kelvin.')
    parser.add_argument('--timestep', type=float, default=1.0, help='Timestep in femtoseconds.')
    parser.add_argument('--total_steps', type=int, default=100000, help='Total simulation steps.')
    parser.add_argument('--log_interval', type=int, default=100, help='Logging interval in steps.')
    parser.add_argument('--tdamp', type=float, default=100.0, help='Damping time for Nose-Hoover in fs.')
    parser.add_argument('--first_n', type=int, default=-1, help='Number of initial frames for equilibrium bond lengths (-1 for auto).')

    args = parser.parse_args()

    os.makedirs(args.output_path, exist_ok=True)

    # Load molecule
    molecules = read(args.xyz_file_path, format='extxyz', index=':')
    molecule = molecules[0]

    # Set up calculator
    calculator = MACECalculator(model_path=args.checkpoint_path, device='cuda')
    calculator.models[0].eval()
    molecule.set_calculator(calculator)

    # Simulation parameters
    timestep_ase = args.timestep * units.fs  # Timestep in ASE units
    tdamp_ase = args.tdamp * units.fs  # tdamp in ASE units

    # Initialize output file for energies and temperature
    output_file = f'{args.output_path}/md_energies.txt'
    with open(output_file, 'w') as f:
        f.write('# Step\tTime(ps)\tPotential_E(eV)\tKinetic_E(eV)\tTotal_E(eV)\tTemperature(K)\n')

    # Function to log energies and temperature
    def log_energies(atoms, step, time_ps):
        pot_energy = atoms.get_potential_energy()
        kin_energy = atoms.get_kinetic_energy()
        total_energy = pot_energy + kin_energy
        temp = kin_energy / (1.5 * kB * len(atoms))  # Temperature from kinetic energy
        with open(output_file, 'a') as f:
            f.write(f'{step}\t{time_ps:.3f}\t{pot_energy:.6f}\t{kin_energy:.6f}\t{total_energy:.6f}\t{temp:.2f}\n')

    # Initialize velocities
    MaxwellBoltzmannDistribution(molecule, temperature_K=args.temperature)

    # Set up dynamics
    dyn = NoseHooverChainNVT(molecule, timestep=timestep_ase, temperature_K=args.temperature, tdamp=tdamp_ase,
                             logfile=f'{args.output_path}/md_nh.log', trajectory=f'{args.output_path}/md_nh.traj')

    # Attach the logging function
    def print_energies():
        step = dyn.get_number_of_steps()
        time_ps = (step * args.timestep) / 1000.0  # Time in picoseconds
        log_energies(molecule, step, time_ps)

    dyn.attach(print_energies, interval=args.log_interval)

    # Run the simulation
    for step in tqdm(range(args.total_steps)):
        dyn.run(1)  # Run one step at a time
    # dyn.run(args.total_steps)

    print("MD simulation completed. Energies and temperature saved to 'md_energies.txt'.")
    print("Trajectory saved to 'md_nh.traj'.")

    # Now calculate stability
    traj_file = f'{args.output_path}/md_nh.traj'
    traj = read(traj_file, index=':')

    # RDKit molecule
    mol = Chem.MolFromSmiles(args.smiles)
    mol = Chem.AddHs(mol)  # Add hydrogens if implicit

    # Get bonds from RDKit molecule
    bonds = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()]

    # Check if number of atoms matches
    if len(traj[0]) != mol.GetNumAtoms():
        raise ValueError("Number of atoms in trajectory does not match the molecule.")

    # Determine the number of frames to use for calculating equilibrium bond lengths
    if args.first_n == -1:
        first_n = min(100, max(1, len(traj) // 10))
    else:
        first_n = args.first_n

    # Compute equilibrium bond lengths as the mean over the first_n frames
    eq_lengths = {}
    for bond in bonds:
        dists = []
        for atoms in traj[:first_n]:
            pos = atoms.get_positions()
            dist = np.linalg.norm(pos[bond[0]] - pos[bond[1]])
            dists.append(dist)
        eq_lengths[bond] = np.mean(dists)

    # Now, iterate over all frames to find the first frame where max deviation > 0.5 Å
    stable_frames = len(traj)
    for k, atoms in enumerate(traj):
        max_dev = 0.0
        pos = atoms.get_positions()
        for bond in bonds:
            dist = np.linalg.norm(pos[bond[0]] - pos[bond[1]])
            dev = abs(dist - eq_lengths[bond])
            if dev > max_dev:
                max_dev = dev
        if max_dev > 0.5:
            stable_frames = k
            break

    # The stable time in ps and fs
    stable_time_ps = (stable_frames - 1) * args.timestep / 1000.0 if stable_frames > 0 else 0.0
    stable_time_fs = (stable_frames - 1) * args.timestep if stable_frames > 0 else 0.0
    output_s_file = f'{args.output_path}/md_stability.txt'
    print(f"Stability of the trajectory: {stable_time_ps} ps")
    print(f"Stability of the trajectory: {stable_time_fs} fs")
    with open(output_s_file, 'w') as f:
        f.write(f"Stability of the trajectory: {stable_time_ps} ps")


if __name__ == "__main__":
    main()