#
# This file computes the atomic spherical coordinates in a given set of
# neighborhoods and outputs a file with these coordinates.
#
# It takes as arguments:
#  - The name of the ouput file
#  - Name of central residue dataset
#  - Number of threads
#  - The neighborhood radius
#  - "easy" flag to include central res
#

# fix for the bug: RuntimeError: received 0 items of ancdata
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
print(rlimit)
resource.setrlimit(resource.RLIMIT_NOFILE, (rlimit[1], rlimit[1]))

from preprocessor import HDF5Preprocessor
import numpy as np
import itertools
import os
from argparse import ArgumentParser
from progress.bar import Bar
import h5py
import sys
from projections import RadialSphericalTensor, MultiChannelRadialSphericalTensor, ZernickeRadialFunctions
sys.path.append('..')
from utils.argparse_utils import *
# sys.path.append('DIR')
# from utils.posterity import get_metadata, record_metadata
from coordinates import protein
import torch

def get_data_id(args, ignore_params):
    params = {}
    for arg in vars(args):
        if arg in ignore_params:
            continue
        if getattr(args,arg) == None:
            continue
        params[arg] = getattr(args,arg)
    
    tag = '-'.join(map(lambda x: str(x) + '=' + str(params[x]),
                       sorted(params)))
    return tag

