import os
import time
import argparse
import numpy as np
from tqdm import tqdm
import multiprocessing
from subprocess import Popen, PIPE
from functools import partialmethod

# parse MSMS output
def read_msms(mesh_prefix):
    assert os.path.isfile(mesh_prefix + '.vert')
    assert os.path.isfile(mesh_prefix + '.face')
    
    # vertices
    with open(mesh_prefix + '.vert') as f:
        vert_data = f.read().rstrip().split('\n')
    num_verts = int(vert_data[2].split()[0])
    assert num_verts == len(vert_data) - 3
    vertices = []
    vnormals = []
    for idx in range(3, len(vert_data)):
        ifields = vert_data[idx].split()
        assert len(ifields) == 10
        vertices.append(ifields[:3])
        vnormals.append(ifields[3:6])
        full_id = ifields[-1].split('_')
        assert len(full_id) == 6
    assert len(vertices) == num_verts

    # faces
    with open(mesh_prefix + '.face') as f:
        face_data = f.read().rstrip().split('\n')
    num_faces = int(face_data[2].split()[0])
    assert num_faces == len(face_data) - 3
    faces = []
    for idx in range(3, len(face_data)):
        ifields = face_data[idx].split()
        assert len(ifields) == 5
        faces.append(ifields[:3]) # one-based, to be converted
    assert len(faces) == num_faces

    # solvent excluded surface info
    vertices = np.array(vertices, dtype=float)
    vnormals = np.array(vnormals, dtype=float)
    faces = np.array(faces, dtype=int) - 1 # convert to zero-based indexing
    assert np.amin(faces) == 0
    assert np.amax(faces) < num_verts
    
    return vertices, vnormals, faces


# use MSMS to compute molecular solvent excluded surface
def compute_ses(data_root, pdb_id, probe_radius, msms_bin):
    # specify IO dir
    data_dir = os.path.join(data_root, pdb_id)

    # ligand xyzrn
    lig_xyzrn_fpath = os.path.join(data_dir, 'ligand.xyzrn')
    if not os.path.isfile(lig_xyzrn_fpath):
        return
    lig_mesh_prefix = os.path.join(data_dir, 'ligand')
    lig_args = [msms_bin, '-if', lig_xyzrn_fpath, '-of', lig_mesh_prefix, \
                '-probe_radius', str(probe_radius), '-density', '1.0']
    lig_proc = Popen(lig_args, stdout=PIPE, stderr=PIPE)
    _, lig_stderr = lig_proc.communicate()
    # skip if MSMS failed
    lig_errmsg = lig_stderr.decode('utf-8')
    if 'ERROR' in lig_errmsg:
        print(f'skip {pdb_id} ligand\n  {lig_errmsg}', flush=True)
        return
    
    # receptor xyzrn
    rec_xyzrn_fpath = os.path.join(data_dir, 'receptor.xyzrn')
    assert os.path.isfile(rec_xyzrn_fpath)
    rec_mesh_prefix = os.path.join(data_dir, 'receptor')
    rec_args = [msms_bin, '-if', rec_xyzrn_fpath, '-of', rec_mesh_prefix, \
                '-probe_radius', str(probe_radius), '-density', '1.0']
    rec_proc = Popen(rec_args, stdout=PIPE, stderr=PIPE)
    _, rec_stderr = rec_proc.communicate()
    # skip if MSMS failed
    rec_errmsg = rec_stderr.decode('utf-8')
    if 'ERROR' in rec_errmsg:
        print(f'skip {pdb_id} receptor\n  {rec_errmsg}', flush=True)
        return
    
    if not (os.path.isfile(lig_mesh_prefix+'.vert') and \
            os.path.isfile(lig_mesh_prefix+'.face') and \
            os.path.isfile(rec_mesh_prefix+'.vert') and \
            os.path.isfile(rec_mesh_prefix+'.face')):
        print(f'skip {pdb_id} due to missing MSMS output', flush=True)
        return

    lig_verts, _, lig_faces = read_msms(lig_mesh_prefix)
    rec_verts, _, rec_faces = read_msms(rec_mesh_prefix)
    if min(len(lig_verts), len(rec_verts)) < 1000:
        print(f'skip {pdb_id} with small vert size: {len(lig_verts)}, {len(rec_verts)}')
        return

    # save ligand surface
    lig_mesh_fpath = os.path.join(data_dir, 'ligand_msms.npz')
    np.savez(lig_mesh_fpath, verts=lig_verts, faces=lig_faces)
    
    # save receptor surface
    rec_mesh_fpath = os.path.join(data_dir, 'receptor_msms.npz')
    np.savez(rec_mesh_fpath, verts=rec_verts, faces=rec_faces)
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--probe-radius', type=float, default=1.5)
    parser.add_argument('--serial', action='store_true')
    parser.add_argument('-j', type=int, default=4)
    parser.add_argument('--mute-tqdm', action='store_true')
    args = parser.parse_args()
    print(args)

    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    # MSMS
    msms_bin = '~/MSMS/msms.x86_64Linux2.2.6.1'
    assert os.path.exists(msms_bin)

    # DB5 IO
    db5_mesh_dir = './DB5_mesh/'
    assert os.path.exists(db5_mesh_dir)

    start = time.time()

    if not args.serial:
        pool = multiprocessing.Pool(processes=args.j)
        pool_args = [(db5_mesh_dir, pdb_id, args.probe_radius, msms_bin) for pdb_id in os.listdir(db5_mesh_dir)]
        pool.starmap(compute_ses, tqdm(pool_args), chunksize=1)
        pool.terminate()
        print('All processes successfully finished')
    else:
        for pdb_id in tqdm(os.listdir(db5_mesh_dir)):
            compute_ses(db5_mesh_dir, pdb_id, args.probe_radius, msms_bin)
    
    print(f'DB5 step3 elapsed time: {(time.time()-start):.1f}s\n')


