import os
import time
import random
import shutil
import argparse
import pandas as pd
from tqdm import tqdm
import multiprocessing
from functools import partialmethod

def parse_df_to_pdb(df, out_fpath):
    # 1-4 ATOM
    # 7-11 atom serial number
    atom_ids = [str(atom_id).rjust(5) for atom_id in df['aid']]
    # 13-16 atom name (we do not have H here so atom name starts at column 14)
    atom_names = [' '+atom_name.strip().ljust(3) for atom_name in df['atom_name']]
    # 18-20 residue name
    for res_name in df['resname']:
        assert len(res_name) == 3
    res_names = [res_name for res_name in df['resname']]
    # 22 chain identifier
    for chain_id in df['chain']:
        assert len(chain_id) == 1
    chain_ids = [chain_id for chain_id in df['chain']]
    # 23-26 residue sequence number
    res_ids = [res_id.strip()[:4].rjust(4) for res_id in df['residue']]
    # 31-38 x
    xs = ['{:.3f}'.format(x).rjust(8) for x in df['x']]
    # 39-46 y
    ys = ['{:.3f}'.format(y).rjust(8) for y in df['y']]
    # 47-54 z
    zs = ['{:.3f}'.format(z).rjust(8) for z in df['z']]
    # 55-60 occupancy
    occus = ['  1.00'] * df.shape[0]
    # 77-78 element symbol
    elems = [str(elem).rjust(2) for elem in df['element']]

    with open(out_fpath, 'w') as f:
        for i in range(len(atom_ids)):
            if elems[i].strip() == 'H' or atom_names[i].strip()[0] == 'H':
                #print(f'{out_fpath} skip line containing {atom_names[i]}, {elems[i]}', flush=True)
                continue
            line = 'ATOM  ' + atom_ids[i] + ' ' + atom_names[i] + ' ' + \
                   res_names[i] + ' ' + chain_ids[i] + res_ids[i] + '    ' + \
                   xs[i] + ys[i] + zs[i] + occus[i] + ' '*16 + elems[i] + '\n'
            assert len(line) == 79
            f.write(line)


def convert_dill_to_pdb(fpath, out_root):
    fname_parsed = fpath.split('/')[-1].split('.')
    assert len(fname_parsed) == 3
    pdb_id = fname_parsed[0]
    assert len(pdb_id) == 4
    pair_name = fname_parsed[1]
    assert len(pair_name.split('_')) == 2
    #assert pair_name.split('_')[-1] == '0'
    out_dir = os.path.join(out_root, f'{pdb_id}_{pair_name}')
    os.makedirs(out_dir, exist_ok=False)

    data = pd.read_pickle(fpath)

    # ligand
    df0 = data.df0
    df0.rename(columns={'chain_id': 'chain', 'residue_number': 'residue', 'residue_name': 'resname',
                       'x_coord': 'x', 'y_coord': 'y', 'z_coord': 'z', 'element_symbol': 'element'}, inplace=True)
    
    # receptor
    df1 = data.df1
    df1.rename(columns={'chain_id': 'chain', 'residue_number': 'residue', 'residue_name': 'resname',
                       'x_coord': 'x', 'y_coord': 'y', 'z_coord': 'z', 'element_symbol': 'element'}, inplace=True)
    
    # we label the smaller component as the ligand
    lig_out_fpath = os.path.join(out_dir, 'ligand.pdb')
    rec_out_fpath = os.path.join(out_dir, 'receptor.pdb')
    if len(df0['x']) < len(df1['x']):
        parse_df_to_pdb(df0, lig_out_fpath)
        parse_df_to_pdb(df1, rec_out_fpath)
    else:
        parse_df_to_pdb(df1, lig_out_fpath)
        parse_df_to_pdb(df0, rec_out_fpath)
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--serial', action='store_true')
    parser.add_argument('-j', type=int, default=4)
    parser.add_argument('--mute-tqdm', action='store_true')
    parser.add_argument('--debug', default=False, action='store_true')
    parser.add_argument('-N', type=int, default=1000)
    args = parser.parse_args()
    print(args)

    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    # filter
    filter_file = './DIPS_pair_summary.csv'
    assert os.path.isfile(filter_file)
    df = pd.read_csv(filter_file, header=0, delimiter=',', index_col=None)
    df_refined = df.loc[(df['original_chains'] < 5) & (df['pruned_pairs'] < 3)]
    valid_pdb_ids = df_refined['dips_id'].str.lower().tolist()

    # DIPS
    DIPS_data_dir = './DIPS_pairs_pruned/'
    assert os.path.exists(DIPS_data_dir)
    DIPS_mesh_dir = './DIPS_mesh/'
    if os.path.exists(DIPS_mesh_dir):
        shutil.rmtree(DIPS_mesh_dir)
    os.makedirs(DIPS_mesh_dir, exist_ok=False)
    DIPS_fpath_list = []
    for subdir in os.listdir(DIPS_data_dir):
        for fname in os.listdir(os.path.join(DIPS_data_dir, subdir)):
            assert fname[-5:] == '.dill'
            if (fname.lower().split('_')[0] in valid_pdb_ids) and \
               (fname[fname.find('.')+1:fname.rfind('_')] == 'pdb1'):
                DIPS_fpath_list.append(os.path.join(DIPS_data_dir, subdir, fname))
    
    if args.debug:
        DIPS_fpath_list = random.sample(DIPS_fpath_list, args.N)
    
    print(f'length of DIPS file path list: {len(DIPS_fpath_list)}')

    # DIPS timer
    start = time.time()
    
    if not args.serial:
        pool = multiprocessing.Pool(processes=args.j)
        pool_args = [(fpath, DIPS_mesh_dir) for fpath in DIPS_fpath_list]
        pool.starmap(convert_dill_to_pdb, tqdm(pool_args), chunksize=10)
        pool.terminate()
        print('All processes successfully finished')
    else:
        for fpath in tqdm(DIPS_fpath_list):    
            convert_dill_to_pdb(fpath, DIPS_mesh_dir)
    
    print(f'step1 DIPS elapsed time: {(time.time()-start):.2f}s\n')
    

