
import signal
import numpy as np
import time
import os, sys
import logging
import itertools
import functools
import warnings
from multiprocessing import Pool, TimeoutError
import h5py
import torch

sys.path.append('..')
from utils import spherical_to_cartesian__numpy

BACKBONE_ATOMS = [b' N  ', b' CA ', b' C  ', b' O  ']
BACKBONE_ATOMS_PLUS_CB = [b' N  ', b' CA ', b' C  ', b' O  ', b' CB ']


def process_data(nb, convert_to_cartesian=True, backbone_only=False, request_frame=False, get_H=False, get_SASA=False, get_charge=False):
    assert (process_data.callback)

    try:
        nb_id = nb['res_id']

        if convert_to_cartesian:
            cartesian_coords = spherical_to_cartesian__numpy(nb['coords'])
        else:
            cartesian_coords = nb['coords']

        if backbone_only:
            if len(nb['atom_names'][0]) != 4:
                print('Bug!: ', nb['atom_names'][0], file=sys.stderr)
                raise Exception
            CA_coords = torch.tensor(cartesian_coords[nb['atom_names'] == b' CA '])
            C_coords = torch.tensor(cartesian_coords[nb['atom_names'] == b' C  '])
            N_coords = torch.tensor(cartesian_coords[nb['atom_names'] == b' N  '])
            O_coords = torch.tensor(cartesian_coords[nb['atom_names'] == b' O  '])
        else:
            C_coords = torch.tensor(cartesian_coords[nb['elements'] == b'C'])
            N_coords = torch.tensor(cartesian_coords[nb['elements'] == b'N'])
            O_coords = torch.tensor(cartesian_coords[nb['elements'] == b'O'])
            S_coords = torch.tensor(cartesian_coords[nb['elements'] == b'S'])
            H_coords = torch.tensor(cartesian_coords[nb['elements'] == b'H'])
            SASA_coords = torch.tensor(cartesian_coords[nb['elements'] != b''])
            SASA_weights = torch.tensor(nb['SASAs'][nb['elements'] != b''])
            charge_coords = torch.tensor(cartesian_coords[nb['elements'] != b''])
            charge_weights = torch.tensor(nb['charges'][nb['elements'] != b''])

        if request_frame:
            try:
                central_res = np.logical_and.reduce(nb['res_ids'] == nb['res_id'], axis=-1)
                central_CA_coords = np.array([0.0, 0.0, 0.0]) # since we centered the neighborhood on the alpha carbon
                central_N_coords = np.squeeze(cartesian_coords[central_res][nb['atom_names'][central_res] == b' N  '])
                central_C_coords = np.squeeze(cartesian_coords[central_res][nb['atom_names'][central_res] == b' C  '])

                # if central_N_coords.shape[0] == 3:
                #     print('-----'*16)
                #     print(nb['res_id'])
                #     print(nb['res_ids'])
                #     print(nb['atom_names'])
                #     print(central_N_coords)
                #     print(central_C_coords)
                #     print(nb['atom_names'].shape)
                #     print(nb['coords'].shape)
                #     print('-----'*16)

                # assert that there is only one atom with three coordinates
                assert (central_CA_coords.shape[0] == 3), 'first assert'
                assert (len(central_CA_coords.shape) == 1), 'second assert'
                assert (central_N_coords.shape[0] == 3), 'third assert'
                assert (len(central_N_coords.shape) == 1), 'fourth assert'
                assert (central_C_coords.shape[0] == 3), 'fifth assert'
                assert (len(central_C_coords.shape) == 1), 'sixth assert'

                # y is unit vector perpendicular to x and lying on the plane between CA_N (x) and CA_C
                # z is unit vector perpendicular to x and the plane between CA_N (x) and CA_C
                x = central_N_coords - central_CA_coords
                x = x / np.linalg.norm(x)

                CA_C_vec = central_C_coords - central_CA_coords

                z = np.cross(x, CA_C_vec)
                z = z / np.linalg.norm(z)

                y = np.cross(z, x)

                frame = (x, y, z)
                    
            except Exception as e:
                print(e)
                print('No central residue (or other unwanted error).')
                frame = None
        else:
            frame = None

    except Exception as e:
        print(e)
        print(nb_id)
        print('Failed in process_data')
        return False
    
    if backbone_only:
        coords = [CA_coords, C_coords, N_coords, O_coords]
        weights = [None, None, None, None]
    else:
        coords = [C_coords,N_coords,O_coords,S_coords] #,H_coords,SASA_coords,charge_coords]
        weights = [None,None,None,None] #,None,SASA_weights,charge_weights]
        if get_H:
            coords += [H_coords]
            weights += [None]
        if get_SASA:
            coords += [SASA_coords]
            weights += [SASA_weights]
        if get_charge:
            coords += [charge_coords]
            weights += [charge_weights]

    return process_data.callback(coords, weights, frame, nb_id, **process_data.params)


def initializer(init, callback, params, init_params):
    if init is not None:
        init(**init_params)
    process_data.callback = callback
    process_data.params = params
    signal.signal(signal.SIGINT, signal.SIG_IGN)


class HDF5Preprocessor:
    def __init__(self, hdf5_file, hdf5_key, limit):
        with h5py.File(hdf5_file, 'r') as f:
            ## uncomment the following two lines to figure out how many neighborhoods there are in the loaded dataset
            ## then exit the program
            # print(np.array(f[hdf5_key].shape))
            # exit(1)
            data = np.unique(np.array(f[hdf5_key][:limit]), axis=0)
        self.__data = data

    def count(self):
        return self.__data.shape[0]

    def execute(self, callback, convert_to_cartesian = True, backbone_only = False, request_frame = False, get_H=False, get_SASA=False, get_charge=False, parallelism = None, limit = None, params = None, init = None, init_params = None):
        if limit is None:
            data = self.__data
        else:
            data = self.__data[:limit]
        with Pool(initializer = initializer, processes=parallelism, initargs = (init, callback, params, init_params)) as pool:

            process_data_hdf5 = functools.partial(
                process_data,
                convert_to_cartesian=convert_to_cartesian,
                backbone_only=backbone_only,
                request_frame=request_frame,
                get_H=get_H,
                get_SASA=get_SASA,
                get_charge=get_charge
            )

            for coords in pool.imap(process_data_hdf5, data):
                if coords:
                    yield coords

