#!/usr/bin/python3

####################################################################################################

####################################################################################################

'''
Lifelong Machine Learning Potentials (lMLP)
'''
__copyright__ = '''This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.'''

####################################################################################################

####################################################################################################

from typing import List, Tuple
import numpy as np
from numba import prange   # type: ignore
from numpy import pi
from numpy.typing import NDArray
from .performance import ncfjit, ncfpjit


####################################################################################################

####################################################################################################

@ncfjit
def prepare_periodic_cell(positions, lattice, n_atoms, n_atoms_sys, R_c) -> Tuple[
        NDArray, NDArray, bool, int]:
    '''
    Return: positions, lattice, pbc_required, n_images_tot
    '''
    # get fractional coordinates
    positions = np.linalg.solve(lattice.T, positions.T).T

    # for QM/MM align center of system atoms and center of cell
    if n_atoms_sys < n_atoms:
        positions -= (np.sum(positions[:n_atoms_sys], axis=0) / n_atoms_sys - 0.5)

    # wrap atoms into original cell
    positions %= 1.0

    # determine if periodic boundary conditions are required according to the space between atoms in
    # the original cell and those in neighboring cells employing the minimal and maximal fractional
    # coordinates in each direction and the respective height of the cell
    normal = [np.cross(lattice[1], lattice[2]),
              np.cross(lattice[2], lattice[0]),
              np.cross(lattice[0], lattice[1])]
    normal = [normal[0] / np.sqrt(np.dot(normal[0], normal[0])),
              normal[1] / np.sqrt(np.dot(normal[1], normal[1])),
              normal[2] / np.sqrt(np.dot(normal[2], normal[2]))]
    height = [abs(np.dot(normal[0], lattice[0])),
              abs(np.dot(normal[1], lattice[1])),
              abs(np.dot(normal[2], lattice[2]))]
    pbc = np.array([(np.min(positions[:n_atoms_sys, 0]) + 1.0 - np.max(positions[:n_atoms_sys, 0]))
                    * height[0] < R_c,
                    (np.min(positions[:n_atoms_sys, 1]) + 1.0 - np.max(positions[:n_atoms_sys, 1]))
                    * height[1] < R_c,
                    (np.min(positions[:n_atoms_sys, 2]) + 1.0 - np.max(positions[:n_atoms_sys, 2]))
                    * height[2] < R_c])
    pbc_required = bool(np.any(pbc))

    # reconvert fractional coordinates
    positions = np.dot(positions, lattice)

    # expand cell if its heights are smaller than the cutoff radius
    n_images_tot = 1
    if pbc_required:
        for i in range(3):
            if height[i] < R_c:
                n_images = int(R_c / height[i]) + 1
                positions = (np.ones((n_images, n_atoms, 3)) * positions).reshape((
                    n_images * n_atoms, 3))
                for j in range(1, n_images):
                    positions[j * n_atoms:(j + 1) * n_atoms] += j * lattice[i]
                lattice[i] *= n_images
                n_atoms *= n_images
                n_images_tot *= n_images

    return positions, lattice, pbc_required, n_images_tot


####################################################################################################

@ncfjit
def get_periodic_images(positions, lattice, n_atoms) -> NDArray:
    '''
    Return: positions_all
    '''
    # get adjacent periodic images (3x3x3 supercell)
    positions_all = (np.ones((27, n_atoms, 3)) * positions).reshape((27 * n_atoms, 3))
    positions_all[1 * n_atoms:10 * n_atoms] -= lattice[0]
    positions_all[18 * n_atoms:27 * n_atoms] += lattice[0]
    positions_all[1 * n_atoms:4 * n_atoms] -= lattice[1]
    positions_all[7 * n_atoms:10 * n_atoms] += lattice[1]
    positions_all[10 * n_atoms:13 * n_atoms] -= lattice[1]
    positions_all[15 * n_atoms:18 * n_atoms] += lattice[1]
    positions_all[18 * n_atoms:21 * n_atoms] -= lattice[1]
    positions_all[24 * n_atoms:27 * n_atoms] += lattice[1]
    for i in range(1, 13, 3):
        positions_all[i * n_atoms:(i + 1) * n_atoms] -= lattice[2]
        positions_all[(i + 2) * n_atoms:(i + 3) * n_atoms] += lattice[2]
    positions_all[13 * n_atoms:14 * n_atoms] -= lattice[2]
    positions_all[14 * n_atoms:15 * n_atoms] += lattice[2]
    for i in range(15, 27, 3):
        positions_all[i * n_atoms:(i + 1) * n_atoms] -= lattice[2]
        positions_all[(i + 2) * n_atoms:(i + 3) * n_atoms] += lattice[2]

    return positions_all


####################################################################################################