def callback(coords_list, weights_list, frame, nb_id, rst, mul_rst, rst_normalization, mul_rst_normalization, n_channels):
    assert (n_channels == len(coords_list) or n_channels == 1) # should be 7 for yes_sidechain: [C, O, N, S, H, SASA, charge], 4 for no_sidechain: [CA, C, N, O]
    assert (n_channels == len(weights_list) or n_channels == 1) # should be 7 for yes_sidechain: [C, O, N, S, H, SASA, charge], 4 for no_sidechain: [CA, C, N, O]

    try:
        if n_channels == 1:
            coeffs = rst.with_peaks_at(torch.cat(coords_list, dim=0), None, normalization=rst_normalization)
        else:
            disentangled_coeffs = []
            for coords, weights in zip(coords_list, weights_list):
                disentangled_coeffs.append(rst.with_peaks_at(coords, weights, normalization=rst_normalization))
            coeffs = mul_rst.combine(torch.stack(disentangled_coeffs, dim=0), normalization=mul_rst_normalization)

        if torch.any(torch.isnan(coeffs)):
            print('------ NAN ------')
            print(nb_id)
            print(torch.any(torch.isnan(disentangled_coeffs[0])))
            print(torch.any(torch.isnan(disentangled_coeffs[1])))
            print(torch.any(torch.isnan(disentangled_coeffs[2])))
            print(torch.any(torch.isnan(disentangled_coeffs[3])))
            print(coords_list[0])
            print(coords_list[1])
            print(coords_list[2])
            print(coords_list[3])
            print('------ NAN ------')
            raise Exception

        if frame is not None:
            frame = torch.stack(tuple(map(lambda x: torch.tensor(x), frame)))
        else:
            # print(nb_id)
            # print('Frame is None. Either it was not requested or something is wrong with central amino acid.')
            return (coeffs.numpy(), None, nb_id)

    except Exception as e:
        print(e)
        print(nb_id)
        print('Failed in callback')
        return (None, None, nb_id)

    return (coeffs.numpy(), frame.numpy(), nb_id)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--input_hdf5', type=str)
    parser.add_argument('--split', type=str, default='train')
    parser.add_argument('--input_key', type=str, default='pdb_subsets/img=x-ray diffraction_max_res=2.5/split_0.8_0.2_0.0/{}/pdbs')
    parser.add_argument('--parallelism', type=int, default=40)
    parser.add_argument('--projection', type=str, default='zernicke',
                        help='family of radial functions to use for projections')
    parser.add_argument('--rmax', type=int, default=20,
                        help='maximum radial order')
    parser.add_argument('--rcut', type=float, default=10.0,
                        help='max radius value to expect in point clouds') # 10 for Mike's casp12 data
    parser.add_argument('--lmax', type=int, default=4,
                        help='maximum spherical order')
    parser.add_argument('--complex_sph', type=str_to_bool, default=False,
                        help='whether to use a basis of complex spherical harmonics')
    parser.add_argument('--n_channels', type=int, default=4,
                        help='Number of channels to expect from the point cloud. If 1, then atom identity is not considered, and all atoms (possibly including Hs if get_H is set to True) are considered.')
    parser.add_argument('--n_neigh', type=int, default=80000, # [training: 1318541 for the first 2500 proteins]  [training H,E,C: 626052 for 11226 proteins with H,E,C | validation H,E,C: 159468 | testing H,E,C: 11579]  [training: 1042144 for first 1200 proteins; validation: 259102 for first 500 proteins] [training: 277113 for first 5000 proteins with H,E,C; validation: 55129 for first 1000 proteins with H,E,C] [training: 107505 for first 2000 proteins with H,E,C AAs; validation: 21820 for first 400 proteins with H,E,C], [531847 for first training 1000 proteins, ], [126025 for the first 250 training proteins, 30553 for the first 50 validation proteins], [48317 for first 100 training proteins, 13443 for first 20 validation proteins]
                        help='number of neighborhoods to collect, from the beginning')
    parser.add_argument('--convert_to_cartesian', type=bool, default=False)
    parser.add_argument('--rst_normalization', type=str, default='square')
    parser.add_argument('--mul_rst_normalization', type=str, default=None)
    parser.add_argument('--neigh_kind', type=str, default='residue_only',
                        help='for filesaving purposes')
    parser.add_argument('--backbone_only', type=str_to_bool, default=False)
    parser.add_argument('--request_frame', type=str_to_bool, default=True)
    parser.add_argument('--get_H', type=str_to_bool, default=False,
                        help='Only used if backbone_only is False.')
    parser.add_argument('--get_SASA', type=str_to_bool, default=False,
                        help='Only used if backbone_only is False.')
    parser.add_argument('--get_charge', type=str_to_bool, default=False,
                        help='Only used if backbone_only is False.')
    
    args = parser.parse_args()

    ignore_params = ['input_hdf5', 'input_key', 'parallelism', 'projection', 'split', 'neigh_kind', 'convert_to_cartesian', 'request_frame', 'backbone_only']
    if args.backbone_only:
        ignore_params += ['get_H', 'get_SASA', 'get_charge']
    
    args.input_key = args.input_key.format(args.split) if args.split != 'test' else 'pdb_list'

    if 'toy_aminoacids' in args.input_hdf5: # spaghetti-y piece just for compatibility purposes, makes my life easiest at this moment
        args.split = 'all'

    nRadialFunctions = args.rmax + 1
    p_val = 1 # parity of l-even irreps (1 = even, -1 = odd); must be 1 for projection onto spherical harmonics basis
    p_arg = 1 # p_val*p_arg = parity of l-odd irreps; doesn't matter for the projection onto spherical harmonics basis, only affects how the representation transforms
              # setting this to 1 since the CGNet operates under such assumption, but it really doesn't matter in the context of this script

    ds = HDF5Preprocessor(args.input_hdf5, args.input_key, args.n_neigh)
    OnRadialFunctions = ZernickeRadialFunctions(args.rcut, nRadialFunctions, args.lmax, complex_sph = args.complex_sph, record_zeros = False)
    rst = RadialSphericalTensor(nRadialFunctions, OnRadialFunctions, args.lmax, p_val, p_arg)
    mul_rst = MultiChannelRadialSphericalTensor(rst, args.n_channels)

    labels = []
    projections = []
    frames = []
    n = 0
    data_ids = []
    i = 0
    t = 0
    with Bar('Processing', max = ds.count(), suffix='%(percent).1f%%') as bar:
        for proj, frame, nb_id in ds.execute(callback,
                                  limit = None,
                                  convert_to_cartesian = args.convert_to_cartesian,
                                  backbone_only = args.backbone_only,
                                  request_frame = args.request_frame,
                                  get_H = args.get_H,
                                  get_SASA = args.get_SASA,
                                  get_charge = args.get_charge,
                                  params = {'rst': rst, 'mul_rst': mul_rst, 'n_channels': args.n_channels, 'rst_normalization': args.rst_normalization, 'mul_rst_normalization': args.mul_rst_normalization},
                                  parallelism = None):
            t += 1
            
            if proj is None:
                print(nb_id,' returned error')
                i += 1
                bar.next()
                continue
            
            if nb_id[0].decode('utf-8') not in protein.aa_to_ind_short:
                print('Got invalid residue type "{}".'.format(nb_id[0].decode('utf-8')))
                i += 1
                bar.next()
                continue

            if args.request_frame:
                if frame is None:
                    print(nb_id,' returned None frame when frame was requested')
                    i += 1
                    bar.next()
                    continue
            
            if t % 100 == 0:
                print('\n\n Status ', i/t, '\n\n')

            data_ids.append(nb_id)
            projections.append(proj)
            frames.append(frame)
            
            labels.append(protein.aa_to_ind_short[nb_id[0].decode('utf-8')])
            
            n += 1
            bar.next()

    data_id = get_data_id(args, ignore_params)
    print(data_id)

    
    projections = np.vstack(projections)
    labels = np.hstack(labels)
    
    np.save('{}/{}/projections-{}-'.format(args.projection, args.neigh_kind, args.split) + data_id + '.npy', projections)

    if args.request_frame:
        frames = np.vstack(frames).reshape(-1, 3, 3)
        np.save('{}/{}/frames-{}-'.format(args.projection, args.neigh_kind, args.split) + data_id + '.npy', frames)

    np.save('{}/{}/aa_labels-{}-'.format(args.projection, args.neigh_kind, args.split) + data_id + '.npy', labels)
    np.save('{}/{}/data_ids-{}-'.format(args.projection, args.neigh_kind, args.split) + data_id + '.npy', np.array(data_ids))
