
import os, sys
import h5py
import numpy as np

import argparse



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_hdf5', type=str)
    parser.add_argument('--input_key', type=str, default='pdb_subsets/img=x-ray diffraction_max_res=2.5/split_0.8_0.2_0.0/train/pdbs')
    parser.add_argument('--output_hdf5', type=str)
    parser.add_argument('--num_examples_per_residue_type', type=int, default=4000)
    parser.add_argument('--seed', type=int, default=10000000)

    args = parser.parse_args()

    rng = np.random.default_rng(args.seed)
    
    with h5py.File(args.input_hdf5,'r') as f:
        all_aminoacids_N = np.array(f[args.input_key])
        N = all_aminoacids_N.shape[0]
        print(N)
    

    def stringify(data_id):
        return '_'.join(list(map(lambda x: x.decode('utf-8'), list(data_id))))
    
    residue_types_N = np.array([aminoacid[0][0] for aminoacid in all_aminoacids_N])

    from collections import Counter
    residue_types = list(dict(Counter(residue_types_N)).keys())
    residue_types.remove(b'Z')
    size_of_toy_aminoacids_dataset = args.num_examples_per_residue_type*len(residue_types)
    args.output_hdf5 = args.output_hdf5.format(size_of_toy_aminoacids_dataset)

    
    max_atoms = 50
    dt = np.dtype([
        ('res_id','S6', (6)), # S5, 5 (old) ; S6, 6 (new with 2ndary structure)
        ('atom_names', 'S4', (max_atoms)),
        ('elements', 'S1', (max_atoms)),
        ('res_ids', 'S6', (max_atoms, 6)), # S5, 5 (old) ; S6, 6 (new with 2ndary structure)
        ('coords', 'f8', (max_atoms, 3)),
        ('SASAs', 'f8', (max_atoms)),
        ('charges', 'f8', (max_atoms)),
    ])
    with h5py.File(args.output_hdf5, 'w') as f:
        # Initialize dataset
        f.create_dataset(args.input_key,
                         shape=(size_of_toy_aminoacids_dataset,),
                         maxshape=(None,),
                         dtype=dt)
    
    with h5py.File(args.output_hdf5,'r+') as f:
        n = 0
        for residue_type in residue_types:
            idxs_of_type = np.arange(N)[residue_types_N == residue_type]
            samples_idxs = rng.choice(idxs_of_type, args.num_examples_per_residue_type, replace=False)
            aminoacids_of_type = all_aminoacids_N[samples_idxs]

            for i, aminoacid in enumerate(aminoacids_of_type):
                print('\r%d/%d' % (n+1, size_of_toy_aminoacids_dataset), end='')
                sys.stdout.flush()
                f[args.input_key][n] = (*aminoacid,)
                n += 1
    print()
    


    


    