@ncfjit
def get_triple_properties(elements_int_j, R_ij, dR_ij__dalpha_i_, interaction_classes_j,
                          atomic_charges_j, QMMM, calc_derivatives=True):
    '''
    Return: ijk
    '''
    # initialize arrays
    n_neighbors = len(R_ij)
    M = (n_neighbors - 1) * n_neighbors // 2
    ijk_0, ijk_1 = np.empty(M, dtype=np.int64), np.empty(M, dtype=np.int64)
    ijk_2, ijk_3, ijk_6 = np.empty(M, dtype=float), np.empty(M, dtype=float), \
        np.empty(M, dtype=float)

    # initialize arrays for derivative properties
    if calc_derivatives:
        M_prime = M
    else:
        M_prime = 0
    ijk_4, ijk_5, ijk_7, ijk_8, ijk_9 = np.empty((M_prime, 3), dtype=float), \
        np.empty((M_prime, 3), dtype=float), np.empty((M_prime, 3), dtype=float), \
        np.empty((M_prime, 3), dtype=float), np.empty((M_prime, 3), dtype=float)

    # initialize arrays for QM/MM properties
    if QMMM:
        M_QMMM = M
    else:
        M_QMMM = 0
    ijk_10 = np.empty(M_QMMM, dtype=np.int64)
    ijk_11, ijk_12 = np.empty(M_QMMM, dtype=float), np.empty(M_QMMM, dtype=float)

    # get all unique neighbor combinations (no double counting)
    j, k = np.triu_indices(n_neighbors, k=1)
    m = j * (2 * n_neighbors - j - 3) // 2 + k - 1

    # insert properties into arrays
    ijk_0[m] = elements_int_j[j]
    ijk_1[m] = elements_int_j[k]
    ijk_2[m] = R_ij[j]
    ijk_3[m] = R_ij[k]
    # calculate cos(theta)
    cos_theta = np.minimum(np.maximum(np.dot(dR_ij__dalpha_i_, dR_ij__dalpha_i_.T),
                           -0.999999999999), 0.999999999999)
    cos_theta_ijk = np.empty(M, dtype=float)
    for i in range(M):
        cos_theta_ijk[i] = cos_theta[j[i], k[i]]
    ijk_6[m] = cos_theta_ijk

    # insert derivative properties into arrays
    if calc_derivatives:
        ijk_4[m] = dR_ij__dalpha_i_[j]
        ijk_5[m] = dR_ij__dalpha_i_[k]
        # calculate the derivative of cos(theta)
        j1 = (dR_ij__dalpha_i_[j].T / R_ij[k]).T
        j2 = (dR_ij__dalpha_i_[j].T / R_ij[j]).T
        k1 = (dR_ij__dalpha_i_[k].T / R_ij[j]).T
        k2 = (dR_ij__dalpha_i_[k].T / R_ij[k]).T
        ijk_7[m] = j1 + k1 - (cos_theta_ijk * (j2 + k2).T).T
        ijk_8[m] = -k1 + (cos_theta_ijk * j2.T).T
        ijk_9[m] = -j1 + (cos_theta_ijk * k2.T).T

    # insert QM/MM properties into arrays
    if QMMM:
        ijk_10[m] = interaction_classes_j[j] + interaction_classes_j[k]
        ijk_11[m] = atomic_charges_j[j]
        ijk_12[m] = atomic_charges_j[k]

    return (ijk_0, ijk_1, ijk_2, ijk_3, ijk_4, ijk_5, ijk_6, ijk_7, ijk_8, ijk_9, ijk_10, ijk_11,
            ijk_12)


####################################################################################################

