import os
import time
import argparse
from tqdm import tqdm
import multiprocessing
from subprocess import Popen, PIPE
from functools import partialmethod

def convert_pdb_to_pqr(data_dir, pdb2pqr_bin):
    # ligand pdb2pqr
    lig_pdb_fpath = os.path.join(data_dir, 'ligand.pdb')
    assert os.path.isfile(lig_pdb_fpath)
    lig_pqr_fpath = os.path.join(data_dir, 'ligand.pqr')
    lig_args = [pdb2pqr_bin, '--ff=AMBER', lig_pdb_fpath, lig_pqr_fpath]
    lig_proc = Popen(lig_args, stdout=PIPE, stderr=PIPE)
    _, stderr = lig_proc.communicate()
    lig_errmsg = stderr.decode('utf-8')
    if 'CRITICAL' in lig_errmsg:
        print(f'{data_dir} ligand failed', flush=True)
        assert not os.path.isfile(lig_pqr_fpath)
        return

    # receptor pdb2pqr
    rec_pdb_fpath = os.path.join(data_dir, 'receptor.pdb')
    assert os.path.isfile(rec_pdb_fpath)
    rec_pqr_fpath = os.path.join(data_dir, 'receptor.pqr')
    rec_args = [pdb2pqr_bin, '--ff=AMBER', rec_pdb_fpath, rec_pqr_fpath]
    rec_proc = Popen(rec_args, stdout=PIPE, stderr=PIPE)
    _, stderr = rec_proc.communicate()
    rec_errmsg = stderr.decode('utf-8')
    if 'CRITICAL' in rec_errmsg:
        print(f'{data_dir} receptor failed', flush=True)
        assert not os.path.isfile(rec_pqr_fpath)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--pdb2pqr-bin', type=str, default='/usr/local/bin/pdb2pqr30')
    parser.add_argument('--serial', action='store_true')
    parser.add_argument('-j', type=int, default=4)
    parser.add_argument('--mute-tqdm', action='store_true')
    args = parser.parse_args()
    print(args)

    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    # pdb2pqr
    pdb2pqr_bin = args.pdb2pqr_bin
    assert os.path.exists(pdb2pqr_bin)

    # DIPS
    DIPS_mesh_dir = './DIPS_mesh/'
    assert os.path.exists(DIPS_mesh_dir)
    pair_paths = [os.path.join(DIPS_mesh_dir, pair_name) for pair_name in os.listdir(DIPS_mesh_dir)]

    # DIPS timer
    start = time.time()

    if not args.serial:
        pool = multiprocessing.Pool(processes=args.j)
        pool_args = [(data_dir, pdb2pqr_bin)  for data_dir in pair_paths]
        pool.starmap(convert_pdb_to_pqr, tqdm(pool_args), chunksize=10)
        pool.terminate()
        print('All processes successfully finished')
    else:
        for data_dir in tqdm(pair_paths):
            convert_pdb_to_pqr(data_dir, pdb2pqr_bin)
    
    print(f'step2 DIPS elapsed time: {(time.time()-start):.2f}s\n')


