import numpy as np
import pandas as pd
from openmm.app import *
from openmm import *
from openmm.unit import *
from simtk.openmm.app import Simulation, PDBFile
import mdtraj as md

from simtk.openmm import LangevinIntegrator
from simtk.openmm.app import ForceField

import matplotlib.pyplot as plt

from utils.plotter import compute_dihedral, full_cyclic_break
from scipy.spatial.distance import cdist
import argparse


phi_indices = [4, 6, 8, 14]
psi_indices = [6, 8, 14, 16]  

solvent_conformers, solvent_names = np.array([[-58.3, 148.98], [61.6, 37.6], [-56.5, -34.7]]), [r'$\beta$', r'$\alpha_L$',r'$\alpha_R$']
vacuum_conformers, vacuum_names = np.array([[-130.4, 151.6], [-65, 60], [62.3, -57.4]]),  [r'$C5$', r'$C7_{eq}$',r'$C7_{ax}$']

def plotDensitywithConformers(trajectory_data, pdb_file, image_file, image_size=(10,10), solvent=False):
    pdb = PDBFile(pdb_file)
    arr = np.load(trajectory_data)

    N = arr.shape[0]

    if solvent:
        forcefield = ForceField("amber99sbnmr.xml", "amber99_obc.xml")
        mod='water'
        conformers, names = solvent_conformers, solvent_names
    else:
        forcefield = ForceField("amber99sbnmr.xml")#, "amber99_obc.xml")
        mod='nowater'
        conformers, names = vacuum_conformers, vacuum_names

    system = forcefield.createSystem(pdb.topology, nonbondedCutoff=3*nanometer, constraints=HBonds)
    integrator = LangevinIntegrator(300, 1, 0.002)

    # Create the simulation
    simulation = Simulation(pdb.topology, system, integrator)

    os.makedirs("/".join(image_file.split("/")[:-1]), exist_ok=True)
    os.makedirs(image_file, exist_ok=True)

    try:
        density_arr = np.load(f'{image_file}/ramachandran{mod}.npz')
        X, Y, Z = density_arr['X'], density_arr['Y'], density_arr['Z']
    except:
        print("Checkpoint not found")

    filename = f'{image_file}/modes.pdb'
    modes = PDBFile(filename)

    numModes = modes.getNumFrames()
    Mpos, Menergy, Mxy = [], [], []

    for i in range(numModes):
        pos = modes.getPositions(frame=i, asNumpy=True)
        simulation.context.setPositions(pos)
        simulation.minimizeEnergy(tolerance=0.1)
        state = simulation.context.getState(getPositions=True, getEnergy=True)
        positions = state.getPositions(asNumpy=True)
        positions = positions.value_in_unit(positions.unit)
        energy = state.getPotentialEnergy()
        energy = energy.value_in_unit(energy.unit)
        Mpos.append(positions)
        Mxy.append((compute_dihedral(pos, phi_indices),compute_dihedral(pos, psi_indices)))

    Mpos, Menergy, Mxy = np.array(Mpos), np.array(Menergy), np.array(Mxy)
    print(Menergy, Mxy)

    Z = Z.reshape(X.shape)

    fig, ax = plt.subplots(figsize=image_size)
    ax.contourf(X, Y, np.log(1./N + Z), levels=25, cmap='turbo', extend='neither', alpha=0.7)
    x, y = arr[..., 1+21+36+14]*(180/np.pi), arr[..., 1+21+36+24]*(180/np.pi)
    dim = len(x.shape)

    period = 3
    if dim == 1:
        abs = np.sqrt((np.diff(x))**2 + (np.diff(y))**2)
        mask = np.hstack([abs > abs.mean()+2*abs.std(), [False]])
        masked_X= np.ma.MaskedArray(x, mask)
        masked_Y= np.ma.MaskedArray(y, mask)
        ax.plot(masked_X, masked_Y)
    else:
        for xRow, yRow in zip(x, y):
            chunks = full_cyclic_break([list(xRow), list(yRow)])
            for chunk in chunks:
                masked_X, masked_Y = chunk[0], chunk[1]
                ax.plot(masked_X, masked_Y, linestyle='-', linewidth=1.8, color='blue', alpha=0.8)
        for xRow, yRow in zip(x.T, y.T):
            chunks = full_cyclic_break([list(xRow), list(yRow)])
            for chunk in chunks:
                masked_X, masked_Y = chunk[0], chunk[1]
                ax.plot(masked_X, masked_Y, linestyle='-', linewidth=1.8, color='crimson', alpha=0.8)
    startpos = pdb.getPositions(frame=0, asNumpy=True)
    startpos = startpos.value_in_unit(startpos.unit)
    startx, starty = compute_dihedral(startpos, phi_indices), compute_dihedral(startpos, psi_indices)

    bbox = dict(boxstyle ="round", fc ="0.8") 
    arrowprops = dict( 
        arrowstyle = "->", 
        connectionstyle = "angle, angleA = 0, angleB = 270, rad = 80") 
    
    offset = 20
    
    # Annotation 
    ax.annotate('Starting Point', (startx, starty), xytext =(2*offset, -5*offset), 
                textcoords ='offset points', fontsize=60,
                bbox = bbox, arrowprops = arrowprops)

    ax.scatter([startx], [starty], marker='*', s=8000, c='yellow')

    arrowprops = dict( 
        arrowstyle = "->", 
        connectionstyle = "angle, rad = 80") 
    

    d = np.argmin(cdist(Mxy, conformers), axis=1)
    mode_names = [names[ind] for ind in d]
    print(d, mode_names)

    for name, (mx, my) in zip(mode_names, Mxy):
        ax.scatter(mx, my, marker='o', s=2000, c='black')
        ax.annotate(f'{name}', (mx, my), xytext =(2*offset, offset), 
                textcoords ='offset points', fontsize=60,
                bbox = bbox, arrowprops = arrowprops)


    ax.set_xlim([-180, 180])
    ax.set_ylim([-180, 180])
    ax.tick_params(axis='both', which='major', labelsize=30)
    ax.tick_params(axis='both', which='minor', labelsize=25)
    fig.tight_layout()
    plt.savefig(f'{image_file}/contour.pdf')
    plt.close()

