import os
import time
import shutil
import pymesh
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
from functools import partialmethod
from geomlib.trimesh import TriMesh
from core.chemistry import atom_type_dict, res_name_dict

# read xyzrn file for atomic features
def read_xyzrn_file(xyzrn_fpath):
    assert os.path.isfile(xyzrn_fpath)
    atom_info = []
    with open(xyzrn_fpath, 'r') as f:
        for line in f.readlines(): 
            line_info = line.rstrip().split()
            assert len(line_info) == 6
            full_id = line_info[-1]
            assert len(full_id.split('_')) == 6
            res_name, res_id, atom_name, atom_id, charge, radius = full_id.split('_')
            assert res_name in res_name_dict
            assert atom_name[0] in atom_type_dict
            atom_info.append(line_info[:3] + 
                             [res_name_dict[res_name],
                              atom_type_dict[atom_name[0]],
                              float(charge)])

    atom_info = np.array(atom_info, dtype=float)
    return atom_info


# skip surface with poor mesh quality
def check_mesh_validity(mesh, check_triangles=False):
    mesh.enable_connectivity()
    verts, faces = mesh.vertices, mesh.faces
    
    # check if a manifold is all-connected using BFS
    visited = np.zeros(len(verts)).astype(bool)
    groups = []
    for ivert in range(len(verts)):
        if visited[ivert]:
            continue
        old_visited = visited.copy()
        queue = [ivert]
        visited[ivert] = True
        while queue:
            curr = queue.pop(0)
            for nbr in mesh.get_vertex_adjacent_vertices(curr):
                if not visited[nbr]:
                    queue.append(nbr)
                    visited[nbr] = True
        groups.append(np.where(np.logical_xor(old_visited, visited))[0])
    groups = sorted(groups, key=lambda x:len(x), reverse=True)
    assert sum(len(ig) for ig in groups) == sum(visited) == len(verts)
    
    # check for isolated vertices
    valid_verts = np.unique(faces)
    has_isolated_verts = verts.shape[0] != len(valid_verts)

    # check for faces with duplicate vertices
    df = pd.DataFrame(faces)
    df = df[df.nunique(axis=1) == 3]
    has_duplicate_verts = df.shape[0] != mesh.num_faces

    # check for abnormal triangles
    if check_triangles:
        v1 = verts[faces[:, 0]]
        v2 = verts[faces[:, 1]]
        v3 = verts[faces[:, 2]]
        e1 = v3 - v2
        e2 = v1 - v3
        e3 = v2 - v1
        L1 = np.linalg.norm(e1, axis=1)
        L2 = np.linalg.norm(e2, axis=1)
        L3 = np.linalg.norm(e3, axis=1)
        cos1 = np.einsum('ij,ij->i', -e2, e3) / (L2 * L3)
        cos2 = np.einsum('ij,ij->i', e1, -e3) / (L1 * L3)
        cos3 = np.einsum('ij,ij->i', -e1, e2) / (L1 * L2)
        cos123 = np.concatenate((cos1.reshape(-1, 1), 
                                 cos2.reshape(-1, 1),
                                 cos3.reshape(-1, 1)), axis=-1)
        valid_faces = np.where(np.all(1 - cos123**2 >= 1E-5, axis=-1))[0]
        has_abnormal_triangles = faces.shape[0] != len(valid_faces)
    else:
        has_abnormal_triangles = False
    
    return groups, has_isolated_verts, has_duplicate_verts, has_abnormal_triangles


