import mdshare
import numpy as np
import mdtraj as md
import time
from simtk.openmm.app import Simulation, PDBFile

from simtk.openmm import LangevinIntegrator
from simtk.openmm.app import ForceField

import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

from openmm.app import *
from openmm import *
from openmm.unit import *

from energy_landscape import find_local_peaks
from scipy.spatial.distance import cdist
import collections, numpy

from utils.plotter import compute_dihedral
import argparse

def find_modes(XY, Z):
    nGrid, _ = Z.shape
    twodinds = find_local_peaks(Z, l=10)
    inds = [x*nGrid + y for x,y in twodinds]
    print(inds)
    # inds = (Z == filters.maximum_filter(Z, 10)).flatten()
    modes = XY[:, inds].T
    return modes

vaccum_conformers = np.array([[-65, 60], [-130.4, 151.6], [62.3, -57.4]])
solvent_conformers = np.array([[-58.3, 148.98], [61.6, 37.6], [49.3, 50.0], [-56.5, -34.7], [-90.6, -7.8]])

parser = argparse.ArgumentParser(description="My parser")
parser.add_argument('--solvent', action='store_true')
parser.add_argument('--no-solvent', dest='solvent', action='store_false')
parser.set_defaults(solvent=False)

args = parser.parse_args()

solvent = args.solvent
conformers = solvent_conformers if solvent else vaccum_conformers

# Load the topology file (e.g., a PDB file)
pdb = PDBFile('molecule_files/alanine-dipeptide.pdb')

if solvent:
    xtc = md.load('molecule_files/alanine-dipeptide-water_traj_500ns.dcd', top='molecule_files/alanine-dipeptide.pdb')
    out_dir = "image_results/alanine-dipeptide_water/Sim/"
else:
    xtc = md.load('molecule_files/alanine-dipeptide_traj_500ns.dcd', top='molecule_files/alanine-dipeptide.pdb')
    out_dir = "image_results/alanine-dipeptide_nowater/Sim/"

os.makedirs(out_dir, exist_ok=True)

# Load the XTC file using MDTraj
mod = '500ns'

phi_indices = [4, 6, 8, 14]  # N, CA, C, N+1
psi_indices = [6, 8, 14, 16]  # CA, C, N+1, CA+1

# Extract the coordinates from MDTraj and convert to OpenMM format
if solvent:
    # forcefield = ForceField('amber14-all.xml', 'amber14/tip3p.xml')
    forcefield = ForceField("amber99sbnmr.xml", "amber99_obc.xml")
else:
    forcefield = ForceField("amber99sbnmr.xml")#, "amber99_obc.xml") #'amber14-all.xml', 'amber14/tip3p.xml')
system = forcefield.createSystem(pdb.topology, nonbondedCutoff=3 * nanometer, constraints=HBonds)
integrator = LangevinIntegrator(300, 1, 0.002)

# Create the simulation
simulation = Simulation(pdb.topology, system, integrator)

# Calculate phi and psi dihedral angles
phi_angle = md.compute_dihedrals(xtc, [phi_indices])
psi_angle = md.compute_dihedrals(xtc, [psi_indices])

# Convert radians to degrees for easier interpretation
phi_angle_degrees = np.degrees(phi_angle)
psi_angle_degrees = np.degrees(psi_angle)

print(phi_angle_degrees.shape)

x, y = phi_angle_degrees[:,0], psi_angle_degrees[:,0]
xy = np.vstack((x, y))
print(xy.shape)

print(phi_angle_degrees.shape)

energies = []
period = 1
N = phi_angle_degrees.shape[0]
indices = np.arange(0, N, period)//period

for i in indices:
    positions = xtc.openmm_positions(i)
    simulation.context.setPositions(positions)
    state = simulation.context.getState(getPositions=True, getEnergy=True)
    positions = state.getPositions(asNumpy=True)
    energy = state.getPotentialEnergy()
    energy = energy.value_in_unit(energy.unit)
    energies.append(energy)

fig, ax = plt.subplots(figsize=(20,20))
cax = plt.scatter(phi_angle_degrees[indices,0], psi_angle_degrees[indices,0], c=energies, cmap='plasma')
cbar = fig.colorbar(cax)
cbar.ax.tick_params(labelsize=25)
plt.xlim(-180, 180)
plt.ylim(-180, 180)
ax.tick_params(axis='both', which='major', labelsize=30)
ax.tick_params(axis='both', which='minor', labelsize=25)
fig.tight_layout()
plt.savefig(f"{out_dir}/Ramachandran{mod}Plot.png")

