import os
import time
import pymesh
import argparse
import numpy as np
from tqdm import tqdm
import multiprocessing
from functools import partialmethod
from sklearn.neighbors import BallTree

def remove_abnormal_triangles(mesh):
    verts = mesh.vertices
    faces = mesh.faces
    v1 = verts[faces[:, 0]]
    v2 = verts[faces[:, 1]]
    v3 = verts[faces[:, 2]]
    e1 = v3 - v2
    e2 = v1 - v3
    e3 = v2 - v1
    L1 = np.linalg.norm(e1, axis=1)
    L2 = np.linalg.norm(e2, axis=1)
    L3 = np.linalg.norm(e3, axis=1)
    cos1 = np.einsum('ij,ij->i', -e2, e3) / (L2 * L3)
    cos2 = np.einsum('ij,ij->i', e1, -e3) / (L1 * L3)
    cos3 = np.einsum('ij,ij->i', -e1, e2) / (L1 * L2)
    cos123 = np.concatenate((cos1.reshape(-1, 1), 
                             cos2.reshape(-1, 1),
                             cos3.reshape(-1, 1)), axis=-1)
    valid_faces = np.where(np.all(1 - cos123**2 > 1E-5, axis=-1))[0]
    faces_new = faces[valid_faces]

    return pymesh.form_mesh(verts, faces_new)


# refine MSMS surface mesh
def refine_mesh(data_dir, resolution):
    # load surface mesh
    lig_msms_fpath = os.path.join(data_dir, 'ligand_msms.npz')
    rec_msms_fpath = os.path.join(data_dir, 'receptor_msms.npz')
    complex_msms_fpath = os.path.join(data_dir, 'complex_msms.npz')
    if not (os.path.isfile(lig_msms_fpath) and \
            os.path.isfile(rec_msms_fpath) and \
            os.path.isfile(complex_msms_fpath)):
        return

    # refine ligand mesh
    lig_msms_npz = np.load(lig_msms_fpath)
    lig_mesh_msms = pymesh.form_mesh(lig_msms_npz['verts'], lig_msms_npz['faces'].astype(int))
    lig_mesh, _ = pymesh.remove_duplicated_vertices(lig_mesh_msms, 1E-6)
    lig_mesh, _ = pymesh.remove_degenerated_triangles(lig_mesh, 100)
    lig_mesh, _ = pymesh.split_long_edges(lig_mesh, resolution)
    num_verts = lig_mesh.num_vertices
    iteration = 0
    while iteration < 10:
        lig_mesh, _ = pymesh.collapse_short_edges(lig_mesh, 1E-6)
        lig_mesh, _ = pymesh.collapse_short_edges(lig_mesh, resolution)
        lig_mesh, _ = pymesh.remove_obtuse_triangles(lig_mesh, 170.0, 100)
        if abs(lig_mesh.num_vertices - num_verts) < 20:
            break
        num_verts = lig_mesh.num_vertices
        iteration += 1
    lig_mesh = pymesh.resolve_self_intersection(lig_mesh)
    lig_mesh, _ = pymesh.remove_duplicated_faces(lig_mesh)
    lig_mesh = pymesh.compute_outer_hull(lig_mesh)
    lig_mesh, _ = pymesh.remove_obtuse_triangles(lig_mesh, 179.0, 100)
    lig_mesh = remove_abnormal_triangles(lig_mesh)
    lig_mesh, _ = pymesh.remove_isolated_vertices(lig_mesh)

    # refine receptor mesh
    rec_msms_npz = np.load(rec_msms_fpath)
    rec_mesh_msms = pymesh.form_mesh(rec_msms_npz['verts'], rec_msms_npz['faces'].astype(int))
    rec_mesh, _ = pymesh.remove_duplicated_vertices(rec_mesh_msms, 1E-6)
    rec_mesh, _ = pymesh.remove_degenerated_triangles(rec_mesh, 100)
    rec_mesh, _ = pymesh.split_long_edges(rec_mesh, resolution)
    num_verts = rec_mesh.num_vertices
    iteration = 0
    while iteration < 10:
        rec_mesh, _ = pymesh.collapse_short_edges(rec_mesh, 1E-6)
        rec_mesh, _ = pymesh.collapse_short_edges(rec_mesh, resolution)
        rec_mesh, _ = pymesh.remove_obtuse_triangles(rec_mesh, 170.0, 100)
        if abs(rec_mesh.num_vertices - num_verts) < 20:
            break
        num_verts = rec_mesh.num_vertices
        iteration += 1
    rec_mesh = pymesh.resolve_self_intersection(rec_mesh)
    rec_mesh, _ = pymesh.remove_duplicated_faces(rec_mesh)
    rec_mesh = pymesh.compute_outer_hull(rec_mesh)
    rec_mesh, _ = pymesh.remove_obtuse_triangles(rec_mesh, 179.0, 100)
    rec_mesh = remove_abnormal_triangles(rec_mesh)
    rec_mesh, _ = pymesh.remove_isolated_vertices(rec_mesh)

    # label buried surface area upon binding
    complex_msms_npz = np.load(complex_msms_fpath)
    bt = BallTree(complex_msms_npz['verts'])
    lig_dist, _ = bt.query(lig_mesh.vertices, k=1)
    lig_mesh.add_attribute('dist_to_complex_surface')
    lig_mesh.set_attribute('dist_to_complex_surface', lig_dist)
    rec_dist, _ = bt.query(rec_mesh.vertices, k=1)
    rec_mesh.add_attribute('dist_to_complex_surface')
    rec_mesh.set_attribute('dist_to_complex_surface', rec_dist)

    # save ligand mesh
    lig_mesh_fpath = os.path.join(data_dir, 'ligand_mesh.ply')
    pymesh.save_mesh(lig_mesh_fpath, lig_mesh, 'dist_to_complex_surface', use_float=True, ascii=False)

    # save receptor mesh
    rec_mesh_fpath = os.path.join(data_dir, 'receptor_mesh.ply')
    pymesh.save_mesh(rec_mesh_fpath, rec_mesh, 'dist_to_complex_surface', use_float=True, ascii=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--resolution', type=float, default=1.2)
    parser.add_argument('--serial', action='store_true')
    parser.add_argument('-j', type=int, default=4)
    parser.add_argument('--mute-tqdm', action='store_true')
    args = parser.parse_args()
    print(args)

    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    # DIPS
    DIPS_mesh_dir = './DIPS_mesh/'
    assert os.path.exists(DIPS_mesh_dir)
    pair_paths = [os.path.join(DIPS_mesh_dir, pair_name) for pair_name in os.listdir(DIPS_mesh_dir)]

    # DIPS timer
    start = time.time()

    if not args.serial:
        pool = multiprocessing.Pool(processes=args.j)
        pool_args = [(data_dir, args.resolution) for data_dir in pair_paths]
        pool.starmap(refine_mesh, tqdm(pool_args), chunksize=10)
        pool.terminate()
        print('All processes successfully finished')
    else:
        for data_dir in tqdm(pair_paths):
            refine_mesh(data_dir, args.resolution)
    
    print(f'step5 DIPS elapsed time: {(time.time()-start):.2f}s\n')


