import os
import time
import argparse
import numpy as np
from tqdm import tqdm
import multiprocessing
from functools import partialmethod
from geomlib.trimesh import TriMesh
from core.chemistry import atom_type_dict, res_type_dict

# read xyzrn file for atomic features
def read_xyzrn_file(data_dir, suffix):
    xyzrn_fpath = os.path.join(data_dir, f'{suffix}.xyzrn')
    assert os.path.isfile(xyzrn_fpath)
    atom_info = []
    with open(xyzrn_fpath, 'r') as f:
        for line in f.readlines(): 
            line_info = line.rstrip().split()
            assert len(line_info) == 6
            full_id = line_info[-1]
            assert len(full_id.split('_')) == 6
            res_name, res_id, atom_name, atom_id, charge, radius = full_id.split('_')
            assert res_name in res_type_dict
            assert atom_name[0] in atom_type_dict
            alpha_carbon = atom_name.upper() == 'CA'
            atom_info.append(line_info[:3] + 
                             [res_type_dict[res_name],
                              atom_type_dict[atom_name[0]],
                              float(charge),
                              float(radius),
                              alpha_carbon])

    return np.array(atom_info, dtype=float)


def gen_dataset(data_root, out_dir, pdb_id, min_num_eigs, eigs_ratio):
    # IO
    fout = os.path.join(out_dir, f'{pdb_id}.npz')
    if os.path.isfile(fout):
        return
    
    data_dir = os.path.join(data_root, pdb_id)
    lig_mesh_fpath = os.path.join(data_dir, 'ligand_mesh.npz')
    rec_mesh_fpath = os.path.join(data_dir, 'receptor_mesh.npz')
    if not (os.path.isfile(lig_mesh_fpath) and \
            os.path.isfile(rec_mesh_fpath)):
        return
    
    # atomic features
    lig_atom_info = read_xyzrn_file(data_dir, suffix='ligand') 
    rec_atom_info = read_xyzrn_file(data_dir, suffix='receptor')

    # load surface mesh
    lig_mesh = np.load(lig_mesh_fpath)
    rec_mesh = np.load(rec_mesh_fpath)
    
    # ligand Laplace-Beltrami basis
    lig_num_verts = lig_mesh['verts'].shape[0]
    lig_num_eigs = max(min_num_eigs, int(eigs_ratio*lig_num_verts)+1)
    assert lig_num_eigs < lig_num_verts
    lig_trimesh = TriMesh(verts=lig_mesh['verts'], faces=lig_mesh['faces'])
    lig_trimesh.LB_decomposition(k=lig_num_eigs) # scipy eigsh must have k < N

    # receptor Laplace-Beltrami basis
    rec_num_verts = rec_mesh['verts'].shape[0]
    rec_num_eigs = max(min_num_eigs, int(eigs_ratio*rec_num_verts)+1)
    assert rec_num_eigs < rec_num_verts
    rec_trimesh = TriMesh(verts=rec_mesh['verts'], faces=rec_mesh['faces'])
    rec_trimesh.LB_decomposition(k=rec_num_eigs) # scipy eigsh must have k < N
 
    # save features
    np.savez_compressed(fout, 
                        # ligand
                        lig_atom_info=lig_atom_info.astype(np.float32),
                        lig_verts=lig_mesh['verts'].astype(np.float32),
                        lig_faces=lig_mesh['faces'].astype(np.float32),                   
                        lig_eigen_vals=lig_trimesh.eigen_vals.astype(np.float32),
                        lig_eigen_vecs=lig_trimesh.eigen_vecs.astype(np.float32),
                        lig_mass=lig_trimesh.mass.astype(np.float32),
                        # receptor
                        rec_atom_info=rec_atom_info.astype(np.float32),
                        rec_verts=rec_mesh['verts'].astype(np.float32),
                        rec_faces=rec_mesh['faces'].astype(np.float32),                   
                        rec_eigen_vals=rec_trimesh.eigen_vals.astype(np.float32),
                        rec_eigen_vecs=rec_trimesh.eigen_vecs.astype(np.float32),
                        rec_mass=rec_trimesh.mass.astype(np.float32)
                        )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--eigs-ratio', type=float, default=0.08)
    parser.add_argument('--min-num-eigs', type=int, default=100)
    parser.add_argument('--serial', action='store_true')
    parser.add_argument('-j', type=int, default=4)
    parser.add_argument('--mute-tqdm', action='store_true')
    args = parser.parse_args()
    print(args)

    # optionally mute tqdm
    if args.mute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    # DB5 IO
    db5_mesh_dir = './DB5_mesh/'
    assert os.path.exists(db5_mesh_dir)
    db5_dataset_dir = './dataset_DB5/'
    os.makedirs(db5_dataset_dir, exist_ok=True)

    start = time.time()

    if not args.serial:
        pool = multiprocessing.Pool(processes=args.j)
        pool_args = [(db5_mesh_dir, db5_dataset_dir, pdb_id, args.min_num_eigs, args.eigs_ratio) \
                     for pdb_id in os.listdir(db5_mesh_dir)]
        pool.starmap(gen_dataset, tqdm(pool_args), chunksize=1)
        pool.terminate()
        print('All processes successfully finished')
    else:
        for pdb_id in tqdm(os.listdir(db5_mesh_dir)):
            gen_dataset(db5_mesh_dir, db5_dataset_dir, pdb_id, args.min_num_eigs, args.eigs_ratio)
    
    print(f'DB5 step5 elapsed time: {(time.time()-start):.1f}s\n')


