import os
from pathlib import Path
import csv

from openmm import *
from openmm.app import *
from openmm.unit import *
import mdtraj as md
from sys import stdout
import numpy as np

from utils.Girsanov import LangevinSplittingGirsanov

data_dir = '' # TODO: add data dir
# Load PDB file (you can use Alanine dipeptide, e.g., alanine-dipeptide.pdb)
pdb = PDBFile(f"{data_dir}/alanine-dipeptide.pdb")
forcefield = ForceField("amber14-all.xml", 'implicit/obc2.xml')
# Create system (using implicit solvent)
system = forcefield.createSystem(pdb.topology,
                                 constraints=HBonds,
                                 nonbondedCutoff=1.0*nanometer,
                                 nonbondedMethod=CutoffNonPeriodic)

## ---------------------
# Add φ/ψ harmonic bias
# ---------------------
# Set bias parameters
phi0 = 0.0  # rad
psi0 = 0.0  # rad
k_bias = 1.0 #50 * kilojoule_per_mole / radians**2  # reasonable bias strength

# Extract φ and ψ torsion atom indices for alanine dipeptide (usually Ace-Ala-Nme)
# φ: C of Ace, N, CA, C of Ala
# ψ: N, CA, C of Ala, N of Nme
residues = list(pdb.topology.residues())
# Get Ace, Ala, Nme
ace = residues[0]
ala = residues[1]
nme = residues[2]
# φ: C(ACE) - N(ALA) - CA(ALA) - C(ALA)
ace_atoms = list(ace.atoms())
ala_atoms = list(ala.atoms())
phi_atoms = [
    [atom for atom in ace_atoms if atom.name == 'C'][0].index,
    [atom for atom in ala_atoms if atom.name == 'N'][0].index,
    [atom for atom in ala_atoms if atom.name == 'CA'][0].index,
    [atom for atom in ala_atoms if atom.name == 'C'][0].index
]

# ψ: N(ALA) - CA(ALA) - C(ALA) - N(NME)
nme_atoms = list(nme.atoms())
psi_atoms = [
    [atom for atom in ala_atoms if atom.name == 'N'][0].index,
    [atom for atom in ala_atoms if atom.name == 'CA'][0].index,
    [atom for atom in ala_atoms if atom.name == 'C'][0].index,
    [atom for atom in nme_atoms if atom.name == 'N'][0].index
]

# φ bias
phi_bias = CustomTorsionForce("0.5 * k * (theta - theta0)^2")
phi_bias.addPerTorsionParameter("k")
phi_bias.addPerTorsionParameter("theta0")
phi_bias.addTorsion(*phi_atoms, [k_bias, phi0])
system.addForce(phi_bias)
phi_bias.setForceGroup(1)

# ψ bias
psi_bias = CustomTorsionForce("0.5 * k * (theta - theta0)^2")
psi_bias.addPerTorsionParameter("k")
psi_bias.addPerTorsionParameter("theta0")
psi_bias.addTorsion(*psi_atoms, [k_bias, psi0])
system.addForce(psi_bias)
psi_bias.setForceGroup(1)

# ---------------------
# Integrator and platform settings (NVT)
# integrator = LangevinIntegrator(
#     300*kelvin, # temperature
#     1/picosecond, # friction
#     2*femtoseconds, # timestep
# )

nstxout = 20
integrator = LangevinSplittingGirsanov(
    nstxout=nstxout, # 20 steps per output
    temperature=300*kelvin,
    collision_rate=1/picosecond,
    timestep=2*femtoseconds,
    splitting="R O V O R")

platform = Platform.getPlatformByName('CPU')
properties = {'Threads': '2'}  # use 2 threads
simulation = Simulation(pdb.topology, system, integrator, platform, properties)

# Set positions: pdb.positions is in nanometer
simulation.context.setPositions(pdb.positions)

print("Start simulation!")
# Simulation setup
simulation.minimizeEnergy()

# Output setup
total_steps = 500_000_000  # 1 μs / 2 fs
output_dir = Path(f"{data_dir}")
output_dir.mkdir(parents=True, exist_ok=True)  # Create directory if it does not exist
simulation.reporters.append(DCDReporter(str(output_dir / "trajectory.dcd"), nstxout))  # Output every 100 steps
simulation.reporters.append(StateDataReporter(str(output_dir / "logger.txt"), 10000,
    temperature=True, progress=True, speed=True,
    totalSteps=total_steps, remainingTime=True, separator="\t"))  

class BiasEnergyAndMReporter(object):
    def __init__(self, file, reportInterval, integrator, context):
        self.file = file
        self.reportInterval = reportInterval
        self.integrator = integrator
        self.context = context
        self.step = 0
        with open(file, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['Step', 'logM', 'BiasEnergy_kJmol'])  # Header

    def describeNextReport(self, simulation):
        steps = self.reportInterval - (simulation.currentStep % self.reportInterval)
        return (steps, True, False, False, False)

    def report(self, simulation, state):
        M = self.integrator.getGlobalVariableByName("M")
        bias_state = self.context.getState(getEnergy=True, groups={1})  # group 1
        bias_energy = bias_state.getPotentialEnergy().value_in_unit(kilojoules_per_mole)
        self.step = simulation.currentStep
        with open(self.file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([self.step, M, bias_energy])


    def finalize(self):
        self.file.close()

simulation.reporters.append(
    BiasEnergyAndMReporter(
        file=str(output_dir / "gr.csv"),
        reportInterval=nstxout,
        integrator=integrator,
        context=simulation.context
    )
)

simulation.step(total_steps)
print("Simulation finished!")