import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.image as mpimg
from matplotlib.pyplot import cm
import matplotlib
import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from openmm import *
from openmm.app import *
from openmm.unit import *
from os import listdir
import numpy as np

A = [(4, 1), (4, 5), (1, 0), (1, 2), (1, 3), (4, 6), (14, 8), (14, 15), (8, 10), (8, 9), (8, 6), (10, 11), (10, 12), (10, 13), (7, 6), (14, 16), (18, 19), (18, 20), (18, 21), (18, 16), (17, 16)]

def plot_sphere(ax, center, radius, color, alpha=1.0):
    """
    Adds a sphere to a 3D plot.
    
    Parameters:
    - ax: Matplotlib 3D axis
    - center: tuple, (x, y, z) coordinates of the sphere center
    - radius: float, radius of the sphere
    - color: color of the sphere
    """
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)
    x = center[0] + radius * np.outer(np.cos(u), np.sin(v))
    y = center[1] + radius * np.outer(np.sin(u), np.sin(v))
    z = center[2] + radius * np.outer(np.ones(np.size(u)), np.cos(v))
    
    ax.plot_surface(x, y, z, color=color, rstride=4, cstride=4, alpha=alpha)

def plot_cylinder(ax, start, end, radius, color, alpha=0.8):
    """
    Adds a cylinder to a 3D plot.
    
    Parameters:
    - ax: Matplotlib 3D axis
    - start: tuple, (x, y, z) coordinates of the start point of the cylinder
    - end: tuple, (x, y, z) coordinates of the end point of the cylinder
    - radius: float, radius of the cylinder
    - color: color of the cylinder
    """
    v = np.array(end) - np.array(start)
    mag = np.linalg.norm(v)
    v = v / mag

    not_v = np.array([1, 0, 0])
    if (v == not_v).all():
        not_v = np.array([0, 1, 0])

    n1 = np.cross(v, not_v)
    n1 /= np.linalg.norm(n1)
    n2 = np.cross(v, n1)
    
    t = np.linspace(0, mag, 100)
    theta = np.linspace(0, 2 * np.pi, 100)
    t, theta = np.meshgrid(t, theta)

    X, Y, Z = [start[i] + v[i] * t + radius * np.sin(theta) * n1[i] + radius * np.cos(theta) * n2[i] for i in [0, 1, 2]]
    
    ax.plot_surface(X, Y, Z, color=color, alpha=alpha)