def plotSimDensity(solvent):

    # Load the topology file (e.g., a PDB file)
    pdb = PDBFile('molecule_files/alanine-dipeptide.pdb')

    if solvent:
        xtc = md.load('molecule_files/alanine-dipeptide-water_traj_500ns.dcd', top='molecule_files/alanine-dipeptide.pdb')
        # xtc = md.load('alanine-dipeptide_water_traj_500ns.dcd', top='molecule_files/alanine-dipeptide.pdb')
        out_dir = "image_results/alanine-dipeptide_water/Sim/"
        forcefield = ForceField("amber99sbnmr.xml", "amber99_obc.xml")
        conformers, names = solvent_conformers, solvent_names
    else:
        xtc = md.load('molecule_files/alanine-dipeptide_traj_500ns.dcd', top='molecule_files/alanine-dipeptide.pdb')
        out_dir = "image_results/alanine-dipeptide_nowater/Sim/"
        forcefield = ForceField("amber99sbnmr.xml")#, "amber99_obc.xml") #'amber14-all.xml', 'amber14/tip3p.xml')
        conformers, names = vacuum_conformers, vacuum_names

    os.makedirs(out_dir, exist_ok=True)

    # Load the XTC file using MDTraj
    mod = '500ns'

    system = forcefield.createSystem(pdb.topology, nonbondedCutoff=3 * nanometer, constraints=HBonds)
    integrator = LangevinIntegrator(300, 1, 0.002)

    # Create the simulation
    simulation = Simulation(pdb.topology, system, integrator)

    # Calculate phi and psi dihedral angles
    phi_angle = md.compute_dihedrals(xtc, [phi_indices])
    psi_angle = md.compute_dihedrals(xtc, [psi_indices])

    # Convert radians to degrees for easier interpretation
    phi_angle_degrees = np.degrees(phi_angle)
    psi_angle_degrees = np.degrees(psi_angle)

    x, y = phi_angle_degrees[:,0], psi_angle_degrees[:,0]
    xy = np.vstack((x, y))
    print(xy.shape)
    print(phi_angle_degrees.shape)

    N = phi_angle_degrees.shape[0]
    try:
        arr = np.load(f'{out_dir}/ramachandran{mod}.npz')
        X, Y, Z = arr['X'], arr['Y'], arr['Z']
        positions = np.vstack([X.ravel(), Y.ravel()])
    except:
        print("Checkpoint not found.")

    filename = f'{out_dir}/modes.pdb'
    modes = PDBFile(filename)

    numModes = modes.getNumFrames()

    Menergy = []
    Mpos = []
    Mxy = []
    for i in range(numModes):
        positions = modes.getPositions(frame=i, asNumpy=True)
        simulation.context.setPositions(positions)
        simulation.minimizeEnergy(tolerance=0.1)
        state = simulation.context.getState(getPositions=True, getEnergy=True)
        positions = state.getPositions(asNumpy=True)
        energy = state.getPotentialEnergy()
        energy = energy.value_in_unit(energy.unit)

        print(energy, xy[:,i])
        Mpos.append(positions)
        Mxy.append([compute_dihedral(positions, phi_indices), compute_dihedral(positions, psi_indices)])
        Menergy.append(energy)
    Menergy, Mpos, Mxy = np.array(Menergy), np.array(Mpos), np.array(Mxy)
    print(Menergy, Mxy)

    fig, ax = plt.subplots(figsize=(20,20))
    plt.contourf(X, Y, np.log(1./N + Z), levels=25, cmap='turbo', extend='neither', alpha=0.7)
    # fig.colorbar(cbar)
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.tick_params(axis='both', which='minor', labelsize=14)

    startpos = pdb.getPositions(frame=0, asNumpy=True)
    simulation.context.setPositions(startpos)
    simulation.minimizeEnergy(tolerance=0.1)
    state = simulation.context.getState(getPositions=True)
    startpos = state.getPositions(asNumpy=True)
    startpos = startpos.value_in_unit(startpos.unit)
    startx, starty = compute_dihedral(startpos, phi_indices), compute_dihedral(startpos, psi_indices)

    bbox = dict(boxstyle ="round", fc ="0.8") 
    arrowprops = dict( 
        arrowstyle = "->", 
        connectionstyle = "angle, angleA = 0, angleB = 270, rad = 80") 
    
    offset = 20
    
    # Annotation 
    ax.annotate('Starting Point', (startx, starty), xytext =(2*offset, -5*offset), 
                textcoords ='offset points', fontsize=60,
                bbox = bbox, arrowprops = arrowprops)

    ax.scatter([startx], [starty], marker='*', s=8000, c='yellow')

    arrowprops = dict( 
        arrowstyle = "->", 
        connectionstyle = "angle, angleA = 0, angleB = 90, rad = 80") 

    d = np.argmin(cdist(Mxy, conformers), axis=1)
    mode_names = [names[ind] for ind in d]
    print(d, mode_names)

    for name, (mx, my) in zip(mode_names, Mxy):
        ax.scatter(mx, my, marker='o', s=2000, c='black')
        ax.annotate(f'{name}', (mx, my), xytext =(2*offset, offset), 
                textcoords ='offset points', fontsize=60,
                bbox = bbox, arrowprops = arrowprops)

    ax.scatter([startx], [starty], marker='*', s=8000, c='yellow')
    ax.scatter(Mxy[:, 0], Mxy[:, 1], marker='o', s=2000, c='black')

    ax.set_xlim([-180, 180])
    ax.set_ylim([-180, 180])
    ax.tick_params(axis='both', which='major', labelsize=30)
    ax.tick_params(axis='both', which='minor', labelsize=25)
    fig.tight_layout()

    plt.savefig(f"{out_dir}/RamachandranPlotContour{mod}.pdf")

