import numpy as np
import pandas as pd
from openmm.app import *
from openmm import *
from openmm.unit import *
from simtk.openmm.app import Simulation, PDBFile

from simtk.openmm import LangevinIntegrator
from simtk.openmm.app import ForceField

from scipy.stats import gaussian_kde
import scipy.ndimage.filters as filters
from scipy.spatial.distance import cdist
from scipy.special import logsumexp

import matplotlib.pyplot as plt
import os
import argparse

from utils.plotter import compute_dihedral, full_cyclic_break


vaccum_conformers = np.array([[-65, 60], [-130.4, 151.6], [62.3, -57.4]])
solvent_conformers = np.array([[-58.3, 148.98], [61.6, 37.6], [49.3, 50.0], [-56.5, -34.7], [-90.6, -7.8]])

def is_peak(matrix, x, y, l=1):
    """Checks if the matrix[x][y] is a local peak."""
    rows, cols = matrix.shape
    
    # Get the value at (x, y)
    current_value = matrix[x][y]

    min_x, max_x = max(0, x-l), min(x+l, rows-1)
    min_y, max_y = max(0, y-l), min(y+l, cols-1)
    
    for new_x in range(min_x, max_x+1):
        for new_y in range(min_y, max_y+1):
            if matrix[new_x][new_y] > current_value:
                return False
    
    return True

def find_local_peaks(matrix, l=1):
    """Finds all local peaks in the matrix."""
    rows, cols = matrix.shape
    peaks = []
    
    for i in range(rows):
        for j in range(cols):
            if is_peak(matrix, i, j, l=l):
                peaks.append((i, j))
                
    return peaks

def dist(sx, sy, tx, ty):
    return np.sqrt((sx-tx)**2 + (sy-ty)**2)

def get_coords(sourcex, sourcey, targetx, targety):
    coords = []

    for sx, sy, tx, ty in zip(sourcex, sourcey, targetx, targety):
        tx_mod = tx + 360*(sx - tx)/abs(sx - tx)
        ty_mod = ty + 360*(sy - ty)/abs(sy - ty)

        posses = [(tx, ty), (tx_mod, ty), (tx, ty_mod), (tx_mod, ty_mod)]

        best_ind = np.argmin([dist(sx, sy, *target) for target in posses])

        coords.append(((sx, sy), posses[best_ind]))
        if best_ind != 0:
            txs, tys = posses[best_ind]
            coords.append(((sx + tx - txs , sy + ty - tys), (tx, ty)))
    return coords

phi_indices = [4, 6, 8, 14]
psi_indices = [6, 8, 14, 16]  


