#
# This file computes the atomic spherical coordinates in a given set of
# neighborhoods and outputs a file with these coordinates.
#
# It takes as arguments:
#  - The name of the ouput file
#  - Name of central residue dataset
#  - Number of threads
#  - The neighborhood radius
#  - "easy" flag to include central res
#

from pyrosetta_hdf5_amino_acids import get_neighborhoods_from_protein, pad_neighborhoods
from preprocessor_hdf5_proteins import PDBPreprocessor
from argparse import ArgumentParser
import numpy as np
import h5py
import sys
import logging
from progress.bar import Bar
import traceback

def callback(np_protein, r):

    try:
        neighborhoods = get_neighborhoods_from_protein(np_protein, r=r)
        padded_neighborhoods = pad_neighborhoods(neighborhoods, padded_length=50)
    except Exception as e:
        print(e)
        print('Error with ', np_protein[0])
        return (None,)
    
    return (padded_neighborhoods)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--input_hdf5', type=str)
    parser.add_argument('--output_hdf5', type=str)
    parser.add_argument('--input_key', type=str, default='pdb_subsets/img=x-ray diffraction_max_res=2.5/split_0.8_0.2_0.0/train/pdbs')
    parser.add_argument('--parallelism', dest='parallelism', type=int, default=40)
    parser.add_argument('--radius', type=float, default=10.0)
    parser.add_argument('--n_proteins', type=int, default=700, # 11227 training proteins
                        help='number of proteins to collect amino-acids from, from the beginning')
    
    args = parser.parse_args()

    
    logging.basicConfig(level=logging.DEBUG)
    ds = PDBPreprocessor(args.input_hdf5, args.input_key)

    max_atoms = 50
    dt = np.dtype([
        ('res_id','S6', (6)), # S5, 5 (old) ; S6, 6 (new with 2ndary structure)
        ('atom_names', 'S4', (max_atoms)),
        ('elements', 'S1', (max_atoms)),
        ('res_ids', 'S6', (max_atoms, 6)), # S5, 5 (old) ; S6, 6 (new with 2ndary structure)
        ('coords', 'f8', (max_atoms, 3)),
        ('SASAs', 'f8', (max_atoms)),
        ('charges', 'f8', (max_atoms)),
    ])
    print(dt)
    print('writing hdf5 file')
    curr_size = 1000
    with h5py.File(args.output_hdf5, 'w') as f:
        # Initialize dataset
        f.create_dataset(args.input_key,
                         shape=(curr_size,),
                         maxshape=(None,),
                         dtype=dt)
    
    print('calling parallel process')
    with Bar('Processing', max = ds.count(), suffix='%(percent).1f%%') as bar:
        with h5py.File(args.output_hdf5,'r+') as f:
            n = 0
            for i, neighborhoods in enumerate(ds.execute(callback,
                                                        limit = args.n_proteins,
                                                        params = {'r': args.radius},
                                                        parallelism = args.parallelism)):
                print(i)

                if neighborhoods[0] is None:
                    bar.next()
                    continue
                
                for neighborhood in neighborhoods:

                    if n == curr_size:
                        curr_size += 1000
                        f[args.input_key].resize((curr_size,))
                    
                    f[args.input_key][n] = (*neighborhood,)
                    n += 1

                bar.next()

            # finally, resize dataset to be of needed shape to exactly contain the data and nothing more
            f[args.input_key].resize((n,))
    
    print('Done with parallel computing')