@ncfpjit
def calc_descriptor_derivative(
        ij_0, ij_1, ij_2, ij_3, ij_4, ijk_0, ijk_1, ijk_2, ijk_3, ijk_4, ijk_5, ijk_6, ijk_7, ijk_8,
        ijk_9, ijk_10, ijk_11, ijk_12, neighbor_index, n_atoms, n_atoms_sys, n_descriptors,
        elem_func_index, rad_func_index, ang_func_index, scale_func_index, R_c, eta_ij,
        H_parameters_rad, H_parameters_rad_scale, n_parameters_rad, element_types_rad, eta_ijk,
        lambda_ijk, zeta_ijk, xi_ijk, H_parameters_ang, H_parameters_ang_scale,
        n_parameters_ang, H_type_jk, n_H_parameters, element_types_ang, calc_derivatives, QMMM,
        active_atom, n_atoms_active, neighbor_index_env, I_type_j, n_parameters_rad_env,
        I_type_jk, n_parameters_ang_env, MM_atomic_charge_max) -> Tuple[
            List[NDArray], List[NDArray], List[NDArray], List[NDArray]]:
    '''
    Return: descriptor, descriptor_i_derivative, descriptor_neighbor_derivative,
            descriptor_neighbor_derivative_env
    '''
    # calculate symmetry function values and derivatives
    descriptor = [np.zeros(0)] * n_atoms_sys
    descriptor_i_derivative = [np.zeros((0, 3))] * n_atoms_sys
    descriptor_neighbor_derivative = [np.zeros((0, 0, 3))] * n_atoms_sys
    descriptor_j_derivatives_rad = [np.zeros((0, 0, 3))] * n_atoms_sys
    descriptor_j_derivatives_ang = [np.zeros((0, 0, 3))] * n_atoms_sys
    descriptor_k_derivatives_ang = [np.zeros((0, 0, 3))] * n_atoms_sys
    # radial and angular symmetry functions
    for i in prange(n_atoms_sys):
        G_rad, dG_rad__dalpha_i_, dG_rad__dalpha_j_ = calc_rad_sym_func(
            ij_0[i], ij_1[i], ij_2[i], ij_3[i], ij_4[i], elem_func_index, rad_func_index,
            scale_func_index, R_c, eta_ij, H_parameters_rad, H_parameters_rad_scale,
            n_parameters_rad, element_types_rad, calc_derivatives, QMMM, I_type_j,
            MM_atomic_charge_max)
        G_ang, dG_ang__dalpha_i_, dG_ang__dalpha_j_, dG_ang__dalpha_k_ = calc_ang_sym_func(
            ijk_0[i], ijk_1[i], ijk_2[i], ijk_3[i], ijk_4[i], ijk_5[i], ijk_6[i], ijk_7[i],
            ijk_8[i], ijk_9[i], ijk_10[i], ijk_11[i], ijk_12[i], elem_func_index, rad_func_index,
            ang_func_index, scale_func_index, R_c, eta_ijk, lambda_ijk, zeta_ijk, xi_ijk,
            H_parameters_ang, H_parameters_ang_scale, n_parameters_ang, H_type_jk, n_H_parameters,
            element_types_ang, calc_derivatives, QMMM, I_type_jk, MM_atomic_charge_max)
        # compile symmetry function values
        descriptor[i] = np.concatenate((G_rad, G_ang))
        # compile symmetry function derivatives with respect to central atom i
        if calc_derivatives:
            descriptor_i_derivative[i] = np.concatenate((dG_rad__dalpha_i_, dG_ang__dalpha_i_))
            # compile symmetry function derivatives with respect to neighbor atoms j and k
            descriptor_j_derivatives_rad[i] = dG_rad__dalpha_j_
            descriptor_j_derivatives_ang[i] = dG_ang__dalpha_j_
            descriptor_k_derivatives_ang[i] = dG_ang__dalpha_k_
    # calculate symmetry function derivatives with respect to neighbor atoms of atom i
    if calc_derivatives:
        for i in prange(n_atoms_sys):
            n_neighbors = len(neighbor_index[i][neighbor_index[i] % n_atoms < n_atoms_sys])
            if n_neighbors > 0:
                descriptor_neighbor_derivative[i] = np.zeros((n_neighbors, n_descriptors, 3))
                for j in range(n_neighbors):
                    if neighbor_index[i][j] >= n_atoms:
                        index = (27 - neighbor_index[i][j] // n_atoms) * n_atoms + i
                    else:
                        index = i
                    descriptor_neighbor_derivative[i][j] = calc_descriptor_neighbor_derivative(
                        descriptor_j_derivatives_rad[neighbor_index[i][j] % n_atoms],
                        descriptor_j_derivatives_ang[neighbor_index[i][j] % n_atoms],
                        descriptor_k_derivatives_ang[neighbor_index[i][j] % n_atoms], index,
                        neighbor_index[neighbor_index[i][j] % n_atoms], n_descriptors,
                        n_parameters_rad, n_parameters_ang, 0, 0)

    # calculate symmetry function derivatives with respect to neighbor active environment atoms of
    # atom i
    if QMMM and calc_derivatives:
        n_descriptors_env = n_parameters_rad_env + n_parameters_ang_env
        n_parameters_rad_env_start = n_parameters_rad - n_parameters_rad_env
        n_parameters_ang_env_start = n_parameters_ang - n_parameters_ang_env
        n_atoms_env = n_atoms_active - n_atoms_sys
        descriptor_neighbor_derivative_env = [np.zeros((0, 0, 3))] * n_atoms_env
        for i in prange(n_atoms_env):
            n_neighbors_env = len(neighbor_index_env[i])
            if n_neighbors_env > 0:
                descriptor_neighbor_derivative_env[i] = np.zeros((
                    n_neighbors_env, n_descriptors_env, 3))
                for j in range(n_neighbors_env):
                    if neighbor_index_env[i][j] >= n_atoms_sys:
                        index = ((27 - neighbor_index_env[i][j] // n_atoms_sys) * n_atoms
                                 + active_atom[i + n_atoms_sys])
                    else:
                        index = active_atom[i + n_atoms_sys]
                    descriptor_neighbor_derivative_env[i][j] = calc_descriptor_neighbor_derivative(
                        descriptor_j_derivatives_rad[neighbor_index_env[i][j] % n_atoms_sys],
                        descriptor_j_derivatives_ang[neighbor_index_env[i][j] % n_atoms_sys],
                        descriptor_k_derivatives_ang[neighbor_index_env[i][j] % n_atoms_sys],
                        index, neighbor_index[neighbor_index_env[i][j] % n_atoms_sys],
                        n_descriptors_env, n_parameters_rad_env, n_parameters_ang_env,
                        n_parameters_rad_env_start, n_parameters_ang_env_start)
    else:
        descriptor_neighbor_derivative_env = [np.zeros((0, 0, 3))]

    return descriptor, descriptor_i_derivative, descriptor_neighbor_derivative, \
        descriptor_neighbor_derivative_env


####################################################################################################

@ncfpjit
def calc_descriptor_derivative_radial(
        ij_0, ij_1, ij_2, ij_3, ij_4, neighbor_index, n_atoms, n_atoms_sys, n_descriptors,
        elem_func_index, rad_func_index, scale_func_index, R_c, eta_ij, H_parameters_rad,
        H_parameters_rad_scale, n_parameters_rad, element_types_rad, calc_derivatives, QMMM,
        active_atom, n_atoms_active, neighbor_index_env, I_type_j, n_parameters_rad_env,
        MM_atomic_charge_max) -> Tuple[List[NDArray], List[NDArray], List[NDArray], List[NDArray]]:
    '''
    Return: descriptor, descriptor_i_derivative, descriptor_neighbor_derivative,
            descriptor_neighbor_derivative_env
    '''
    # calculate symmetry function values and derivatives
    descriptor = [np.zeros(0)] * n_atoms_sys
    descriptor_i_derivative = [np.zeros((0, 3))] * n_atoms_sys
    descriptor_neighbor_derivative = [np.zeros((0, 0, 3))] * n_atoms_sys
    descriptor_j_derivatives_rad = [np.zeros((0, 0, 3))] * n_atoms_sys
    zeros = np.zeros((0, 0, 0))
    # only radial symmetry functions
    for i in prange(n_atoms_sys):
        descriptor[i], dG_rad__dalpha_i_, dG_rad__dalpha_j_ = calc_rad_sym_func(
            ij_0[i], ij_1[i], ij_2[i], ij_3[i], ij_4[i], elem_func_index, rad_func_index,
            scale_func_index, R_c, eta_ij, H_parameters_rad, H_parameters_rad_scale,
            n_parameters_rad, element_types_rad, calc_derivatives, QMMM, I_type_j,
            MM_atomic_charge_max)
        # compile symmetry function values
        if calc_derivatives:
            # compile symmetry function derivatives with respect to central atom i
            descriptor_i_derivative[i] = np.concatenate((dG_rad__dalpha_i_, np.zeros((0, 3))))
            # compile symmetry function derivatives with respect to neighbor atoms j
            descriptor_j_derivatives_rad[i] = dG_rad__dalpha_j_
    # calculate symmetry function derivatives with respect to neighbor atoms of atom i
    if calc_derivatives:
        for i in prange(n_atoms_sys):
            n_neighbors = len(neighbor_index[i][neighbor_index[i] % n_atoms < n_atoms_sys])
            if n_neighbors > 0:
                descriptor_neighbor_derivative[i] = np.zeros((n_neighbors, n_descriptors, 3))
                for j in range(n_neighbors):
                    if neighbor_index[i][j] >= n_atoms:
                        index = (27 - neighbor_index[i][j] // n_atoms) * n_atoms + i
                    else:
                        index = i
                    descriptor_neighbor_derivative[i][j] = calc_descriptor_neighbor_derivative(
                        descriptor_j_derivatives_rad[neighbor_index[i][j] % n_atoms], zeros, zeros,
                        index, neighbor_index[neighbor_index[i][j] % n_atoms], n_descriptors,
                        n_parameters_rad, 0, 0, 0)

    # calculate symmetry function derivatives with respect to neighbor active environment atoms of
    # atom i
    if QMMM and calc_derivatives:
        n_atoms_env = n_atoms_active - n_atoms_sys
        n_parameters_rad_env_start = n_parameters_rad - n_parameters_rad_env
        descriptor_neighbor_derivative_env = [np.zeros((0, 0, 3))] * n_atoms_env
        for i in prange(n_atoms_env):
            n_neighbors_env = len(neighbor_index_env[i])
            if n_neighbors_env > 0:
                descriptor_neighbor_derivative_env[i] = np.zeros((
                    n_neighbors_env, n_parameters_rad_env, 3))
                for j in range(n_neighbors_env):
                    if neighbor_index_env[i][j] >= n_atoms_sys:
                        index = ((27 - neighbor_index_env[i][j] // n_atoms_sys) * n_atoms
                                 + active_atom[i + n_atoms_sys])
                    else:
                        index = active_atom[i + n_atoms_sys]
                    descriptor_neighbor_derivative_env[i][j] = calc_descriptor_neighbor_derivative(
                        descriptor_j_derivatives_rad[neighbor_index_env[i][j] % n_atoms_sys], zeros,
                        zeros, index, neighbor_index[neighbor_index_env[i][j] % n_atoms_sys],
                        n_parameters_rad_env, n_parameters_rad_env, 0, n_parameters_rad_env_start, 0)
    else:
        descriptor_neighbor_derivative_env = [np.zeros((0, 0, 3))]

    return descriptor, descriptor_i_derivative, descriptor_neighbor_derivative, \
        descriptor_neighbor_derivative_env


####################################################################################################

@ncfjit
def calc_descriptor_neighbor_derivative(descriptor_j_derivatives_rad_neighbor,
                                        descriptor_j_derivatives_ang_neighbor,
                                        descriptor_k_derivatives_ang_neighbor, index,
                                        neighbor_indices_neighbor, n_descriptors,
                                        n_parameters_rad, n_parameters_ang,
                                        n_parameters_rad_start, n_parameters_ang_start) -> NDArray:
    '''
    Return: descriptor_neighbor_derivative
    '''
    # calculate descriptor neighbor derivative
    descriptor_neighbor_derivative = np.zeros((n_descriptors, 3))
    n_neighbor_neighbors = len(neighbor_indices_neighbor)
    i_neighbor_index = np.argwhere(neighbor_indices_neighbor == index)[0][0]

    # radial symmetry functions
    descriptor_neighbor_derivative[:n_parameters_rad] = descriptor_j_derivatives_rad_neighbor[
        n_parameters_rad_start:n_parameters_rad_start + n_parameters_rad, i_neighbor_index]

    # angular symmetry functions
    if n_parameters_ang > 0:
        for k in range(i_neighbor_index + 1, n_neighbor_neighbors):
            m = i_neighbor_index * (2 * n_neighbor_neighbors - i_neighbor_index - 3) // 2 + k - 1
            descriptor_neighbor_derivative[n_parameters_rad:] += descriptor_j_derivatives_ang_neighbor[
                n_parameters_ang_start:n_parameters_ang_start + n_parameters_ang, m]
        for j in range(i_neighbor_index):
            m = j * (2 * n_neighbor_neighbors - j - 3) // 2 + i_neighbor_index - 1
            descriptor_neighbor_derivative[n_parameters_rad:] += descriptor_k_derivatives_ang_neighbor[
                n_parameters_ang_start:n_parameters_ang_start + n_parameters_ang, m]

    return descriptor_neighbor_derivative


####################################################################################################

@ncfjit
def calc_rad_sym_func(elements_int_j, ij_1, dR_ij__dalpha_i_, interaction_classes_j,
                      atomic_charges_j, elem_func_index, rad_func_index, scale_func_index, R_c,
                      eta_ij, H_parameters_rad, H_parameters_rad_scale, n_parameters_rad,
                      element_types_rad, calc_derivatives, QMMM, I_type_j,
                      MM_atomic_charge_max) -> Tuple[NDArray, NDArray, NDArray]:
    '''
    Requirement: 0 <= R_ij < R_c

    Return: G_rad, dG_rad__dalpha_i_, dG_rad__dalpha_j_
    '''
    # check if neighbors exist
    n_interactions = len(elements_int_j)
    if n_interactions == 0:
        return np.zeros(n_parameters_rad), np.zeros((n_parameters_rad, 3)), \
            np.zeros((n_parameters_rad, 0, 3))

    # determine element-dependent radial function
    if elem_func_index == 0:
        elem_j = elem_rad_eeACSF(
            elements_int_j, n_parameters_rad, H_parameters_rad, H_parameters_rad_scale,
            n_interactions)
    elif elem_func_index == 1:
        elem_j = elem_rad_ACSF(
            elements_int_j, n_parameters_rad, element_types_rad, n_interactions)
    elif elem_func_index == 2:
        elem_j = elem_rad_eeACSF_QMMM(
            elements_int_j, interaction_classes_j, atomic_charges_j, n_parameters_rad,
            H_parameters_rad, H_parameters_rad_scale, n_interactions, MM_atomic_charge_max)

    # determine interaction-dependent radial function
    if QMMM:
        int_j = int_rad_eeACSF_QMMM(interaction_classes_j, I_type_j, n_parameters_rad,
                                    n_interactions)
        elem_j = elem_j * int_j

    # calculate combined radial and cutoff function and its derivative
    R_ij = ij_1.repeat(n_parameters_rad).reshape((-1, n_parameters_rad))
    if rad_func_index == 0:
        rad_ij, drad_ij = rad_bump(R_ij, eta_ij, R_c)
    elif rad_func_index == 1:
        rad_ij, drad_ij = rad_gaussian_bump(R_ij, eta_ij, R_c)
    elif rad_func_index == 2:
        rad_ij, drad_ij = rad_gaussian_cos(R_ij, eta_ij, R_c)

    # calculate unscaled radial eeACSFs
    G_rad = np.sum(elem_j * rad_ij, axis=0)

    # calculate radial eeACSF derivatives
    if calc_derivatives:
        if scale_func_index == 0:
            dscale = dscale_crss(G_rad)
        elif scale_func_index == 1:
            dscale = dscale_linear(G_rad)
        elif scale_func_index == 2:
            dscale = dscale_sqrt(G_rad)
        dG_rad__dalpha_j_ = ((-dscale * elem_j * drad_ij) * np.ones((
            3, n_interactions, n_parameters_rad))).T * dR_ij__dalpha_i_
        dG_rad__dalpha_i_ = -np.sum(dG_rad__dalpha_j_, axis=1)

    # calculate radial eeACSFs
    if scale_func_index == 0:
        G_rad = scale_crss(G_rad)
    elif scale_func_index == 1:
        G_rad = scale_linear(G_rad)
    elif scale_func_index == 2:
        G_rad = scale_sqrt(G_rad)

    if not calc_derivatives:
        return G_rad, np.zeros((0, 3)), np.zeros((0, 0, 3))

    return G_rad, dG_rad__dalpha_i_, dG_rad__dalpha_j_


####################################################################################################

@ncfjit
def calc_ang_sym_func(elements_int_j, elements_int_k, ijk_2, ijk_3, dR_ij__dalpha_i_,
                      dR_ik__dalpha_i_, ijk_6, dcos_theta_ijk__dalpha_i_, dcos_theta_ijk__dalpha_j_,
                      dcos_theta_ijk__dalpha_k_, interaction_classes_jk, atomic_charges_j,
                      atomic_charges_k, elem_func_index, rad_func_index, ang_func_index,
                      scale_func_index, R_c, eta_ijk, lambda_ijk, zeta_ijk, xi_ijk, H_parameters_ang,
                      H_parameters_ang_scale, n_parameters_ang, H_type_jk, n_H_parameters,
                      element_types_ang, calc_derivatives, QMMM, I_type_jk,
                      MM_atomic_charge_max) -> Tuple[NDArray, NDArray, NDArray, NDArray]:
    '''
    Requirement: 0 <= R_ij < R_c, 0 <= R_ik < R_c

    Return: G_ang, dG_ang__dalpha_i_, dG_ang__dalpha_j_, dG_ang__dalpha_k_
    '''
    # check if neighbors exist
    n_interactions = len(elements_int_j)
    if n_interactions == 0:
        return np.zeros(n_parameters_ang), np.zeros((n_parameters_ang, 3)), \
            np.zeros((n_parameters_ang, 0, 3)), np.zeros((n_parameters_ang, 0, 3))

    # determine element-dependent angular function
    if elem_func_index == 0:
        elem_jk = elem_ang_eeACSF(
            elements_int_j, elements_int_k, n_parameters_ang, H_parameters_ang,
            H_parameters_ang_scale, H_type_jk, n_H_parameters, n_interactions)
    elif elem_func_index == 1:
        elem_jk = elem_ang_ACSF(
            elements_int_j, elements_int_k, n_parameters_ang, element_types_ang, n_interactions)
    elif elem_func_index == 2:
        elem_jk = elem_ang_eeACSF_QMMM(
            elements_int_j, elements_int_k, interaction_classes_jk, atomic_charges_j,
            atomic_charges_k, n_parameters_ang, H_parameters_ang, H_parameters_ang_scale, H_type_jk,
            n_H_parameters, I_type_jk, n_interactions, MM_atomic_charge_max)

    # determine interaction-dependent angular function
    if QMMM:
        int_jk = int_ang_eeACSF_QMMM(interaction_classes_jk, I_type_jk, n_parameters_ang,
                                     n_interactions)
        elem_jk = elem_jk * int_jk

    # calculate radial functions and their derivatives
    R_ij = ijk_2.repeat(n_parameters_ang).reshape((-1, n_parameters_ang))
    R_ik = ijk_3.repeat(n_parameters_ang).reshape((-1, n_parameters_ang))
    if rad_func_index == 0:
        rad_ij, drad_ij = rad_bump(R_ij, eta_ijk, R_c)
        rad_ik, drad_ik = rad_bump(R_ik, eta_ijk, R_c)
    elif rad_func_index == 1:
        rad_ij, drad_ij = rad_gaussian_bump(R_ij, eta_ijk, R_c)
        rad_ik, drad_ik = rad_gaussian_bump(R_ik, eta_ijk, R_c)
    elif rad_func_index == 2:
        rad_ij, drad_ij = rad_gaussian_cos(R_ij, eta_ijk, R_c)
        rad_ik, drad_ik = rad_gaussian_cos(R_ik, eta_ijk, R_c)

    # calculate angular function and its derivative
    cos_theta_ijk = ijk_6.repeat(n_parameters_ang).reshape((-1, n_parameters_ang))
    if ang_func_index == 0:
        ang_ijk, dang_ijk = ang_bump(cos_theta_ijk, lambda_ijk, xi_ijk)
    elif ang_func_index == 1:
        ang_ijk, dang_ijk = ang_cos(cos_theta_ijk, lambda_ijk, zeta_ijk, xi_ijk)
    elif ang_func_index == 2:
        ang_ijk, dang_ijk = ang_cos_int(cos_theta_ijk, lambda_ijk, zeta_ijk)

    # calculate unscaled angular eeACSFs
    x = elem_jk * rad_ij * rad_ik
    G_ang = np.sum(x * ang_ijk, axis=0)

    # calculate angular eeACSF derivatives
    if calc_derivatives:
        if scale_func_index == 0:
            dscale = dscale_crss(G_ang)
        elif scale_func_index == 1:
            dscale = dscale_linear(G_ang)
        elif scale_func_index == 2:
            dscale = dscale_sqrt(G_ang)
        y = dscale * elem_jk * ang_ijk
        z = np.ones((3, n_interactions, n_parameters_ang))
        a = ((dscale * x * dang_ijk) * z).T
        b = ((y * drad_ij * rad_ik) * z).T * dR_ij__dalpha_i_
        c = ((y * rad_ij * drad_ik) * z).T * dR_ik__dalpha_i_
        dG_ang__dalpha_i_ = np.sum(a * dcos_theta_ijk__dalpha_i_ + b + c, axis=1)
        dG_ang__dalpha_j_ = a * dcos_theta_ijk__dalpha_j_ - b
        dG_ang__dalpha_k_ = a * dcos_theta_ijk__dalpha_k_ - c

    # calculate angular eeACSFs
    if scale_func_index == 0:
        G_ang = scale_crss(G_ang)
    elif scale_func_index == 1:
        G_ang = scale_linear(G_ang)
    elif scale_func_index == 2:
        G_ang = scale_sqrt(G_ang)

    if not calc_derivatives:
        return G_ang, np.zeros((0, 3)), np.zeros((0, 0, 3)), np.zeros((0, 0, 3))

    return G_ang, dG_ang__dalpha_i_, dG_ang__dalpha_j_, dG_ang__dalpha_k_


####################################################################################################

@ncfjit
def elem_rad_eeACSF(elements_int_j, n_parameters_rad, H_parameters_rad, H_parameters_rad_scale,
                    n_interactions) -> NDArray:
    '''
    Return: H_j
    '''
    # determine element-dependent term
    H_j = np.zeros((n_interactions, n_parameters_rad), dtype=np.int64)
    for i in range(n_interactions):
        H_j[i] = H_parameters_rad[elements_int_j[i]]
    H_j = H_parameters_rad_scale * H_j

    return H_j


####################################################################################################

@ncfjit
def elem_rad_ACSF(elements_int_j, n_parameters_rad, element_types_rad, n_interactions) -> NDArray:
    '''
    Return: S_j
    '''
    # determine selection array of element contributions
    element_types_rad = element_types_rad.repeat(n_interactions).reshape((-1, n_interactions)).T
    elements_int_j = elements_int_j.repeat(n_parameters_rad).reshape((-1, n_parameters_rad))
    S_j = 1.0 * np.equal(element_types_rad, elements_int_j)

    return S_j


####################################################################################################

@ncfjit
def elem_rad_eeACSF_QMMM(elements_int_j, interaction_classes_j, atomic_charges_j, n_parameters_rad,
                         H_parameters_rad, H_parameters_rad_scale, n_interactions,
                         MM_atomic_charge_max) -> NDArray:
    '''
    Return: H_j
    '''
    # determine element-dependent term
    H_j = np.zeros((n_interactions, n_parameters_rad))
    for i in range(n_interactions):
        if interaction_classes_j[i] == 2:
            H_j[i] = (0.5 / MM_atomic_charge_max) * atomic_charges_j[i]
        else:
            H_j[i] = H_parameters_rad_scale * H_parameters_rad[elements_int_j[i]]

    return H_j


####################################################################################################

@ncfjit
def elem_ang_eeACSF(elements_int_j, elements_int_k, n_parameters_ang, H_parameters_ang,
                    H_parameters_ang_scale, H_type_jk, n_H_parameters, n_interactions) -> NDArray:
    '''
    Hint: Only contributions of angles ijk (not ikj) are taken into account.

    Return: H_jk
    '''
    # determine element-dependent term
    H_j = np.zeros((n_interactions, n_parameters_ang), dtype=np.int64)
    H_k = np.zeros((n_interactions, n_parameters_ang), dtype=np.int64)
    for i in range(n_interactions):
        H_j[i] = H_parameters_ang[elements_int_j[i]]
        H_k[i] = H_parameters_ang[elements_int_k[i]]
    factor = -2 * np.greater(H_type_jk, n_H_parameters) + 1
    bias = np.greater(H_type_jk, n_H_parameters)
    bias = bias.repeat(n_interactions).reshape((-1, n_interactions)).T
    bias = bias * (-1 * np.logical_and(np.equal(H_j, 0), np.equal(H_k, 0)) + 1)
    H_jk = np.absolute(H_j + factor * H_k) + bias
    H_jk = H_parameters_ang_scale * H_jk

    return H_jk


####################################################################################################

@ncfjit
def elem_ang_ACSF(elements_int_j, elements_int_k, n_parameters_ang, element_types_ang,
                  n_interactions) -> NDArray:
    '''
    Hint: Contributions of angles ijk and ikj are both taken into account.

    Return: S_jk
    '''
    # determine selection array of element contributions
    element_types_ang = element_types_ang.repeat(n_interactions).reshape((-1, n_interactions)).T
    elements_jk = 1000 * elements_int_j + elements_int_k
    elements_jk = elements_jk.repeat(n_parameters_ang).reshape((-1, n_parameters_ang))
    elements_kj = 1000 * elements_int_k + elements_int_j
    elements_kj = elements_kj.repeat(n_parameters_ang).reshape((-1, n_parameters_ang))
    S_jk = np.logical_or(np.equal(element_types_ang, elements_jk),
                         np.equal(element_types_ang, elements_kj))

    # take into account ijk and ikj
    S_jk = 2.0 * S_jk

    return S_jk


####################################################################################################

@ncfjit
def elem_ang_eeACSF_QMMM(elements_int_j, elements_int_k, interaction_classes_jk, atomic_charges_j,
                         atomic_charges_k, n_parameters_ang, H_parameters_ang,
                         H_parameters_ang_scale, H_type_jk, n_H_parameters, I_type_jk,
                         n_interactions, MM_atomic_charge_max) -> NDArray:
    '''
    Hint: Only contributions of angles ijk (not ikj) are taken into account.

    Return: H_jk
    '''
    # determine element-dependent term
    H_j = np.zeros((n_interactions, n_parameters_ang))
    H_k = np.zeros((n_interactions, n_parameters_ang))
    H_sign = np.ones((n_interactions, n_parameters_ang))
    for i in range(n_interactions):
        if interaction_classes_jk[i] == 2:
            H_j[i] = H_parameters_ang_scale * H_parameters_ang[elements_int_j[i]]
            H_k[i] = H_parameters_ang_scale * H_parameters_ang[elements_int_k[i]]
        elif interaction_classes_jk[i] == 3:
            if elements_int_j[i] >= 0:
                H_j[i] = (H_parameters_ang_scale * H_parameters_ang[elements_int_j[i]]
                          * atomic_charges_k[i] / MM_atomic_charge_max)
            else:
                H_j[i] = (H_parameters_ang_scale * H_parameters_ang[elements_int_k[i]]
                          * atomic_charges_j[i] / MM_atomic_charge_max)
            H_sign[i] = np.sign(H_j[i])
        else:
            H_j[i] = (0.5 / MM_atomic_charge_max**2) * atomic_charges_j[i] * atomic_charges_k[i]
            H_sign[i] = np.sign(H_j[i])
    factor = -2.0 * np.greater(H_type_jk, n_H_parameters) + 1.0
    bias = np.equal(I_type_jk, 0) * np.greater(H_type_jk, n_H_parameters) * H_parameters_ang_scale
    bias = bias.repeat(n_interactions).reshape((-1, n_interactions)).T
    bias = bias * (-1.0 * np.logical_and(np.equal(H_j, 0), np.equal(H_k, 0)) + 1.0)
    H_jk = H_sign * np.absolute(H_j + factor * H_k) + bias

    return H_jk


####################################################################################################

@ncfjit
def int_rad_eeACSF_QMMM(interaction_classes_j, I_type_j, n_parameters_rad, n_interactions) -> NDArray:
    '''
    Return: I_j
    '''
    # determine interaction-dependent term
    I_j = np.zeros((n_interactions, n_parameters_rad))
    A = np.equal(I_type_j, 0)
    B = np.equal(I_type_j, 1)
    for i in range(n_interactions):
        X = interaction_classes_j[i] == 1
        Y = interaction_classes_j[i] == 2
        I_j[i][A * X + B * Y] = 1.0

    return I_j


####################################################################################################

@ncfjit
def int_ang_eeACSF_QMMM(interaction_classes_jk, I_type_jk, n_parameters_ang, n_interactions) -> NDArray:
    '''
    Return: I_jk
    '''
    # determine interaction-dependent term
    I_jk = np.zeros((n_interactions, n_parameters_ang))
    A = np.equal(I_type_jk, 0)
    B = np.equal(I_type_jk, 1)
    C = np.equal(I_type_jk, 2)
    for i in range(n_interactions):
        X = interaction_classes_jk[i] == 2
        Y = interaction_classes_jk[i] == 3
        Z = interaction_classes_jk[i] == 4
        I_jk[i][A * X + B * Y + C * Z] = 1.0

    return I_jk


####################################################################################################

@ncfpjit
def rad_bump(R, eta, R_c) -> Tuple[NDArray, NDArray]:
    '''
    Return: rad, drad
    '''
    # calculate bump combined radial and cutoff function and its derivative
    x = 1.0 - (R / R_c)**2
    rad = np.exp(eta - eta / x)
    drad = rad * ((-2.0 * eta / R_c**2) * R / x**2)

    return rad, drad


####################################################################################################

@ncfpjit
def rad_gaussian_bump(R, eta, R_c) -> Tuple[NDArray, NDArray]:
    '''
    Return: rad, drad
    '''
    # calculate Gaussian radial function and its derivative
    e = np.exp(-eta * R**2)
    de_e = (-2.0 * eta) * R

    # calculate bump cutoff function and its derivative
    x = 1.0 - (R / R_c)**2
    b = np.exp(1.0 - 1.0 / x)
    db_b = (-2.0 / R_c**2) * R / x**2

    # calculate combined radial and cutoff function and its derivative
    rad = e * b
    drad = rad * (db_b + de_e)

    return rad, drad


####################################################################################################

@ncfpjit
def rad_gaussian_cos(R, eta, R_c) -> Tuple[NDArray, NDArray]:
    '''
    Return: rad, drad
    '''
    # calculate Gaussian radial function and its derivative
    e = np.exp(-eta * R**2)
    de_e = (-2.0 * eta) * R

    # calculate cosine cutoff function and its derivative
    x = (pi / R_c) * R
    c = 0.5 + 0.5 * np.cos(x)
    dc = (-0.5 * pi / R_c) * np.sin(x)

    # calculate combined radial and cutoff function and its derivative
    rad = e * c
    drad = e * (dc + de_e * c)

    return rad, drad


####################################################################################################

@ncfpjit
def ang_bump(cos_theta, lambda_ijk, xi_ijk) -> Tuple[NDArray, NDArray]:
    '''
    Return: ang, dang
    '''
    # calculate cosine angular function and its derivative
    x = lambda_ijk - (np.arccos(cos_theta) / pi)
    y = 1.0 - x**2
    ang = np.exp(xi_ijk - xi_ijk / y)
    dang = ang * ((-2.0 / pi) * xi_ijk) * x / y**2 / np.sqrt(1.0 - cos_theta**2)

    return ang, dang


####################################################################################################

@ncfpjit
def ang_cos(cos_theta, lambda_ijk, zeta_ijk, xi_ijk) -> Tuple[NDArray, NDArray]:
    '''
    Return: ang, dang
    '''
    # calculate cosine angular function and its derivative
    theta = zeta_ijk * np.arccos(cos_theta)
    x = 0.5 + (0.5 * lambda_ijk) * np.cos(theta)
    ang = x**xi_ijk
    dang = (0.5 * lambda_ijk * zeta_ijk * xi_ijk) * ang / x * np.sin(theta) / np.sqrt(
        1.0 - cos_theta**2)

    return ang, dang


####################################################################################################

@ncfpjit
def ang_cos_int(cos_theta, lambda_ijk, zeta_ijk) -> Tuple[NDArray, NDArray]:
    '''
    Return: ang, dang
    '''
    # calculate cosine angular function and its derivative
    x = 0.5 + (0.5 * lambda_ijk) * cos_theta
    ang = x**zeta_ijk
    dang = (0.5 * lambda_ijk * zeta_ijk) * ang / x

    return ang, dang


####################################################################################################

@ncfjit
def scale_crss(G) -> NDArray:
    '''
    Return: scale
    '''
    # calculate cube root-scaled-shifted scaling function
    return 3.0 * ((G + 1.0)**(1.0 / 3.0) - 1.0)


####################################################################################################

@ncfjit
def dscale_crss(G) -> NDArray:
    '''
    Return: dscale
    '''
    # calculate cube root-scaled-shifted scaling function derivative
    return (G + 1.0)**(-2.0 / 3.0)


####################################################################################################

@ncfjit
def scale_linear(G) -> NDArray:
    '''
    Return: scale
    '''
    # calculate linear scaling function
    return G


####################################################################################################

@ncfjit
def dscale_linear(G) -> NDArray:
    '''
    Return: dscale
    '''
    # calculate linear scaling function derivative
    return np.ones(G.shape)


####################################################################################################

@ncfjit
def scale_sqrt(G) -> NDArray:
    '''
    Return: scale
    '''
    # calculate square root scaling function
    return np.sqrt(G)


####################################################################################################

@ncfpjit
def dscale_sqrt(G) -> NDArray:
    '''
    Return: dscale
    '''
    # calculate square root scaling function derivative
    dscale = np.zeros(G.shape)
    nonzero = G > 0.0
    dscale[nonzero] = 0.5 / np.sqrt(G[nonzero])

    return dscale
