
'''
Computes statistics in an array that is parallel to the inputted data_ids
'''

from ctypes.wintypes import ATOM
import os, sys
from re import A
import h5py
import gzip, pickle
import numpy as np
from tqdm import tqdm
import argparse

sys.path.append('../../..')
from utils.argparse_utils import *

ATOM_TYPES_TO_GET = [b' CA ', b' N  ', b' C  ', b' O  ']
ATOM_TYPES_TO_GET_SET = {b' CA ', b' N  ', b' C  ', b' O  '}
ATOM_TYPES_TO_GET_STRINGNAMES = {
    b' CA ': '$\\text{C}_{\\alpha}$',
    b' N  ': 'N',
    b' C  ': 'C',
    b' O  ': 'O'
}
IDX_OF_ATOM_TYPES = 1
IDX_OF_SASA = 5

def stringify(data_id):
    return '_'.join(list(map(lambda x: x.decode('utf-8'), list(data_id))))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--raw_hdf5_data', type=str)
    parser.add_argument('--hdf5_key', type=str, default='pdb_subsets/img=x-ray diffraction_max_res=2.5/split_0.8_0.2_0.0/val/pdbs') # pdb_subsets/img=x-ray diffraction_max_res=2.5/split_0.8_0.2_0.0/{train-val}/pdbs, pdb_list
    parser.add_argument('--output_file', type=str)
    
    args = parser.parse_args()


    # pipeline for raw data statistics:
    #       1) compute statistics for all neighborhoods (even if there is redundancy)
    #       2) put statistics in dict, indexed by stringified data_ids

    
    with h5py.File(args.raw_hdf5_data, 'r') as f:
        raw_data = np.unique(np.array(f[args.hdf5_key]), axis=0)
    

    # print(raw_data[0][1])
    # exit(1)


    ## 1) compute statistics for all neighborhoods (even if there is redundancy)
    data_ids_from_raw_M = []
    num_atoms_M = []
    avg_sasa_M = []
    num_atoms_of_type_M = {}
    for atom_type in ATOM_TYPES_TO_GET:
        num_atoms_of_type_M[atom_type] = []
    
    for neighborhood in raw_data:
        data_ids_from_raw_M.append(stringify(neighborhood[0]))

        num_atoms_of_interest = 0
        for atom_type in ATOM_TYPES_TO_GET:
            num_atoms_of_type = np.sum(neighborhood[IDX_OF_ATOM_TYPES] == atom_type)
            num_atoms_of_type_M[atom_type].append(num_atoms_of_type)
            num_atoms_of_interest += num_atoms_of_type
        num_atoms_M.append(num_atoms_of_interest)

        avg_sasa_M.append(np.mean(neighborhood[IDX_OF_SASA][neighborhood[IDX_OF_ATOM_TYPES] != b'']))
    

    ## 2) put statistics in dict, indexed by stringified data_ids
    num_atoms_dict = {}
    avg_sasa_dict = {}
    num_atoms_of_type_dict = {}
    for atom_type in ATOM_TYPES_TO_GET:
        num_atoms_of_type_dict[atom_type] = {}
    
    for idx, data_id in enumerate(data_ids_from_raw_M):
        num_atoms_dict[data_id] = num_atoms_M[idx]
        avg_sasa_dict[data_id] = avg_sasa_M[idx]
        for atom_type in ATOM_TYPES_TO_GET:
            num_atoms_of_type_dict[atom_type][data_id] = num_atoms_of_type_M[atom_type][idx]
    
    output_dict = {
        'num_atoms': num_atoms_dict,
        'avg_sasa' : avg_sasa_dict,
        'num_atoms_of_type': num_atoms_of_type_dict,
    }
    with gzip.open(args.output_file, 'wb') as f:
        pickle.dump(output_dict, f)
    

    # ## 3) extract array of statistics of interest, parallel to the given data_ids
    # data_ids_from_projections_N = list(map(stringify, np.load(os.path.join(args.projections_data_dir, 'data_ids-test-%s.npy' % (args.id)))))
    # num_atoms_N = []
    # num_atoms_of_type_N = {}
    # for atom_type in ATOM_TYPES_TO_GET:
    #     num_atoms_of_type_N[atom_type] = []

    # for data_id in data_ids_from_projections_N:
    #     num_atoms_N.append(num_atoms_dict[data_id])
    #     for atom_type in ATOM_TYPES_TO_GET:

