import argparse
import mdtraj
import os
import tqdm
import pandas as pd 
from multiprocessing import Pool
import numpy as np
from SDE_model import residue_constants as rc
import Bio.PDB as PDB


parser = argparse.ArgumentParser()
parser.add_argument('--split', type=str, default='/mdcath/test.csv')
parser.add_argument('--sim_dir', type=str, default='/to_xtc_path')
parser.add_argument('--outdir', type=str, default='/saved_numpy_path')
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--suffix', type=str, default='i40')
parser.add_argument('--stride', type=int, default=1)
args = parser.parse_args()
os.makedirs(args.outdir, exist_ok=True)

df = pd.read_csv(args.split, index_col='name')
names = df.index

def get_pdb_residue_info(pdb_path):
    pdb_parser = PDB.PDBParser(QUIET=True)
    try:
        structure = pdb_parser.get_structure("protein", pdb_path)
        valid_residues = {}
        for model in structure:
            for chain in model:
                for residue in chain:
                    if PDB.is_aa(residue):
                        valid_residues[residue.id[1]] = (len(valid_residues), residue.resname)
        return valid_residues
    except Exception as e:
        print(f"error when reading pdb: {str(e)}")
        return None

def traj_to_atom14(traj, name):
    """transform the traj to atom14"""
    processing_log = []
    
    valid_residues = get_pdb_residue_info(f'{args.sim_dir}/{name}.pdb')
    if valid_residues is None:
        return None
        
    if not valid_residues:
        print(f"error：protein{name}.pdb does not include valid residual")
        return None

    arr = np.zeros((traj.n_frames, len(valid_residues), 14, 3), dtype=np.float16)

    for i, resi in enumerate(traj.top.residues):
        if resi.resSeq not in valid_residues:
            processing_log.append(f"ignore residual: {resi.name}{resi.resSeq} - not in valid sequence")
            continue
            
        valid_idx = valid_residues[resi.resSeq][0]
        
        if resi.name not in rc.restype_name_to_atom14_names:
            processing_log.append(f"error：unknow residual type {resi.name}{resi.resSeq}")
            continue
            
        atom14_names = rc.restype_name_to_atom14_names[resi.name]
        atoms_found = set()
        
        for at in resi.atoms:
            if at.name not in atom14_names:
                continue
            j = atom14_names.index(at.name)
            arr[:,valid_idx,j] = traj.xyz[:,at.index] * 10.0
            atoms_found.add(at.name)
            
        missing_atoms = set(atom14_names) - atoms_found
        if missing_atoms:
            processing_log.append(f"resi {resi.name}{resi.resSeq} lack atom: {missing_atoms}")


    if processing_log:
        print(f"\n process {name} log:")
        for log in processing_log:
            print(log)

    return arr

def process_normal_trajectory(name):
    try:
        if os.path.exists(f'{args.outdir}/{name}.npy'):
            print(f"jump {name}: file already exist")
            return
        
        name = name +'_0'
        traj = mdtraj.load(f'{args.sim_dir}/{name}.xtc', top=f'{args.sim_dir}/{name}.pdb')
        
        traj.atom_slice([a.index for a in traj.top.atoms if a.element.symbol != 'H'], inplace=True)
        

        traj.superpose(traj)

    
        arr = traj_to_atom14(traj, name)
        
        if arr is not None:
            np.save(f'{args.outdir}/{name}.npy', arr)
            print(f"success{name}")
        
    except Exception as e:
        print(f"process{name} found error: {str(e)}")

def do_job(name):
    process_normal_trajectory(name)

def main():

    jobs = [name for name in names if not os.path.exists(f'{args.outdir}/{name}.npy')]
    
    if not jobs:
        print("no file need processing")
        return
        
    print(f"find {len(jobs)} files")
    
    if args.num_workers > 1:
        with Pool(args.num_workers) as p:
            list(tqdm.tqdm(p.imap(do_job, jobs), total=len(jobs)))
    else:
        for job in tqdm.tqdm(jobs):
            do_job(job)
    
    print("done")

if __name__ == "__main__":
    main()