def plot_ball_and_stick(pdb_file, output_image, period=300, edge_thickness=2, image_size=(20,20), mult=50, indices=None, names=None):
    """
    Plots a 3D ball-and-stick model of a molecule from a PDB file and saves it as a PNG image.
    
    Parameters:
    - pdb_file: string, path to the PDB file
    - output_image: string, path to save the output PNG image
    - ball_size: float, size of the balls representing atoms
    - stick_size: float, thickness of the sticks representing bonds
    """
    # Parse the PDB file
    pdb = PDBFile(pdb_file)

    elems = ['H', 'C', 'H', 'H', 'C', 'O', 'N', 'H', 'C', 'H', 'C', 'H', 'H', 'H', 'C', 'O', 'N', 'H', 'C', 'H', 'H', 'H']
    colors = {'C': 'grey', 'H': 'blue', 'O': 'red', 'N': 'purple'}
    radii = {'C':.75, 'H':.45, 'N':.75, 'O':.75}
    node_radii = [radii[elem] for elem in elems]
    node_colors = [colors[elem] for elem in elems]

    N = pdb.getNumFrames()
    edge_colors = [matplotlib.colors.rgb2hex(c) for c in cm.seismic(np.linspace(0, 1, N//period))]
    edge_colors = ['black']*(N//period)
    posns = [mult*pdb.getPositions(asNumpy=True, frame=i) for i in range(0, N, period)]

    if indices is None:
        indices = range(len(posns))
    
    for ind, (x, edge_color) in enumerate(zip(posns, edge_colors)):
        if ind not in indices:
            continue
        fig = plt.figure(figsize = image_size)
        ax = fig.add_subplot(111, projection='3d')

        x = np.float64(x) # convert to float64
        bonds = []
        for i, j in A:
            # plot_cylinder(axs[ind], x[i], x[j], edge_thickness/2, edge_color)
            bond = sorted([tuple(x[i]), tuple(x[j])])
            if bond not in bonds:
                bonds.append(bond)
        lines = Line3DCollection(bonds, linewidths=edge_thickness/2, color=edge_color, alpha=1)
        ax.add_collection3d(lines)
        # ax1.set_xlim3d(-RADIUS / 2, RADIUS / 2)
        # ax1.set_zlim3d(-RADIUS / 2, RADIUS / 2)
        # ax1.set_ylim3d(-RADIUS / 2, RADIUS / 2)

        ax.view_init(elev=30, azim=45, roll=15)
        ax.axis('off')
        for p, node_radius, node_color in zip(x, node_radii, node_colors):
            # add spheres at the nodes
            plot_sphere(ax, p, node_radius, node_color)

        # axs[ind].scatter(x[:,0], x[:,1], x[:,2], s=node_radii, c=node_colors, alpha=1)

        plt.tight_layout()
        idx = indices.index(ind)
        if names:
            plt.savefig(f'{output_image}_{int(idx*period)}_{names[idx]}.png', transparent=True)
        else:
            plt.savefig(f'{output_image}_{int(idx*period)}.png', transparent=True)
        plt.close()

def scatter_plot_with_images(trajectory_data, mol_data, output_image, period=250, image_size=(10, 10)):
    """
    Plots a scatter plot with a specific image at each scatter point.
    
    Parameters:
    - x: list or numpy array of x coordinates
    - y: list or numpy array of y coordinates
    - image_path: string, path to the image to be used for each scatter point
    - image_size: tuple, size of the image to be displayed (width, height)
    """

    arr = np.load(trajectory_data)
    N = arr.shape[0]
    plot_ball_and_stick(mol_data, 'tmp/img', image_size=(image_size[0]//10, image_size[1]//10), period=period)
    x, y = arr[list(range(0, N, period)), 1+21+36+11]*(180/np.pi), arr[list(range(0, N, period)), 1+21+36+19]*(180/np.pi)
    image_paths = [f'tmp/img_{ind}.png' for ind in range(0, N, period)]
    print(x.shape, y.shape, len(image_paths))

    # Load the image

    ramachandran_arr = np.load('ramachandran.npz')
    
    fig, ax = plt.subplots(figsize=image_size)
    
    # Scatter plot for reference (if needed, can be removed)
    ax.scatter(x, y, s=0)  # s=0 makes the scatter points invisible
    
    # Plot each image at the scatter points
    for xi, yi, img_path in zip(x, y, image_paths):
        # Create an OffsetImage object
        img = mpimg.imread(img_path)
        imagebox = OffsetImage(img, zoom=1)
        # Create an AnnotationBbox object
        ab = AnnotationBbox(imagebox, (xi, yi), frameon=False)
        # Add the AnnotationBbox object to the plot
        ax.add_artist(ab)
    
    # Set limits
    energies = arr[:,0]
    ax.contour(ramachandran_arr['X'], ramachandran_arr['Y'], np.log(1./250000 + ramachandran_arr['Z']), levels=25, cmap='plasma')
    cbar = ax.scatter(arr[:, 1+21+36+11]*(180/np.pi), arr[:, 1+21+36+19]*(180/np.pi), s=4, c=energies, cmap='seismic',)
    fig.colorbar(cbar)
    ax.set_xlim([-180, 180])
    ax.set_ylim([-180, 180])
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.tick_params(axis='both', which='minor', labelsize=14)
    plt.savefig(f"{output_image}.png")
    

def compute_dihedral(x, atom_indices):
    x = x[np.newaxis, ::]
    assert len(atom_indices) == 4
    i, j, k, l = atom_indices
    rij = x[:, j, :] - x[:, i, :]
    rjk = x[:, k, :] - x[:, j, :]
    rkl = x[:, l, :] - x[:, k, :]
    n_ijk = np.cross(rij, rjk)
    n_jkl = np.cross(rjk, rkl)
    rjk_norm = np.linalg.norm(rjk, axis=1, keepdims=True)
    t = np.sum(np.cross(n_ijk, n_jkl)*rjk, axis=1, keepdims=True)
    s = rjk_norm * np.sum(n_ijk*n_jkl, axis=1, keepdims=True)
    phi = np.arctan2(t, s)
    return phi[0,0]*(180/np.pi)

def cyclic_break(arr1, *arrs):
    chunks = []

    start = 0
    prev_elem = arr1[start]
    chunks.append([prev_elem])
    arrs_chunks = []
    for idx in range(1, len(arr1)):
        elem = arr1[idx]
        if (elem - prev_elem) > 180:
            chunks[-1].append(elem - 360)
            chunks.append([elem])
            arrs_chunks.append([arr[start:idx+1] for arr in arrs])
            start = idx
        elif (elem - prev_elem) < -180:
            chunks[-1].append(elem + 360)
            chunks.append([elem])
            arrs_chunks.append([arr[start:idx+1] for arr in arrs])
            start = idx
        else:
            chunks[-1].append(elem)
        prev_elem = elem
    arrs_chunks.append([arr[start:] for arr in arrs])
    return chunks, arrs_chunks

def full_cyclic_break(arrs):
    arr_chunks = [arrs]
    for i in range(len(arrs)):
        new_arr_chunks = []
        for chunk in arr_chunks:
            main_arr = chunk[i]
            other_arrs = chunk[:i] + chunk[i+1:]
            chunks, arrs_chunks = cyclic_break(main_arr, *other_arrs)
            new_arr_chunks.extend([ac[:i] + [c] + ac[i:] for c, ac in zip(chunks, arrs_chunks)])
        arr_chunks = new_arr_chunks
    return arr_chunks

if __name__ =="__main__":

    phi_indices = [4, 6, 8, 14]
    psi_indices = [6, 8, 14, 16]  

    ethick, mult = 5, 5
    # mods = ["water", "nowater"]
    mods = ["water"]
    # mods = ["nowater"]

    for mod in mods:
        method = 'TrueLie'
        method_mod = ""
        for ind in range(1):
            results_dir = f"image_results/alanine-dipeptide_{mod}/min_op_{method}_traj_{ind+1}_1{method_mod}"
            pdbfilename = f"{results_dir}/modes.pdb"
            pdb = PDBFile(pdbfilename)
            indices = range(pdb.getNumFrames())

            names = []
            for num in indices:
                positions = pdb.getPositions(asNumpy=True, frame=num)
                positions = positions.value_in_unit(positions.unit)
                phi, psi = compute_dihedral(positions, phi_indices), compute_dihedral(positions, psi_indices)
                names.append((np.rint(phi), np.rint(psi)))

            plot_ball_and_stick(pdbfilename, f"{results_dir}/mode", period=1, edge_thickness=ethick, mult=mult, indices=indices, names=names)

        method = 'Opt'
        epsilons = [0.1, 0.01]
        method_mod = f"_{epsilons[0]}_{epsilons[1]}"
        for ind in range(1):
            results_dir = f"image_results/alanine-dipeptide_{mod}/min_op_{method}_traj_{ind+1}_1{method_mod}"
            pdbfilename = f"{results_dir}/modes.pdb"
            pdb = PDBFile(pdbfilename)
            indices = range(pdb.getNumFrames())

            names = []
            for num in indices:
                positions = pdb.getPositions(asNumpy=True, frame=num)
                positions = positions.value_in_unit(positions.unit)
                phi, psi = compute_dihedral(positions, phi_indices), compute_dihedral(positions, psi_indices)
                names.append((np.rint(phi), np.rint(psi)))

            plot_ball_and_stick(pdbfilename, f"{results_dir}/mode", period=1, edge_thickness=ethick, mult=mult, indices=indices, names=names) 

        method = 'Lie'
        method_mod = ""
        for ind in range(3):
            results_dir = f"image_results/alanine-dipeptide_{mod}/min_op_{method}_traj_{ind+1}_1{method_mod}"
            pdbfilename = f"{results_dir}/modes.pdb"
            pdb = PDBFile(pdbfilename)
            indices = range(pdb.getNumFrames())

            names = []
            for num in indices:
                positions = pdb.getPositions(asNumpy=True, frame=num)
                positions = positions.value_in_unit(positions.unit)
                phi, psi = compute_dihedral(positions, phi_indices), compute_dihedral(positions, psi_indices)
                names.append((np.rint(phi), np.rint(psi)))

            plot_ball_and_stick(pdbfilename, f"{results_dir}/mode", period=1, edge_thickness=ethick, mult=mult, indices=indices, names=names)
        
    # for state_type in ['alanine-dipeptide']:
    #     hessian_type = 'min_op'
    #     mutation_type = 'cannonical_lie'


    #     directory_path = f'./results/{hessian_type}/{mutation_type}/{state_type}/'
    #     file_types = ['npy']

    #     names  = [(dir_content.split('.')[0].split('_')[1], dir_content.split('.')[0].split('_')[2])
    #             for dir_content in listdir(directory_path)
    #             if dir_content.split('.')[-1] in file_types]
    #     bases = [max([int(name[1]) for name in names if int(name[0]) == ind]) for ind in range(1,4)]
    #     print(bases)

    #     for ind in range(3):
    #         for basis_idx in range(bases[ind]+1):
    #             scatter_plot_with_images(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}_{basis_idx}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}_{basis_idx}.pdb', f'./image_results/{state_type}/{hessian_type}_{mutation_type}_{ind+1}_{basis_idx}', period=200, image_size=(20, 20))

    #     hessian_type = 'min_op'
    #     mutation_type = 'lie'

    #     for ind in range(3):
    #         scatter_plot_with_images(f'./results/{hessian_type}/{mutation_type}/{state_type}/traj_{ind+1}.npy', f'./results/{hessian_type}/{mutation_type}/{state_type}/eigen_{ind+1}.pdb', f'./image_results/{state_type}/{hessian_type}_{mutation_type}_{ind+1}', period=200, image_size=(20, 20))


    # def plot_eigenvectors(pdb_file, vectors, output_image, edge_thickness=2, cmap='seismic', image_size=(10,10)):
    #     """
    #     Plots a 3D ball-and-stick model of a molecule from a PDB file and saves it as a PNG image.
        
    #     Parameters:
    #     - pdb_file: string, path to the PDB file
    #     - output_image: string, path to save the output PNG image
    #     - ball_size: float, size of the balls representing atoms
    #     - stick_size: float, thickness of the sticks representing bonds
    #     """
    #     # Parse the PDB file
    #     pdb = PDBFile(pdb_file)

    #     elems = ['H', 'C', 'H', 'H', 'C', 'O', 'N', 'H', 'C', 'H', 'C', 'H', 'H', 'H', 'C', 'O', 'N', 'H', 'C', 'H', 'H', 'H']
    #     radii = {'C':.75, 'H':.45, 'N':.75, 'O':.75}
    #     node_radii = [radii[elem] for elem in elems]

    #     N = pdb.getNumFrames()
    #     for ind in range(vectors.shape[1]):
    #         vector = vectors[:, ind]
    #         vector = (vector - np.amin (vector))/(np.amax(vector) - np.amin(vector))
    #         node_colors = [matplotlib.colors.rgb2hex(c) for c in cm.seismic(vector.reshape(-1))]
    #         edge_color = ['black']
    #         posn = 40*pdb.getPositions(asNumpy=True, frame=0)

    #         fig = plt.figure(figsize = image_size)
    #         ax = fig.add_subplot(111, projection='3d')

    #         x = np.float64(posn) # convert to float64
    #         bonds = []
    #         for i, j in A:
    #             # plot_cylinder(axs[ind], x[i], x[j], edge_thickness/2, edge_color)
    #             bond = sorted([tuple(x[i]), tuple(x[j])])
    #             if bond not in bonds:
    #                 bonds.append(bond)
    #         lines = Line3DCollection(bonds, linewidths=edge_thickness/2, color=edge_color, alpha=1)
    #         ax.add_collection3d(lines)
    #         # ax1.set_xlim3d(-RADIUS / 2, RADIUS / 2)
    #         # ax1.set_zlim3d(-RADIUS / 2, RADIUS / 2)
    #         # ax1.set_ylim3d(-RADIUS / 2, RADIUS / 2)

    #         ax.view_init(elev=30, azim=45, roll=15)
    #         ax.axis('off')
    #         for p, node_radius, node_color in zip(x, node_radii, node_colors):
    #             # add spheres at the nodes
    #             plot_sphere(ax, p, node_radius, node_color)

    #         plt.tight_layout()
    #         plt.savefig(f'{output_image}_{ind+1}.png')

        
    # for state_type in ['random']:
        # hessian_type = 'min_op'
        # mutation_type = 'cannonical_lie'

        # data = np.load(f'./results/{hessian_type}/{mutation_type}/{state_type}/eigenvectors.npz')
        # for ind in range(3):
        #     vectors = data[f'arr_{ind}']
        #     plot_eigenvectors('alanine-dipeptide.pdb', vectors, f'./image_results/{hessian_type}_{mutation_type}_{state_type}_eigenvectors_{ind+1}')