class EnergyLandscape:
    def __init__(self):
        self.Mode_pos=[]
        self.Mode_energies=[]
        self.Mode_xy = []

    def optimize_sim(self, pdb, simulation, N, period, steps, offset=0, dim=1, solvent=False):
        log_period = steps//10
        Indices = []
        Optx, Opty, Energies = [], [], []

        for index in range(0, N, period):
            if dim == 1:
                simulation.context.setPositions(pdb.getPositions(frame=offset*N + index))
                simulation.minimizeEnergy(tolerance=0.1)
                state = simulation.context.getState(getPositions=True, getEnergy=True)
                positions = state.getPositions(asNumpy=True)
                positions = positions.value_in_unit(positions.unit)
                energy = state.getPotentialEnergy()
                energy = energy.value_in_unit(energy.unit)

                mode_pos, best_energy = positions, energy
                best_xy = (compute_dihedral(positions, phi_indices), compute_dihedral(positions, psi_indices))
                simulation.reporters.append(PDBReporter('temp.pdb', log_period))
                simulation.reporters.append(StateDataReporter('temp.csv', log_period, potentialEnergy=True))
                if energy > 100:
                    continue
                try:
                    simulation.step(steps)
                except:
                    simulation.reporters.pop()
                    simulation.reporters.pop()
                    continue
                simulation.reporters.pop()
                simulation.reporters.pop()
                temppdb = PDBFile('temp.pdb')
                tempenergies = pd.read_csv('temp.csv')
                tempenergies = tempenergies[tempenergies.columns[-1]]
                optx, opty, energies = [], [], []
                N2 = temppdb.getNumFrames()
                m2 = steps//log_period
                for ind in range(m2):
                    positions = temppdb.getPositions(asNumpy=True, frame=(N2-m2)+ind)
                    positions = positions.value_in_unit(positions.unit)
                    energy = tempenergies[ind]
                    optx.append(compute_dihedral(positions, phi_indices))
                    opty.append(compute_dihedral(positions, psi_indices))
                    energies.append(energy)
                    if energy < best_energy:
                        dists = np.array([0.5*(abs(xy[0] - best_xy[0]) + abs(xy[1] - best_xy[1])) for xy in self.Mode_xy])
                        dist = np.min(dists, initial=30.0)
                        if dist < 20.0:
                            dind = np.argmin(dists)
                            if self.Mode_energies[dind] > best_energy:
                                self.Mode_energies[dind] = best_energy
                                self.Mode_pos[dind] = mode_pos
                                self.Mode_xy[dind] = best_xy
                        elif dist > 20.0:
                            self.Mode_pos.append(mode_pos)
                            self.Mode_energies.append(best_energy)
                            self.Mode_xy.append(best_xy)
                        best_energy = energy
                        mode_pos = positions
                        best_xy = (optx[-1], opty[-1])
                dists = np.array([0.5*(abs(xy[0] - best_xy[0]) + abs(xy[1] - best_xy[1])) for xy in self.Mode_xy])
                dist = np.min(dists, initial=30.0)
                if dist < 5.0:
                    dind = np.argmin(dists)
                    if self.Mode_energies[dind] > best_energy:
                        self.Mode_energies[dind] = best_energy
                        self.Mode_pos[dind] = mode_pos
                        self.Mode_xy[dind] = best_xy
                elif dist > 20.0:
                    self.Mode_pos.append(mode_pos)
                    self.Mode_energies.append(best_energy)
                    self.Mode_xy.append(best_xy)

                # Mode_pos.append(mode_pos), Mode_energies.append(best_energy), Mode_xy.append(best_xy)
                Optx.append(optx), Opty.append(opty), Energies.append(energies)
                Indices.append(tuple([index//period]))
            else:
                optx, opty, energies, indices = self.optimize_sim(pdb, simulation, N, period, steps, offset=offset*N+index, dim=dim-1, solvent=solvent)
                Optx.extend(optx)
                Opty.extend(opty)
                Energies.extend(energies)
                Indices.extend([tuple([index//period, *ind]) for ind in indices])
        return Optx, Opty, Energies, Indices

    def plot_energy_landscape(self, trajectory_data, pdb_file, image_file, period = 10, steps=1000, image_size = (10, 10), solvent=False,):
        print(image_file)
        pdb = PDBFile(pdb_file)
        arr = np.load(trajectory_data)

        N = arr.shape[0]
        m = pdb.getNumFrames()

        if solvent:
            forcefield = ForceField("amber99sbnmr.xml", "amber99_obc.xml")
        else:
            forcefield = ForceField("amber99sbnmr.xml")#, "amber99_obc.xml")
        system = forcefield.createSystem(pdb.topology, nonbondedCutoff=3*nanometer, constraints=HBonds)
        integrator = LangevinIntegrator(300, 1, 0.002)

        # Create the simulation
        simulation = Simulation(pdb.topology, system, integrator)

        os.makedirs("/".join(image_file.split("/")[:-1]), exist_ok=True)
        os.makedirs(image_file, exist_ok=True)

        xshape = np.meshgrid(*[np.arange(0,N,period) for _ in range(len(arr.shape) - 1)])
        x, y = arr[..., 1+21+36+11]*(180/np.pi), arr[..., 1+21+36+19]*(180/np.pi)
        x, y = x[xshape], y[xshape]
        print(x.shape, y.shape)
        optx, opty, energies, indices = self.optimize_sim(pdb, simulation, N, period, steps, offset=0, dim=len(arr.shape) - 1, solvent=solvent)
        optx, opty, energies, indices = np.array(optx), np.array(opty), np.array(energies), np.array([list(val) for val in zip(*indices)])
        Mpos, Menergy, Mxy  = np.array(self.Mode_pos), np.array(self.Mode_energies), np.array(self.Mode_xy)

        dists = np.min(cdist(Mxy,conformers), axis=1)
        thres = 20.0
        Mpos = Mpos[dists < thres, ::]
        Menergy = Menergy[dists < thres]
        Mxy = Mxy[dists < thres, ::]

        optx, opty = optx.flatten(), opty.flatten()
        fig, ax = plt.subplots(figsize=(20,20))
        cax = plt.scatter(optx, opty, c=energies, cmap='plasma')
        cbar = fig.colorbar(cax)
        cbar.ax.tick_params(labelsize=25)
        plt.xlim(-180, 180)
        plt.ylim(-180, 180)
        ax.tick_params(axis='both', which='major', labelsize=30)
        ax.tick_params(axis='both', which='minor', labelsize=25)
        fig.tight_layout()
        plt.savefig(f"{image_file}/Ramachandran{mod}Plot.png")

        optxy = np.vstack((optx, opty))
        kde = gaussian_kde(optxy)

        xmin, xmax = -180, 180
        ymin, ymax = -180, 180
        nGrid = 100
        xgrid = np.linspace(xmin, xmax, nGrid)
        ygrid = np.linspace(ymin, ymax, nGrid)
        X, Y = np.meshgrid(xgrid, ygrid)
        positions = np.vstack([X.ravel(), Y.ravel()])
        # Evaluate the KDE on the grid
        Z = kde(positions).reshape(nGrid,nGrid)
        np.savez(f'{image_file}/ramachandran{mod}.npz', X=X, Y=Y, Z=Z)
        # find_modes(positions, Z, lambda num: pdb.getPositions(frame=num), simulation, N, image_file, steps, optxy, indices, period)

        filename = f'{image_file}/modes.pdb'
        with open(filename, 'w') as f:
            PDBFile.writeHeader(pdb.topology, f)
            for idx, pos in enumerate(Mpos):
                PDBFile.writeModel(pdb.topology, 10*pos, file=f, modelIndex=idx)
            PDBFile.writeFooter(pdb.topology, f)
        print(Menergy, Mxy)
        
        Z = Z.reshape(X.shape)

        fig, ax = plt.subplots(figsize=image_size)
        cf = ax.contourf(X, Y, np.log(1./N + Z), levels=25, cmap='turbo', extend='neither', alpha=0.7)
        # cbar = fig.colorbar(cf, ax=ax, label='Density')

        # x, y = x[indices].flatten(), y[indices].flatten()
        x, y = arr[..., 1+21+36+14]*(180/np.pi), arr[..., 1+21+36+24]*(180/np.pi)
        dim = len(x.shape)

        period = 3
        if dim == 1:
            abs = np.sqrt((np.diff(x))**2 + (np.diff(y))**2)
            mask = np.hstack([abs > abs.mean()+2*abs.std(), [False]])
            masked_X= np.ma.MaskedArray(x, mask)
            masked_Y= np.ma.MaskedArray(y, mask)
            ax.plot(masked_X, masked_Y)
        else:
            for xRow, yRow in zip(x, y):
                chunks = full_cyclic_break([list(xRow), list(yRow)])
                for chunk in chunks:
                    masked_X, masked_Y = chunk[0], chunk[1]
                    ax.plot(masked_X, masked_Y, linestyle='-', linewidth=1.8, color='blue', alpha=0.8)
            for xRow, yRow in zip(x.T, y.T):
                chunks = full_cyclic_break([list(xRow), list(yRow)])
                for chunk in chunks:
                    masked_X, masked_Y = chunk[0], chunk[1]
                    ax.plot(masked_X, masked_Y, linestyle='-', linewidth=1.8, color='crimson', alpha=0.8)
        startpos = pdb.getPositions(frame=0, asNumpy=True)
        startpos = startpos.value_in_unit(startpos.unit)
        startx, starty = compute_dihedral(startpos, phi_indices), compute_dihedral(startpos, psi_indices)
        ax.scatter(Mxy[:, 0], Mxy[:, 1], marker='o', s=400, c='black')
        ax.scatter([startx], [starty], marker='*', s=1000, c='yellow')
        
        ax.set_xlim([-180, 180])
        ax.set_ylim([-180, 180])
        ax.tick_params(axis='both', which='major', labelsize=30)
        ax.tick_params(axis='both', which='minor', labelsize=25)
        
        plt.savefig(f'{image_file}/contour.png')
        plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="My parser")
    parser.add_argument('--solvent', action='store_true')
    parser.add_argument('--no-solvent', dest='solvent', action='store_false')
    parser.set_defaults(solvent=False)

    args = parser.parse_args()

    solvent = args.solvent

    hessian_type = 'min_op'
    inds = 1
    top_K = 1
    period = 1
    mod ='water' if solvent else 'nowater' 

    conformers = solvent_conformers if solvent else vaccum_conformers
    
    for state_type in ['alanine-dipeptide']:

        hessian_type = 'min_op'
        mutation_type = 'Opt'

        eps = (0.1, 0.01)
        for ind in range(1):
            for basis_idx in range(top_K):
                plotter = EnergyLandscape()
                plotter.plot_energy_landscape(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', period=period, image_size=(20, 20), solvent=solvent)
        
        eps = (0.5, 0.01)
        for ind in range(1):
            for basis_idx in range(top_K):
                plotter = EnergyLandscape()
                plotter.plot_energy_landscape(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', period=period, image_size=(20, 20), solvent=solvent)

        hessian_type = 'min_op'
        mutation_type = 'TrueLie'

        for ind in range(1):
            for basis_idx in range(top_K):
                plotter = EnergyLandscape()
                plotter.plot_energy_landscape(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}', period=period, image_size=(20, 20), solvent=solvent)

        hessian_type = 'min_op'
        mutation_type = 'Lie'

        for ind in range(3):
            for basis_idx in range(top_K):
                plotter = EnergyLandscape()
                plotter.plot_energy_landscape(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}', period=period, image_size=(20, 20), solvent=solvent)

    
    if solvent:
        for state_type in ['alanine-dipeptide_water']:

            hessian_type = 'min_op'
            mutation_type = 'Opt'

            eps = (0.1, 0.01)
            for ind in range(1):
                for basis_idx in range(top_K):
                    plotter = EnergyLandscape()
                    plotter.plot_energy_landscape(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', period=period, image_size=(20, 20), solvent=solvent)

            eps = (0.5, 0.01)
            for ind in range(1):
                for basis_idx in range(top_K):
                    plotter = EnergyLandscape()
                    plotter.plot_energy_landscape(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', period=period, image_size=(20, 20), solvent=solvent)

    
    for i in range(3):
        if solvent:
            state_types = [f'alanine-dipeptide_run_{i}', f'alanine-dipeptide_water_run_{i}']
        else:
            state_types = [f'alanine-dipeptide_run_{i}']
        for state_type in state_types:

            hessian_type = 'min_op'
            mutation_type = 'Opt'

            eps = (0.1, 0.01)
            for ind in range(1):
                for basis_idx in range(top_K):
                    plotter = EnergyLandscape()
                    plotter.plot_energy_landscape(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', period=period, image_size=(20, 20), solvent=solvent)
            
            eps = (0.5, 0.01)
            for ind in range(1):
                for basis_idx in range(top_K):
                    plotter = EnergyLandscape()
                    plotter.plot_energy_landscape(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}.pdb', f'./image_results/{state_type}_{mod}/{hessian_type}_{mutation_type}_traj_{ind+1}_{basis_idx+1}_{eps[0]}_{eps[1]}', period=period, image_size=(20, 20), solvent=solvent)
    