# nohup python generate_aldp_samples_vacuum.py 0 --resume > generate_aldp_samples_vacuum_0.log 2>&1 &
# nohup python generate_aldp_samples_vacuum.py 1 --resume > generate_aldp_samples_vacuum_1.log 2>&1 &
# CUDA_VISIBLE_DEVICES=0 python generate_aldp_samples_vacuum.py 0 --resume > generate_aldp_samples_vacuum.log 2>&1 &


import os
import sys
from openmm import unit, app, Platform
from openmmtools import states, mcmc, multistate, testsystems
from openmmtools.cache import ContextCache
import logging
import argparse

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%H:%M:%S'
)

# --- 0. ARGUMENT HANDLING ---
parser = argparse.ArgumentParser(description="Run or resume ALDP vacuum PT simulation.")
parser.add_argument("index", type=str, help="Index for the run")
parser.add_argument("--resume", action="store_true", help="Resume from the existing checkpoint file")
args = parser.parse_args()

# --- 1. CONFIGURATION ---
run_index = args.index
OUTPUT_BASE_NAME = f""
OUTPUT_FILE = f"{OUTPUT_BASE_NAME}.nc"
CHECKPOINT_FILE = f"{OUTPUT_BASE_NAME}_checkpoint.nc"
N_REPLICAS = 10  # Number of replicas
N_ITERATIONS = 500_000  # Number of exchange attempts
STEPS_PER_ITERATION = 1000  # MD steps between swaps (1ps if dt=2fs)
T_MIN = 300.0 * unit.kelvin
T_MAX = 1500.0 * unit.kelvin

if not args.resume:
    for f in [OUTPUT_FILE, CHECKPOINT_FILE]:
        if os.path.exists(f):
            os.remove(f)
    print(f"Starting NEW simulation: {OUTPUT_FILE}")
else:
    if not os.path.exists(OUTPUT_FILE):
        print(f"Error: Cannot resume. {OUTPUT_FILE} not found.")
        sys.exit(1)
    print(f"RESUMING simulation from: {OUTPUT_FILE}")

testsystem = testsystems.AlanineDipeptideVacuum()
n_atoms = testsystem.topology.getNumAtoms()

reporter = multistate.MultiStateReporter(
    OUTPUT_FILE, 
    checkpoint_interval=100,
    analysis_particle_indices=tuple(range(n_atoms)),
    position_interval=1,
)

if args.resume:
    # Load the sampler from the existing storage
    # This automatically recovers the states, positions, and current iteration
    sampler = multistate.ReplicaExchangeSampler.from_storage(reporter)
else:
    # move = mcmc.LangevinSplittingDynamicsMove(
    #     timestep=4.0 * unit.femtoseconds,
    #     collision_rate=1.0 / unit.picoseconds,
    #     n_steps=STEPS_PER_ITERATION,
    #     reassign_velocities=False
    # )

    move = mcmc.LangevinDynamicsMove(
        timestep=2.0 * unit.femtoseconds,
        collision_rate=1.0 / unit.picoseconds,
        n_steps=STEPS_PER_ITERATION,
        reassign_velocities=False
    )

    # move = mcmc.GHMCMove(
    #     timestep=2.0 * unit.femtoseconds,
    #     collision_rate=1.0 / unit.picoseconds,
    #     n_steps=STEPS_PER_ITERATION,
    #     reassign_velocities=False
    # )

    # sampler = multistate.ReplicaExchangeSampler(
    #     mcmc_moves=move, 
    #     number_of_iterations=N_ITERATIONS,
    #     online_analysis_interval=None
    # )

    sampler = multistate.ParallelTemperingSampler(
        mcmc_moves=move,
        number_of_iterations=N_ITERATIONS,
        online_analysis_interval=None
    )
    reference_state = states.ThermodynamicState(system=testsystem.system, temperature=T_MIN)

    print(f"Initializing Replica Exchange with {N_REPLICAS} replicas...")
    sampler.create(
        reference_state, 
        states.SamplerState(positions=testsystem.positions),
        reporter,
        min_temperature=T_MIN,
        max_temperature=T_MAX,
        n_temperatures=N_REPLICAS
    )

# platform = get_fastest_platform()
# platform = Platform.getPlatformByName('CPU')
platform = Platform.getPlatformByName('CUDA')
# platform = Platform.getPlatformByName('OpenCL')
cache = ContextCache(
    platform=platform,
    platform_properties={'Precision': 'single'}
)
sampler.energy_context_cache = cache
sampler.sampler_context_cache = cache

# Update iterations if you want to extend a finished run
if args.resume:
    print(f"Simulation already completed {sampler.iteration} iterations. Extending to {N_ITERATIONS} iterations.")
    sampler.number_of_iterations = N_ITERATIONS

print(f"Starting simulation: {N_ITERATIONS} iterations...")
sampler.run()
print(f"Done! Data saved to {OUTPUT_FILE}")

analyzer = multistate.ReplicaExchangeAnalyzer(reporter)
analyzer.show_mixing_statistics()