def gen_dataset(data_dir, out_dir, eigs_ratio):
    # IO
    lig_mesh_fpath = os.path.join(data_dir, 'ligand_mesh.ply')
    rec_mesh_fpath = os.path.join(data_dir, 'receptor_mesh.ply')
    if not (os.path.isfile(lig_mesh_fpath) and \
            os.path.isfile(rec_mesh_fpath)):
        return
    
    # atomic features
    lig_atom_info = read_xyzrn_file(os.path.join(data_dir, 'ligand.xyzrn'))
    rec_atom_info = read_xyzrn_file(os.path.join(data_dir, 'receptor.xyzrn'))

    # load surface mesh
    lig_mesh = pymesh.load_mesh(lig_mesh_fpath)
    rec_mesh = pymesh.load_mesh(rec_mesh_fpath)

    # apply filters
    if (min(len(lig_atom_info), len(rec_atom_info)) < 1000) or \
       (max(len(lig_atom_info), len(rec_atom_info)) > 10000) or \
       (min(lig_mesh.num_vertices, rec_mesh.num_vertices) < 1000) or \
       (max(lig_mesh.num_vertices, rec_mesh.num_vertices) > 10000):
       return

    # check ligand mesh validity
    groups, has_isolated_verts, has_duplicate_verts, has_abnormal_triangles \
        = check_mesh_validity(lig_mesh, check_triangles=True)
    # apply filters
    if not ((len(groups) == 1) and (not has_isolated_verts) and \
            (not has_duplicate_verts) and (not has_abnormal_triangles)):
        print(f'skip {data_dir} due to poor refined ligand mesh quality')
        print(f'\tgroup sizes: {[len(ig) for ig in groups]}')
        print(f'\thas isolated verts: {has_isolated_verts}')
        print(f'\thas duplicate verts: {has_duplicate_verts}')
        print(f'\thas abnormal triangles: {has_abnormal_triangles}\n', flush=True)
        return
    
    # ligand labelling helper
    lig_dist_to_complex_surface = lig_mesh.get_vertex_attribute('vertex_dist_to_complex_surface')

    # ligand Laplace-Beltrami basis
    lig_k = int(eigs_ratio * lig_mesh.num_vertices) + 1
    assert lig_k < lig_mesh.num_vertices
    lig_trimesh = TriMesh(verts=lig_mesh.vertices, faces=lig_mesh.faces)
    lig_trimesh.LB_decomposition(k=lig_k) # scipy eigsh must have k < N

    # check receptor mesh validity
    groups, has_isolated_verts, has_duplicate_verts, has_abnormal_triangles \
        = check_mesh_validity(rec_mesh, check_triangles=True)
    # apply filters
    if not ((len(groups) == 1) and (not has_isolated_verts) and \
            (not has_duplicate_verts) and (not has_abnormal_triangles)):
        print(f'skip {data_dir} due to poor refined receptor mesh quality')
        print(f'\tgroup sizes: {[len(ig) for ig in groups]}')
        print(f'\thas isolated verts: {has_isolated_verts}')
        print(f'\thas duplicate verts: {has_duplicate_verts}')
        print(f'\thas abnormal triangles: {has_abnormal_triangles}\n', flush=True)
        return
    
    # receptor labelling helper
    rec_dist_to_complex_surface = rec_mesh.get_vertex_attribute('vertex_dist_to_complex_surface')

    # receptor Laplace-Beltrami basis
    rec_k = int(eigs_ratio * rec_mesh.num_vertices) + 1
    assert rec_k < rec_mesh.num_vertices
    rec_trimesh = TriMesh(verts=rec_mesh.vertices, faces=rec_mesh.faces)
    rec_trimesh.LB_decomposition(k=rec_k) # scipy eigsh must have k < N

    # save features
    fname = data_dir[data_dir.rfind('/')+1:]
    fout = os.path.join(out_dir, f'{fname}.npz')
    np.savez_compressed(fout, # ligand
                        lig_verts=lig_mesh.vertices.astype(np.float32),
                        lig_faces=lig_mesh.faces.astype(np.float32),  
                        lig_dist_to_complex_surface=lig_dist_to_complex_surface.astype(np.float32),                     
                        lig_eigen_vals=lig_trimesh.eigen_vals.astype(np.float32),
                        lig_eigen_vecs=lig_trimesh.eigen_vecs.astype(np.float32),
                        lig_mass=lig_trimesh.mass.astype(np.float32),
                        lig_atom_info=lig_atom_info.astype(np.float32),
                        # receptor
                        rec_verts=rec_mesh.vertices.astype(np.float32),
                        rec_faces=rec_mesh.faces.astype(np.float32),
                        rec_dist_to_complex_surface=rec_dist_to_complex_surface.astype(np.float32),                        
                        rec_eigen_vals=rec_trimesh.eigen_vals.astype(np.float32),
                        rec_eigen_vecs=rec_trimesh.eigen_vecs.astype(np.float32),
                        rec_mass=rec_trimesh.mass.astype(np.float32),
                        rec_atom_info=rec_atom_info.astype(np.float32)
                        )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--eigs-ratio', type=float, default=0.06)
    parser.add_argument('--serial', action='store_true')
    parser.add_argument('-j', type=int, default=4)
    parser.add_argument('--mute-tqdm', action='store_true')
    args = parser.parse_args()
    print(args)

    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    # DIPS
    DIPS_mesh_dir = './DIPS_mesh/'
    assert os.path.exists(DIPS_mesh_dir)
    pair_paths = [os.path.join(DIPS_mesh_dir, pair_name) for pair_name in os.listdir(DIPS_mesh_dir)]
    DIPS_dataset_dir = './dataset_DIPS/'
    if os.path.exists(DIPS_dataset_dir):
        shutil.rmtree(DIPS_dataset_dir)
    os.makedirs(DIPS_dataset_dir, exist_ok=False)    

    # DIPS timer
    start = time.time()

    if not args.serial:
        pool = multiprocessing.Pool(processes=args.j)
        pool_args = [(data_dir, DIPS_dataset_dir, args.eigs_ratio) for data_dir in pair_paths]
        pool.starmap(gen_dataset, tqdm(pool_args), chunksize=10)
        pool.terminate()
        print('All processes successfully finished')
    else:
        for data_dir in tqdm(pair_paths):
            gen_dataset(data_dir, DIPS_dataset_dir, args.eigs_ratio)
    
    print(f'step6 DIPS elapsed time: {(time.time()-start):.2f}s\n')


