"""Script for processing PDB files to remove their padding residues."""

import argparse
import functools as fn
import os
import multiprocessing as mp
import time
import glob
from Bio.PDB import PDBParser
from Bio.PDB.PDBIO import PDBIO

ALL_RESTYPES = {
    'ALA',
    'ARG',
    'ASN',
    'ASP',
    'CYS',
    'GLN',
    'GLU',
    'GLY',
    'HIS',
    'ILE',
    'LEU',
    'LYS',
    'MET',
    'PHE',
    'PRO',
    'SER',
    'THR',
    'TRP',
    'TYR',
    'VAL'
}


# Define the parser
parser = argparse.ArgumentParser(
    description='PDB reduce script.')
parser.add_argument(
    'src_dir',
    help='Directory with source PDB files.',
    type=str)
parser.add_argument(
    'dest_dir',
    help='Directory to save processed PDB files.',
    type=str)
parser.add_argument(
    '--num_processes',
    help='Number of processes.',
    type=int,
    default=20)
parser.add_argument(
    '--verbose',
    help='Whether to log everything.',
    action='store_true')

class DataError(Exception):
    """Data exception."""
    pass

def process_pdb(pdb_path: str, dest_dir: str):
    parser = PDBParser()
    pdb_file = os.path.basename(pdb_path)
    structure = parser.get_structure(
        pdb_file.replace('.pdb', ''), pdb_path)
    chains = list(structure.get_chains())
    if len(chains) > 1:
        raise DataError(f'Multiple chains found for {pdb_path}')
    monomer_chain = chains[0]
    residues = list(monomer_chain.get_residues())
    for res in residues:
        if res.get_resname() not in ALL_RESTYPES:
            monomer_chain.detach_child(res.id)
    io = PDBIO()
    io.set_structure(monomer_chain)
    dest_path = os.path.join(dest_dir, pdb_file)
    io.save(dest_path)

def process_fn(pdb_path, dest_dir=None, verbose=None):
    try:
        start_time = time.time()
        _ = process_pdb(pdb_path, dest_dir)
        elapsed_time = time.time() - start_time
        if verbose:
            print(f'Finished {pdb_path} in {elapsed_time:2.2f}s')
        return True
    except DataError as e:
        if verbose:
            print(f'Failed {pdb_path}: {e}')
        return False

def main(args):
    # Get all PDB files to read.
    all_pdb_paths = glob.glob(os.path.join(args.src_dir, '*'))
    total_num_paths = len(all_pdb_paths)
    dest_dir = args.dest_dir
    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir)
    print(f'Files will be written to {dest_dir}')

    # Process each PDB file
    _process_fn = fn.partial(
        process_fn,
        verbose=args.verbose,
        dest_dir=dest_dir)
    with mp.Pool(processes=args.num_processes) as pool:
        worker_callbacks = pool.map(_process_fn, all_pdb_paths)
    succeeded = sum(worker_callbacks)
    print(
        f'Finished processing {succeeded}/{total_num_paths} files')


if __name__ == "__main__":
    args = parser.parse_args()
    main(args)