if __name__=="__main__":

    parser = argparse.ArgumentParser(description="My parser")
    parser.add_argument('--solvent', action='store_true')
    parser.add_argument('--no-solvent', dest='solvent', action='store_false')
    parser.set_defaults(solvent=False)

    args = parser.parse_args()

    solvent = args.solvent
    mod = 'water' if solvent else 'nowater' 
    top_K = 1


    for i in range(3):
        if solvent:
            state_types = [f'alanine-dipeptide_run_{i}', f'alanine-dipeptide_water_run_{i}']
        else:
            state_types = [f'alanine-dipeptide_run_{i}']
        for state_type in state_types:
            hessian_type = 'min_op'
            mutation_type = 'Opt'

            eps = (0.1, 0.01)
            for ind in range(1):
                for basis_idx in range(top_K):
                    plotDensitywithConformers(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', image_size=(20, 20), solvent=solvent)
            
            eps = (0.5, 0.01)
            for ind in range(1):
                for basis_idx in range(top_K):
                    plotDensitywithConformers(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', image_size=(20, 20), solvent=solvent)

    plotSimDensity(solvent)

    for state_type in ['alanine-dipeptide']:

        hessian_type = 'min_op'
        mutation_type = 'Opt'

        eps = (0.1, 0.01)
        for ind in range(1):
            for basis_idx in range(top_K):
                plotDensitywithConformers(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', image_size=(20, 20), solvent=solvent)
        
        eps = (0.5, 0.01)
        for ind in range(1):
            for basis_idx in range(top_K):
                plotDensitywithConformers(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', image_size=(20, 20), solvent=solvent)

        hessian_type = 'min_op'
        mutation_type = 'TrueLie'

        for ind in range(1):
            for basis_idx in range(top_K):
                plotDensitywithConformers(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}', image_size=(20, 20), solvent=solvent)

        hessian_type = 'min_op'
        mutation_type = 'Lie'

        for ind in range(3):
            for basis_idx in range(top_K):
                plotDensitywithConformers(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}', image_size=(20, 20), solvent=solvent)
    
    if solvent:
        for state_type in ['alanine-dipeptide_water']:

            hessian_type = 'min_op'
            mutation_type = 'Opt'

            eps = (0.1, 0.01)
            for ind in range(1):
                for basis_idx in range(top_K):
                    plotDensitywithConformers(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', image_size=(20, 20), solvent=solvent)

            eps = (0.5, 0.01)
            for ind in range(1):
                for basis_idx in range(top_K):
                    plotDensitywithConformers(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', image_size=(20, 20), solvent=solvent)