import os
import time
import argparse
import numpy as np
from tqdm import tqdm
import multiprocessing
from functools import partialmethod
from core.chemistry import atom_type_dict, res_type_dict

def parse_pqr_file(pqr_fpath):
    with open(pqr_fpath, 'r') as f:
        f_read = f.readlines()
    xyz_list = [] # atomic coordinates
    rn_list = [] # radius and descriptions
    for line in f_read:
        if line[:4] == 'ATOM':
            assert (len(line) == 70) and (line[69] == '\n')
            atom_id = int(line[6:11]) # 1-based indexing
            assert line[11] == ' '
            atom_name = line[12:16].strip()
            assert atom_name[0] in atom_type_dict
            assert line[16] == ' '
            res_name = line[17:20]
            assert res_name in res_type_dict
            res_id = int(line[22:26].strip()) # 1-based indexing
            x = float(line[30:38])
            y = float(line[38:46])
            z = float(line[46:54])
            assert line[54] == ' '
            charge = float(line[55:62])
            assert line[62] == ' '
            radius = float(line[63:69])
            xyz_list.append([x, y, z])
            full_id = f'{res_name}_{res_id:d}_{atom_name}_{atom_id:d}_{charge:.4f}_{radius:.4f}'
            rn_list.append(str(radius) + ' 1 ' + full_id)
    
    return np.array(xyz_list, dtype=float), rn_list


# prepare input for MSMS surface computation
def convert_pqr_to_xyzrn(data_dir):
    # IO
    lig_pqr_fpath = os.path.join(data_dir, 'ligand.pqr')
    rec_pqr_fpath = os.path.join(data_dir, 'receptor.pqr')
    if not (os.path.isfile(lig_pqr_fpath) and \
            os.path.isfile(rec_pqr_fpath)):
        return

    # parse ligand pqr file
    lig_xyz, lig_rn = parse_pqr_file(lig_pqr_fpath)
    # write ligand xyzrn file
    lig_xyzrn_fpath = os.path.join(data_dir, 'ligand.xyzrn')
    with open(lig_xyzrn_fpath, 'w') as f:
        for idx in range(len(lig_xyz)):
            coords = '{:.6f} {:.6f} {:.6f} '.format(*lig_xyz[idx])
            f.write(coords + lig_rn[idx] + '\n')

    # parse receptor pqr file
    rec_xyz, rec_rn = parse_pqr_file(rec_pqr_fpath)
    # write receptor xyzrn file
    rec_xyzrn_fpath = os.path.join(data_dir, 'receptor.xyzrn')
    with open(rec_xyzrn_fpath, 'w') as f:
        for idx in range(len(rec_xyz)):
            coords = '{:.6f} {:.6f} {:.6f} '.format(*rec_xyz[idx])
            f.write(coords + rec_rn[idx] + '\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--serial', action='store_true')
    parser.add_argument('-j', type=int, default=4)
    parser.add_argument('--mute-tqdm', action='store_true')
    args = parser.parse_args()
    print(args)

    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    # DB5 IO
    db5_mesh_dir = './DB5_mesh/'
    assert os.path.exists(db5_mesh_dir)

    start = time.time()

    if not args.serial:
        pool = multiprocessing.Pool(processes=args.j)
        pool_args = [os.path.join(db5_mesh_dir, pdb_id) for pdb_id in os.listdir(db5_mesh_dir)]
        pool.map(convert_pqr_to_xyzrn, tqdm(pool_args), chunksize=1)
        pool.terminate()
        print('All processes successfully finished')
    else:
        for pdb_id in tqdm(os.listdir(db5_mesh_dir)):
            convert_pqr_to_xyzrn(os.path.join(db5_mesh_dir, pdb_id))
    
    print(f'DB5 step2 elapsed time: {(time.time()-start):.1f}s\n')


