import os
import time
import argparse
import numpy as np
from tqdm import tqdm
import multiprocessing
from subprocess import Popen, PIPE
from functools import partialmethod

# parse MSMS output
def read_msms(mesh_prefix):
    assert os.path.isfile(mesh_prefix + '.vert')
    assert os.path.isfile(mesh_prefix + '.face')
    
    # vertices
    with open(mesh_prefix + '.vert') as f:
        vert_data = f.read().rstrip().split('\n')
    num_verts = int(vert_data[2].split()[0])
    assert num_verts == len(vert_data) - 3
    vertices = []
    vnormals = []
    for idx in range(3, len(vert_data)):
        ifields = vert_data[idx].split()
        assert len(ifields) == 10
        vertices.append(ifields[:3])
        vnormals.append(ifields[3:6])
        full_id = ifields[-1].split('_')
        assert len(full_id) == 6
    assert len(vertices) == num_verts

    # faces
    with open(mesh_prefix + '.face') as f:
        face_data = f.read().rstrip().split('\n')
    num_faces = int(face_data[2].split()[0])
    assert num_faces == len(face_data) - 3
    faces = []
    for idx in range(3, len(face_data)):
        ifields = face_data[idx].split()
        assert len(ifields) == 5
        faces.append(ifields[:3]) # one-based, to be converted
    assert len(faces) == num_faces

    # solvent excluded surface info
    vertices = np.array(vertices, dtype=float)
    vnormals = np.array(vnormals, dtype=float)
    faces = np.array(faces, dtype=int) - 1 # convert to zero-based indexing
    assert np.amin(faces) == 0
    assert np.amax(faces) < num_verts
    
    return vertices, vnormals, faces


# use MSMS to compute molecular solvent excluded surface
def compute_ses(data_dir, msms_bin):
    lig_xyzrn_fpath = os.path.join(data_dir, 'ligand.xyzrn')
    rec_xyzrn_fpath = os.path.join(data_dir, 'receptor.xyzrn')
    complex_xyzrn_fpath = os.path.join(data_dir, 'complex.xyzrn')
    if not (os.path.isfile(lig_xyzrn_fpath) and
            os.path.isfile(rec_xyzrn_fpath) and
            os.path.isfile(complex_xyzrn_fpath)):
        return

    # ligand 
    lig_mesh_prefix = os.path.join(data_dir, 'ligand')
    lig_args = [msms_bin, '-if', lig_xyzrn_fpath, '-of', lig_mesh_prefix, \
                '-probe_radius', '1.5', '-density', '1.0']
    lig_proc = Popen(lig_args, stdout=PIPE, stderr=PIPE)
    _, lig_stderr = lig_proc.communicate()
    # skip if MSMS failed
    lig_errmsg = lig_stderr.decode('utf-8')
    if 'ERROR' in lig_errmsg:
        print(f'skip {data_dir} ligand\n  {lig_errmsg}', flush=True)
        return
    
    # receptor
    rec_mesh_prefix = os.path.join(data_dir, 'receptor')
    rec_args = [msms_bin, '-if', rec_xyzrn_fpath, '-of', rec_mesh_prefix, \
                '-probe_radius', '1.5', '-density', '1.0']
    rec_proc = Popen(rec_args, stdout=PIPE, stderr=PIPE)
    _, rec_stderr = rec_proc.communicate()
    # skip if MSMS failed
    rec_errmsg = rec_stderr.decode('utf-8')
    if 'ERROR' in rec_errmsg:
        print(f'skip {data_dir} receptor\n  {rec_errmsg}', flush=True)
        return
    
    # complex
    complex_mesh_prefix = os.path.join(data_dir, 'complex')
    complex_args = [msms_bin, '-if', complex_xyzrn_fpath, '-of', complex_mesh_prefix, \
                    '-probe_radius', '1.5', '-density', '1.0']
    complex_proc = Popen(complex_args, stdout=PIPE, stderr=PIPE)
    _, complex_stderr = complex_proc.communicate()
    # skip if MSMS failed
    complex_errmsg = complex_stderr.decode('utf-8')
    if 'ERROR' in complex_errmsg:
        print(f'skip {data_dir} complex\n  {complex_errmsg}', flush=True)
        return
    
    if not (os.path.isfile(lig_mesh_prefix+'.vert') and \
            os.path.isfile(lig_mesh_prefix+'.face') and \
            os.path.isfile(rec_mesh_prefix+'.vert') and \
            os.path.isfile(rec_mesh_prefix+'.face') and \
            os.path.isfile(complex_mesh_prefix+'.vert') and \
            os.path.isfile(complex_mesh_prefix+'.face')):
        print(f'skip {data_dir} due to missing MSMS output', flush=True)
        return

    # save ligand surface
    lig_verts, _, lig_faces = read_msms(lig_mesh_prefix)
    lig_mesh_fpath = os.path.join(data_dir, 'ligand_msms.npz')
    np.savez(lig_mesh_fpath, 
             verts=lig_verts.astype(np.float32),
             faces=lig_faces.astype(np.float32))
    
    # save receptor surface
    rec_verts, _, rec_faces = read_msms(rec_mesh_prefix)
    rec_mesh_fpath = os.path.join(data_dir, 'receptor_msms.npz')
    np.savez(rec_mesh_fpath,
             verts=rec_verts.astype(np.float32),
             faces=rec_faces.astype(np.float32))
    
    # save complex surface
    complex_verts, _, _ = read_msms(complex_mesh_prefix)
    complex_mesh_fpath = os.path.join(data_dir, 'complex_msms.npz')
    np.savez(complex_mesh_fpath, 
             verts=complex_verts.astype(np.float32))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--serial', action='store_true')
    parser.add_argument('-j', type=int, default=4)
    parser.add_argument('--mute-tqdm', action='store_true')
    args = parser.parse_args()
    print(args)

    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    # MSMS
    msms_bin = '~/MSMS/msms.x86_64Linux2.2.6.1'
    assert os.path.exists(msms_bin)

    # DIPS
    DIPS_mesh_dir = './DIPS_mesh/'
    assert os.path.exists(DIPS_mesh_dir)
    pair_paths = [os.path.join(DIPS_mesh_dir, pair_name) for pair_name in os.listdir(DIPS_mesh_dir)]

    # DIPS timer
    start = time.time()

    if not args.serial:
        pool = multiprocessing.Pool(processes=args.j)
        pool_args = [(data_dir, msms_bin) for data_dir in pair_paths]
        pool.starmap(compute_ses, tqdm(pool_args), chunksize=10)
        pool.terminate()
        print('All processes successfully finished')
    else:
        for data_dir in tqdm(pair_paths):
            compute_ses(data_dir, msms_bin)
    
    print(f'step4 DIPS elapsed time: {(time.time()-start):.2f}s\n')


