import os
import time
import shutil
import argparse
import pandas as pd
from tqdm import tqdm
import multiprocessing
from subprocess import Popen, PIPE
from functools import partialmethod

def convert_db5_to_pqr(data_dir, out_root, pdb_id, pdb2pqr_bin):
    out_dir = os.path.join(out_root, pdb_id)
    os.makedirs(out_dir, exist_ok=False)
    
    for prefix in ['ligand', 'receptor']:
        fname = f'{pdb_id}_l_b.pdb' if prefix == 'ligand' else f'{pdb_id}_r_b.pdb'
    
        # remove HETATM
        input_fpath = os.path.join(data_dir, fname)
        assert os.path.isfile(input_fpath)
        intermediate_fpath = os.path.join(out_dir, prefix+'.pdb')
        with open(input_fpath, 'r') as fin:
            f_read = fin.readlines()
        with open(intermediate_fpath, 'w') as fout:
            for line in f_read:
                if line[:4] == 'ATOM':
                    fout.write(line)
                else:
                    assert (line.strip() == '') or (line[:6] == 'HETATM')
        
        # call pdb2pqr
        output_fpath = os.path.join(out_dir, prefix+'.pqr')
        proc_args = [pdb2pqr_bin, '--ff=AMBER', intermediate_fpath, output_fpath]
        proc = Popen(proc_args, stdout=PIPE, stderr=PIPE)
        _, stderr = proc.communicate()
        errmsg = stderr.decode('utf-8')
        if 'CRITICAL' in errmsg:
            print(f'{fname} failed', flush=True)
            assert not os.path.isfile(output_fpath)
        else:
            assert os.path.isfile(output_fpath)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--pdb2pqr-bin', type=str, default='/usr/local/bin/pdb2pqr30')
    parser.add_argument('--serial', action='store_true')
    parser.add_argument('-j', type=int, default=4)
    parser.add_argument('--mute-tqdm', action='store_true')
    args = parser.parse_args()
    print(args)

    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    # pdb2pqr
    pdb2pqr_bin = args.pdb2pqr_bin
    assert os.path.exists(pdb2pqr_bin)

    # DB5 IO
    db5_difficulty_df = pd.read_csv('db5_difficulty.csv', header=0, index_col=None, sep=',')
    db5_pdb_ids = db5_difficulty_df['pdb_id'].unique()
    db5_pdb_dir = './DB5_pdbs/'
    assert os.path.exists(db5_pdb_dir)
    db5_mesh_dir = './DB5_mesh/'
    if os.path.exists(db5_mesh_dir):
        shutil.rmtree(db5_mesh_dir)
    os.makedirs(db5_mesh_dir, exist_ok=False)

    start = time.time()

    if not args.serial:
        pool = multiprocessing.Pool(processes=args.j)
        pool_args = [(db5_pdb_dir, db5_mesh_dir, pdb_id, pdb2pqr_bin) for pdb_id in db5_pdb_ids]
        pool.starmap(convert_db5_to_pqr, tqdm(pool_args), chunksize=1)
        pool.terminate()
        print('All processes successfully finished')
    else:
        for pdb_id in tqdm(db5_pdb_ids):
            convert_db5_to_pqr(db5_pdb_dir, db5_mesh_dir, pdb_id, pdb2pqr_bin)
    
    print(f'DB5 step1 elapsed time: {(time.time()-start):.1f}s\n')