kde = gaussian_kde(xy)

xmin, xmax = -180, 180
ymin, ymax = -180, 180
xgrid = np.linspace(xmin, xmax, 50)
ygrid = np.linspace(ymin, ymax, 50)
X, Y = np.meshgrid(xgrid, ygrid)
positions = np.vstack([X.ravel(), Y.ravel()])

# Evaluate the KDE on the grid
Z = kde(positions).reshape(X.shape)
np.savez(f'{out_dir}/ramachandran{mod}.npz', X=X, Y=Y, Z=Z)

# find_modes(positions, Z, lambda num: xtc.openmm_positions(num), simulation, N, out_dir, 100, xy, np.array([[v] for v in indices]).T, period)
modes = find_modes(positions, Z)
d = cdist(xy[:, indices].T, modes)
clusterInd = np.argmin(d, axis=1)
clusterDist = np.min(d, axis=1)
counter = collections.Counter(clusterInd)
print(counter)

thres = 30.0
clusters = [np.array(energies) for _ in range(len(modes))]
for j in range(len(modes)):
    clusters[j][clusterInd != j] = 0.0
    clusters[j][clusterDist >= thres] = 0.0
modeInds = [np.argmin(value) for value in clusters]
print(modes, modeInds)

Menergy = []
Mpos = []
Mxy = []
for ind in modeInds:
    i = indices[ind]
    positions = xtc.openmm_positions(i)
    simulation.context.setPositions(positions)
    simulation.minimizeEnergy(tolerance=0.1)
    state = simulation.context.getState(getPositions=True, getEnergy=True)
    positions = state.getPositions(asNumpy=True)
    energy = state.getPotentialEnergy()
    energy = energy.value_in_unit(energy.unit)
    print(energy, xy[:,i])
    Mpos.append(positions)
    Mxy.append([compute_dihedral(positions, phi_indices), compute_dihedral(positions, psi_indices)])
    Menergy.append(energy)
Menergy, Mpos, Mxy = np.array(Menergy), np.array(Mpos), np.array(Mxy)

dists = np.min(cdist(Mxy,conformers), axis=1)
print(dists)
Mpos = Mpos[dists < thres, ::]
Menergy = Menergy[dists < thres]
Mxy = Mxy[dists < thres, ::]

print(Menergy, Mxy)

os.makedirs(out_dir, exist_ok=True)
filename = f'{out_dir}/modes.pdb'
with open(filename, 'w') as f:
    PDBFile.writeHeader(pdb.topology, f)
    for idx, pos in enumerate(Mpos):
        PDBFile.writeModel(pdb.topology, 10*pos, file=f, modelIndex=idx)
    PDBFile.writeFooter(pdb.topology, f)

fig, ax = plt.subplots(figsize=(20,20))
cf = plt.contourf(X, Y, np.log(1./N + Z), levels=25, cmap='turbo', extend='neither', alpha=0.7)
cbar = fig.colorbar(cf, ax=ax, label='Density')
# fig.colorbar(cbar)
ax.tick_params(axis='both', which='major', labelsize=20)
ax.tick_params(axis='both', which='minor', labelsize=14)

startpos = pdb.getPositions(frame=0, asNumpy=True)
simulation.context.setPositions(startpos)
simulation.minimizeEnergy(tolerance=0.1)
state = simulation.context.getState(getPositions=True)
startpos = state.getPositions(asNumpy=True)
startpos = startpos.value_in_unit(startpos.unit)
startx, starty = compute_dihedral(startpos, phi_indices), compute_dihedral(startpos, psi_indices)

ax.scatter(Mxy[:, 0], Mxy[:, 1], marker='o', s=400, c='black')
ax.scatter([startx], [starty], marker='*', s=1000, c='yellow')
ax.set_xlim([-180, 180])
ax.set_ylim([-180, 180])
ax.tick_params(axis='both', which='major', labelsize=30)
ax.tick_params(axis='both', which='minor', labelsize=25)
plt.savefig(f"{out_dir}/RamachandranPlotContour{mod}.png")