import sys
import os
import numpy as np
from openmm.app import *
from openmm.unit import *
from utils.plotter import compute_dihedral
from simtk.openmm import LangevinIntegrator
sys.path.append('/c5/shared/pymol/1.7.0.0-python-2.7.5-shared/lib/python2.7/site-packages/')

import __main__
__main__.pymol_argv = ['pymol','-qc'] # Pymol: quiet and no GUI
import pymol
pymol.finish_launching()

from pymol import cmd
import argparse

phi_indices = [4, 6, 8, 14]
psi_indices = [6, 8, 14, 16] 

parser = argparse.ArgumentParser(description="My parser")
parser.add_argument('--solvent', action='store_true')
parser.add_argument('--no-solvent', dest='solvent', action='store_false')
parser.set_defaults(solvent=False)

args = parser.parse_args()

solvent = args.solvent

def color_by_element():
    """
    Colors atoms in the current PyMOL session based on their element types.
    """
    # Color by element
    cmd.color('purple', 'elem N')       # Nitrogen atoms
    cmd.color('red', 'elem O')        # Oxygen atoms
    cmd.color('green', 'elem S')      # Sulfur atoms
    cmd.color('orange', 'elem P')     # Phosphorus atoms
    cmd.color('gray', 'elem C')       # Carbon atoms
    cmd.color('blue', 'elem H')     # Hydrogen atoms

water_mod = "water" if solvent else "nowater"

if solvent:
    forcefield = ForceField("amber99sbnmr.xml", "amber99_obc.xml")
else:
    forcefield = ForceField("amber99sbnmr.xml")#, "amber99_obc.xml")
pdb = PDBFile('molecule_files/alanine-dipeptide.pdb')
system = forcefield.createSystem(pdb.topology, nonbondedCutoff=3*nanometer, constraints=HBonds)
integrator = LangevinIntegrator(300, 1, 0.002)

# Create the simulation
simulation = Simulation(pdb.topology, system, integrator)

def plotModes(results_dir, image_dir):
    os.makedirs(f"{results_dir}/modes", exist_ok=True)
    os.makedirs(f"{image_dir}", exist_ok=True)
    pdb_file = f"{results_dir}/modes.pdb"
    try:
        pdb = PDBFile(pdb_file)
    except:
        return
    phis, psis = [], []
    for j in range(pdb.getNumFrames()):
        with open(f'{results_dir}/modes/mode_{j}.pdb', 'w') as f:
            PDBFile.writeHeader(pdb.topology, f)
            pos = pdb.getPositions(asNumpy=True, frame=j)
            
            simulation.context.setPositions(pos)
            state = simulation.context.getState( getEnergy=True)
            energy = state.getPotentialEnergy()
            print(j, energy)

            phi, psi = compute_dihedral(pos, phi_indices), compute_dihedral(pos, psi_indices)
            phis.append(phi)
            psis.append(psi)
            PDBFile.writeModel(pdb.topology, pos, file=f, modelIndex=0)
            PDBFile.writeFooter(pdb.topology, f)

    pdb_name ='mode'
    for j in range(pdb.getNumFrames()):
        cmd.load(f"molecule_files/alanine-dipeptide.pdb", "alanine")
        cmd.orient("alanine")
        cmd.load(f"{results_dir}/modes/mode_{j}.pdb", f"{pdb_name}_{j}")
        cmd.enable("alanine")
        cmd.align(f"{pdb_name}_{j}", "alanine")
        cmd.center(f"{pdb_name}_{j}")
        cmd.zoom(f"{pdb_name}_{j}")
        print(cmd.get_names())
        color_by_element()
        # cmd.disable("all")
        # cmd.enable(pdb_name)
        # cmd.hide('all')
        # cmd.show('cartoon')
        cmd.set('ray_opaque_background', 0)
        cmd.disable("alanine")
        cmd.png(f"{image_dir}/{pdb_name}_{j}_{phis[j]:.2f}_{psis[j]:.2f}.png")
        cmd.delete( f"{pdb_name}_{j}")
        cmd.delete("alanine")

elems = {"Lie": 3, "TrueLie": 1, "Opt": 1}

for method in ["Lie", "TrueLie", "Opt"]:
    if method == "Opt":
        mods = ["_0.5_0.01", "_0.1_0.01"]
    else:
        mods = [""]
    for idx in range(elems[method]):
        for mod in mods:
            plotModes(f"image_results/alanine-dipeptide_{water_mod}/min_op_{method}_traj_{idx+1}_1{mod}", f"image_results/alanine-dipeptide_{water_mod}/images/min_op_{method}_traj_{idx+1}_1{mod}")

plotModes(f"image_results/alanine-dipeptide_{water_mod}/Sim", f"image_results/alanine-dipeptide_{water_mod}/images/Sim")

if solvent:
    plotModes(f"image_results/alanine-dipeptide_water_water/min_op_Opt_traj_1_1_0.1_0.01", f"image_results/alanine-dipeptide_water_water/images/min_op_Opt_traj_1_1_0.1_0.01") 
    plotModes(f"image_results/alanine-dipeptide_water_water/min_op_Opt_traj_1_1_0.5_0.01", f"image_results/alanine-dipeptide_water_water/images/min_op_Opt_traj_1_1_0.5_0.01") 

cmd.quit()