#!/usr/bin/python3

####################################################################################################

####################################################################################################

'''
Lifelong Machine Learning Potentials (lMLP)
'''
__copyright__ = '''This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.'''

####################################################################################################

####################################################################################################

from os import path
from pathlib import Path
from types import SimpleNamespace
import sys
import warnings
import numpy as np
import torch
from numba import NumbaTypeSafetyWarning   # type: ignore
from .descriptors import calc_descriptor_derivative, calc_descriptor_derivative_radial
from .lmlp_base import lMLP_base
from .models import calculate_forces, calculate_forces_QMMM


####################################################################################################

####################################################################################################

class lMLP_calculator(lMLP_base):
    '''
    Lifelong Machine Learning Potential Calculator
    '''

####################################################################################################

    def __init__(self, generalization_setting_file, uncertainty_scaling=2.0,
                 active_learning_file=None, active_learning_thresholds=(3.0, 3.0, 3.0)):
        '''
        Initialization
        '''
        # get settings
        super().__init__()
        self.settings = SimpleNamespace(
            generalization_setting_format='lMLP',
            generalization_setting_file=generalization_setting_file,
            generalization_dir=Path(generalization_setting_file).parent,
            uncertainty_scaling=uncertainty_scaling,
            active_learning_file=active_learning_file,
            active_learning_thresholds=active_learning_thresholds)

        # set device used by PyTorch
        self.device = 'cpu'

        # read and set generalization settings
        self.element_types, self.n_descriptors, self.ensemble, self.test_RMSEs = \
            self.read_generalization_setting([])
        self.n_element_types = len(self.element_types)

        # define default data type of PyTorch tensors
        self.dtype_torch = self.define_dtype_torch()
        torch.set_default_dtype(self.dtype_torch)

        # initialize descriptor parameters and element energy
        self.descriptor_parameters = []
        self.R_c = 0.0
        self.element_energy = {}

        # define activation function
        self.activation_function = self.define_activation_function()

        # initialize ensemble model
        self.n_ensemble = len(self.ensemble)
        self.model = [None] * self.n_ensemble
        for model_index in range(self.n_ensemble):
            self.define_model(model_index)
            self.settings.generalization_file = path.join(
                self.settings.generalization_dir, self.ensemble[model_index])
            self.read_generalization(model_index)

        # initialize uncertainty prediction
        self.test_RMSEs = np.mean(np.array(self.test_RMSEs), axis=0)

        # create active learning output file
        if self.settings.active_learning_file is not None:
            if self.n_ensemble <= 1:
                print('ERROR: More than one machine learning potential has to be in the ensemble',
                      'for active learning.')
                sys.exit()
            Path(self.settings.active_learning_file).parent.mkdir(parents=True, exist_ok=True)
            with open(self.settings.active_learning_file, 'w', encoding='utf-8') as f:
                f.write('')
            self.active_learning_counter = 0

####################################################################################################

    def predict(self, elements, positions, lattice=None, atomic_classes=None, atomic_charges=None,
                name=None, calc_forces=True, calc_uncertainty=False):
        '''
        Prediction

        Return: energy_prediction, forces_prediction
        '''
        # check settings
        if calc_uncertainty and self.n_ensemble <= 1:
            print('ERROR: More than one machine learning potential has to be in the ensemble',
                  'for uncertainty quantification.')
            sys.exit()

        # get settings
        self.settings.calc_forces = calc_forces
        if self.settings.active_learning_file is not None:
            calc_uncertainty = True

        # get structures
        QMMM = False
        n_structures = 1
        n_atoms_original = len(elements)
        n = 0
        if lattice is None:
            lattice = np.array([])
        if atomic_classes is None:
            atomic_classes = np.ones(n_atoms_original, dtype=int)
            elements_unique = np.unique(elements)
        else:
            elements_unique = np.unique(elements[atomic_classes == 1])
            if np.max(atomic_classes) > 1:
                QMMM = True
        if atomic_charges is None:
            atomic_charges = np.zeros(n_atoms_original)

        # check structures
        if not np.all(np.isin(elements_unique, self.element_types)):
            print('ERROR: The lMLP is not able to represent all given chemical elements.',
                  '\n{0} not all in {1}'.format(np.unique(elements), np.unique(self.element_types)))
            sys.exit()

        # prepare periodic systems
        if len(lattice) > 0:
            PBC = True
            elements, positions, lattice, atomic_classes, atomic_charges, _, _, n_atoms, \
                reorder_original = self.prepare_periodic_systems(
                    elements, positions, lattice, atomic_classes, atomic_charges, 0.0, [],
                    n_atoms_original)
        else:
            PBC = False
            n_atoms = n_atoms_original
            reorder_original = np.empty(0, dtype=int)

        # create lists of structure properties
        elements = [elements]
        positions = [positions]
        lattices = [lattice]
        atomic_classes = [atomic_classes]
        atomic_charges = [atomic_charges]
        n_atoms = np.array([n_atoms])
        if name is None:
            name = []
        else:
            name = [name]

        # order atoms by atomic type
        if QMMM:
            elements, positions, atomic_classes, atomic_charges, _, n_atoms_sys, reorder = \
                self.atomic_type_ordering(elements, positions, atomic_classes, atomic_charges, [],
                                          n_structures, n_atoms)
        else:
            n_atoms_sys = n_atoms
            reorder = []

        # determine element integer list
        elements_int_sys = self.convert_element_names(elements, n_structures, n_atoms_sys)

        # determine neighbor indices and descriptors and their derivatives as a function of the
        # Cartesian coordinates
        descriptors_torch, descriptor_derivatives_torch, neighbor_indices, \
            descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, active_atoms, \
            n_atoms_active, MM_gradients = self.calculate_descriptors(
                elements_int_sys, positions, lattices, atomic_classes, atomic_charges, n_structures,
                n_atoms, n_atoms_sys, [])

        # calculate energy and forces of all ensemble members
        energy_prediction = np.empty(self.n_ensemble)
        if self.settings.calc_forces:
            forces_prediction = np.empty((self.n_ensemble, n_atoms[n], 3))
        for model_index in range(self.n_ensemble):
            energy_prediction_torch, forces_prediction_torch = self.calculate_energy_forces(
                model_index, n, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                neighbor_indices, n_atoms_sys, descriptor_neighbor_derivatives_torch_env,
                neighbor_indices_env, n_atoms_active, MM_gradients, create_graph=False)
            energy_prediction[model_index] = float(energy_prediction_torch.cpu().detach().numpy()[0])
            if self.settings.calc_forces:
                if self.settings.QMMM:
                    forces_prediction[model_index] = np.zeros((n_atoms[n], 3))
                    forces_prediction[model_index][active_atoms[n]] = \
                        forces_prediction_torch.cpu().detach().numpy().astype(float)
                else:
                    forces_prediction[model_index] = \
                        forces_prediction_torch.cpu().detach().numpy().astype(float)

        # calculate ensemble prediction and uncertainty of energy and forces including
        # element-specific atomic energies
        if calc_uncertainty:
            energy_uncertainty = max(self.test_RMSEs[0], self.settings.uncertainty_scaling * np.std(
                energy_prediction / n_atoms_sys[n], ddof=1)) * np.sqrt(n_atoms_sys[n])
        else:
            energy_uncertainty = 0.0
        energy_prediction = np.mean(energy_prediction) + np.sum(np.array([
            self.element_energy[ele] for ele in elements[n][:n_atoms_sys[n]]]))
        if self.settings.calc_forces:
            if calc_uncertainty:
                forces_uncertainty = self.settings.uncertainty_scaling * np.std(
                    forces_prediction, ddof=1, axis=0)
                forces_uncertainty[:n_atoms_sys[n]][forces_uncertainty[
                    :n_atoms_sys[n]] < self.test_RMSEs[1]] = self.test_RMSEs[1]
                if self.settings.QMMM:
                    active_atoms_env = np.arange(n_atoms[n])[active_atoms[n][n_atoms_sys[n]:]]
                    for i in range(3):
                        forces_uncertainty[:, i][active_atoms_env[forces_uncertainty[:, i][
                            active_atoms_env] < self.test_RMSEs[2]]] = self.test_RMSEs[2]
            forces_prediction = np.mean(forces_prediction, axis=0)

        # write bad represented structures to a file for active learning
        if self.settings.active_learning_file is not None:
            retrain = False
            energy_uncertainty_per_atom = energy_uncertainty / np.sqrt(n_atoms_sys[n])
            if energy_uncertainty_per_atom > (
                    self.settings.active_learning_thresholds[0] * self.test_RMSEs[0]):
                retrain = True
            if self.settings.calc_forces:
                forces_uncertainty_max = np.max(forces_uncertainty[:n_atoms_sys[n]])
                if forces_uncertainty_max > (
                        self.settings.active_learning_thresholds[1] * self.test_RMSEs[1]):
                    retrain = True
                if self.settings.QMMM:
                    forces_uncertainty_max_QMMM = np.max(forces_uncertainty[n_atoms_sys[n]:])
                    if forces_uncertainty_max_QMMM > (
                            self.settings.active_learning_thresholds[2] * self.test_RMSEs[2]):
                        retrain = True
            if retrain:
                self.active_learning_counter += 1
                if self.settings.calc_forces:
                    if self.settings.QMMM:
                        uncertainty = [[energy_uncertainty_per_atom, forces_uncertainty_max,
                                       forces_uncertainty_max_QMMM]]
                    else:
                        uncertainty = [[energy_uncertainty_per_atom, forces_uncertainty_max]]
                else:
                    uncertainty = [[energy_uncertainty_per_atom]]
                self.write_inputdata(
                    self.settings.active_learning_file, elements, positions, lattices,
                    atomic_classes, atomic_charges, [energy_prediction], [forces_prediction],
                    n_structures, n_atoms, reorder, name=name, uncertainty=uncertainty,
                    counter=self.active_learning_counter, mode='a')

        # get energy of original atoms
        n_images = n_atoms[n] / n_atoms_original
        energy_prediction /= n_images
        if calc_uncertainty:
            energy_uncertainty /= np.sqrt(n_images)

        # reorder forces and get original atoms
        if self.settings.calc_forces:
            if QMMM:
                forces_prediction = forces_prediction[reorder[n]]
            if PBC:
                forces_prediction = forces_prediction[reorder_original][:n_atoms_original]
            if calc_uncertainty:
                if QMMM:
                    forces_uncertainty = forces_uncertainty[reorder[n]]
                if PBC:
                    forces_uncertainty = forces_uncertainty[reorder_original][:n_atoms_original]

        # return requested properties
        if self.settings.calc_forces:
            if calc_uncertainty:
                return energy_prediction, forces_prediction, energy_uncertainty, forces_uncertainty
            return energy_prediction, forces_prediction
        if calc_uncertainty:
            return energy_prediction, energy_uncertainty
        return energy_prediction

####################################################################################################

    def calculate_symmetry_function(self, elements_int_sys, positions, lattices, atomic_classes,
                                    atomic_charges, n_structures, n_atoms, n_atoms_sys, _):
        '''
        Implementation: Radial types: bump, Gaussian-bump, Gaussian-cosine
                        Angular types: bump, cosine, cosine_integer
                        Scaling types: cube root-scaled-shifted, linear, square root

        Return: descriptors_torch, [descriptor_i_derivatives_torch,
                descriptor_neighbor_derivatives_torch], neighbor_indices,
                descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, active_atoms,
                n_atoms_active, MM_gradients
        '''
        # implemented descriptor radial types
        descriptor_radial_type_list = ['bump', 'gaussian_bump', 'gaussian_cos']
        # implemented descriptor angular types
        descriptor_angular_type_list = ['bump', 'cos', 'cos_int']
        # implemented descriptor radial types
        descriptor_scaling_type_list = ['crss', 'linear', 'sqrt']

        # get bump radial function index
        if self.settings.descriptor_radial_type == 'bump':
            rad_func_index = 0
        # get Gaussian-bump radial function index
        elif self.settings.descriptor_radial_type == 'gaussian_bump':
            rad_func_index = 1
        # get Gaussian-cosine radial function index
        elif self.settings.descriptor_radial_type == 'gaussian_cos':
            rad_func_index = 2
        # not implemented descriptor radial type
        else:
            print('ERROR: Calculating descriptor radial type {0} is not yet implemented.'
                  .format(self.settings.descriptor_radial_type),
                  '\nPlease use one of the following types:')
            for des_rad_type in descriptor_radial_type_list:
                print('{0}'.format(des_rad_type))
            sys.exit()

        # get bump angular function index
        if self.settings.descriptor_angular_type == 'bump':
            ang_func_index = 0
        # get cosine angular function index
        elif self.settings.descriptor_angular_type == 'cos':
            ang_func_index = 1
        # get cosine integer angular function index
        elif self.settings.descriptor_angular_type == 'cos_int':
            ang_func_index = 2
        # not implemented descriptor angular type
        else:
            print('ERROR: Calculating descriptor angular type {0} is not yet implemented.'
                  .format(self.settings.descriptor_angular_type),
                  '\nPlease use one of the following types:')
            for des_ang_type in descriptor_angular_type_list:
                print('{0}'.format(des_ang_type))
            sys.exit()

        # check descriptor scaling type for QM/MM data
        if self.settings.QMMM and self.settings.descriptor_scaling_type != 'linear':
            print('ERROR: QM/MM data require descriptor scaling type linear.')
            sys.exit()
        # get cube root-scaled-shifted scaling function index
        if self.settings.descriptor_scaling_type == 'crss':
            scale_func_index = 0
        # get linear scaling function index
        elif self.settings.descriptor_scaling_type == 'linear':
            scale_func_index = 1
        # get square root scaling function index
        elif self.settings.descriptor_scaling_type == 'sqrt':
            scale_func_index = 2
        # not implemented descriptor scaling type
        else:
            print('ERROR: Calculating descriptor scaling type {0} is not yet implemented.'
                  .format(self.settings.descriptor_scaling_type),
                  '\nPlease use one of the following types:')
            for des_sca_type in descriptor_scaling_type_list:
                print('{0}'.format(des_sca_type))
            sys.exit()

        # determine number of radial and angular parameters
        n_parameters_ang = np.sum(np.array(
            [self.descriptor_parameters[i][0] for i in range(self.n_descriptors)])) - self.n_descriptors
        n_parameters_rad = self.n_descriptors - n_parameters_ang
        parameters_rad = np.array(self.descriptor_parameters[:n_parameters_rad])
        # get element-dependent radial and angular function index
        if self.settings.descriptor_type == 'eeACSF':
            element_types_rad = np.array([], dtype=int)
            if self.settings.QMMM:
                elem_func_index = 2
            else:
                elem_func_index = 0
        elif self.settings.descriptor_type == 'ACSF':
            element_types_rad = np.tile(
                np.arange(self.n_element_types), n_parameters_rad // self.n_element_types)
            if self.settings.QMMM:
                print('ERROR: QM/MM reference data cannot be represented by ACSFs.',
                      '\nPlease use the descriptor type eeACSFs.')
                sys.exit()
            else:
                elem_func_index = 1

        # radial and angular symmetry functions
        if n_parameters_ang > 0:
            parameters_ang = np.array(self.descriptor_parameters[n_parameters_rad:])
            n_element_types_ang = 0
            element_types_ang = np.array([], dtype=int)
            H_parameters_rad = np.array([], dtype=int)
            H_parameters_ang = np.array([], dtype=int)
            H_parameters_rad_scale = np.array([])
            H_parameters_ang_scale = np.array([])
            n_H_parameters = 0
            if self.settings.descriptor_type == 'ACSF':
                # determine angular parameters
                n_element_types_ang = np.sum(np.arange(1, self.n_element_types + 1))
                element_types_ang = np.tile(np.array(
                    [1000 * i + j for i in range(self.n_element_types)
                     for j in range(i, self.n_element_types)]), n_parameters_ang // n_element_types_ang)
                # slice parameters array
                H_type_jk = np.array([], dtype=int)
                eta_ij = parameters_rad[:, 1]
                eta_ijk = parameters_ang[:, 1]
                lambda_ijk = parameters_ang[:, 2]
                zeta_ijk = parameters_ang[:, 3]
                xi_ijk = parameters_ang[:, 4]
            elif self.settings.descriptor_type == 'eeACSF':
                # slice parameters array
                H_type_j = parameters_rad[:, 1].astype(int)
                H_type_jk = parameters_ang[:, 1].astype(int)
                eta_ij = parameters_rad[:, 2]
                eta_ijk = parameters_ang[:, 2]
                lambda_ijk = parameters_ang[:, 3]
                zeta_ijk = parameters_ang[:, 4]
                xi_ijk = parameters_ang[:, 5]
                # determine H parameters
                H_parameters_rad, H_parameters_ang, H_parameters_rad_scale, H_parameters_ang_scale, \
                    n_H_parameters = self.get_H_parameters(H_type_j, H_type_jk)
            if self.settings.QMMM:
                # slice parameters array
                I_type_j = parameters_rad[:, 3].astype(int)
                I_type_jk = parameters_ang[:, 6].astype(int)
                # check if H and I types are compatible
                if np.any(np.logical_and(np.greater(H_type_jk, n_H_parameters), np.greater(I_type_jk, 2))):
                    print('ERROR: Angular symmetry function subtypes cannot be larger than',
                          '{0} for QM/MM subtypes larger than 2.'.format(n_H_parameters))
                    sys.exit()
                # determine I parameters
                n_parameters_rad_env = len(np.arange(n_parameters_rad)[I_type_j > 0])
                n_parameters_ang_env = len(np.arange(n_parameters_ang)[I_type_jk > 0])
                MM_gradients = list(np.arange(self.n_descriptors)[
                    np.concatenate((I_type_j > 0, I_type_jk > 0))])
            else:
                # slice parameters array
                I_type_j = np.array([], dtype=int)
                I_type_jk = np.array([], dtype=int)
                # determine I parameters
                n_parameters_rad_env = 0
                n_parameters_ang_env = 0
                MM_gradients = []

            # calculate symmetry function values and derivatives
            descriptors_torch = []
            descriptor_i_derivatives_torch = []
            descriptor_neighbor_derivatives_torch = []
            neighbor_indices = []
            descriptor_neighbor_derivatives_torch_env = []
            neighbor_indices_env = []
            active_atoms = []
            n_atoms_active = []
            for n in range(n_structures):
                # determine interatomic distances and angles for each atom within the cutoff sphere
                neighbor_index, ij_0, ij_1, ij_2, ij_3, ij_4, ijk_0, ijk_1, ijk_2, ijk_3, ijk_4, \
                    ijk_5, ijk_6, ijk_7, ijk_8, ijk_9, ijk_10, ijk_11, ijk_12, active_atom, \
                    neighbor_index_env = self.calculate_atomic_environments(
                        elements_int_sys[n], positions[n], lattices[n], atomic_classes[n],
                        atomic_charges[n], n_atoms[n], n_atoms_sys[n],
                        calc_derivatives=self.settings.calc_forces)
                active_atoms.append(active_atom)
                n_atoms_active.append(len(active_atom))
                neighbor_index_sys = [neighbor_index[i][neighbor_index[i] % n_atoms[n] < n_atoms_sys[n]]
                                      for i in range(len(neighbor_index))]
                neighbor_indices.append([i % n_atoms[n] for i in neighbor_index_sys])
                descriptor_i_derivatives_torch.append([])
                descriptor_neighbor_derivatives_torch.append([])
                neighbor_indices_env.append([i % n_atoms_sys[n] for i in neighbor_index_env])
                descriptor_neighbor_derivatives_torch_env.append([])
                # calculate symmetry function value and derivative contributions
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', category=NumbaTypeSafetyWarning)
                    descriptor, descriptor_i_derivative, descriptor_neighbor_derivative, \
                        descriptor_neighbor_derivative_env = calc_descriptor_derivative(
                            ij_0, ij_1, ij_2, ij_3, ij_4, ijk_0, ijk_1, ijk_2, ijk_3, ijk_4, ijk_5,
                            ijk_6, ijk_7, ijk_8, ijk_9, ijk_10, ijk_11, ijk_12, neighbor_index,
                            n_atoms[n], n_atoms_sys[n], self.n_descriptors, elem_func_index,
                            rad_func_index, ang_func_index, scale_func_index, self.R_c, eta_ij,
                            H_parameters_rad, H_parameters_rad_scale, n_parameters_rad,
                            element_types_rad, eta_ijk, lambda_ijk, zeta_ijk, xi_ijk,
                            H_parameters_ang, H_parameters_ang_scale, n_parameters_ang, H_type_jk,
                            n_H_parameters, element_types_ang, self.settings.calc_forces,
                            self.settings.QMMM, active_atom, n_atoms_active[n], neighbor_index_env,
                            I_type_j, n_parameters_rad_env, I_type_jk, n_parameters_ang_env,
                            self.settings.MM_atomic_charge_max)
                # compile symmetry function values
                descriptors_torch.append(torch.tensor(np.array(descriptor), requires_grad=True,
                                                      dtype=self.dtype_torch))
                if self.settings.calc_forces:
                    for i in range(n_atoms_sys[n]):
                        # compile symmetry function derivatives with respect to the central atom i
                        descriptor_i_derivatives_torch[-1].append(torch.tensor(
                            descriptor_i_derivative[i], dtype=self.dtype_torch))
                        # compile symmetry function derivatives with respect to neighbor atoms of atom i
                        descriptor_neighbor_derivatives_torch[-1].append(torch.tensor(
                            descriptor_neighbor_derivative[i], dtype=self.dtype_torch))
                    # compile symmetry function derivatives with respect to neighbor active environment
                    # atoms of atom i
                    if self.settings.QMMM:
                        for i in range(n_atoms_active[n] - n_atoms_sys[n]):
                            descriptor_neighbor_derivatives_torch_env[-1].append(torch.tensor(
                                descriptor_neighbor_derivative_env[i], dtype=self.dtype_torch))

        # only radial symmetry functions
        else:
            H_parameters_rad = np.array([], dtype=int)
            H_parameters_rad_scale = np.array([])
            if self.settings.descriptor_type == 'ACSF':
                # slice parameter arrays
                eta_ij = parameters_rad[:, 1]
            elif self.settings.descriptor_type == 'eeACSF':
                # slice parameter arrays
                H_type_j = parameters_rad[:, 1].astype(int)
                eta_ij = parameters_rad[:, 2]
                # determine H parameters
                H_parameters_rad, H_parameters_ang, H_parameters_rad_scale, H_parameters_ang_scale, \
                    n_H_parameters = self.get_H_parameters(H_type_j, np.array([], dtype=int))
            if self.settings.QMMM:
                # slice parameters array
                I_type_j = parameters_rad[:, 3].astype(int)
                # determine I parameters
                n_parameters_rad_env = len(np.arange(n_parameters_rad)[I_type_j > 0])
                MM_gradients = list(np.arange(self.n_descriptors)[I_type_j > 0])
            else:
                # slice parameters array
                I_type_j = np.array([], dtype=int)
                # determine I parameters
                n_parameters_rad_env = 0
                MM_gradients = []

            # calculate symmetry function values and derivatives
            descriptors_torch = []
            descriptor_i_derivatives_torch = []
            descriptor_neighbor_derivatives_torch = []
            neighbor_indices = []
            descriptor_neighbor_derivatives_torch_env = []
            neighbor_indices_env = []
            active_atoms = []
            n_atoms_active = []
            for n in range(n_structures):
                # determine interatomic distances for each atom within the cutoff sphere
                neighbor_index, ij_0, ij_1, ij_2, ij_3, ij_4, ijk_0, ijk_1, ijk_2, ijk_3, ijk_4, \
                    ijk_5, ijk_6, ijk_7, ijk_8, ijk_9, ijk_10, ijk_11, ijk_12, active_atom, \
                    neighbor_index_env = self.calculate_atomic_environments(
                        elements_int_sys[n], positions[n], lattices[n], atomic_classes[n],
                        atomic_charges[n], n_atoms[n], n_atoms_sys[n],
                        calc_derivatives=self.settings.calc_forces, angular=False)
                active_atoms.append(active_atom)
                n_atoms_active.append(len(active_atom))
                neighbor_index_sys = [neighbor_index[i][neighbor_index[i] % n_atoms[n] < n_atoms_sys[n]]
                                      for i in range(len(neighbor_index))]
                neighbor_indices.append([i % n_atoms[n] for i in neighbor_index_sys])
                descriptor_i_derivatives_torch.append([])
                descriptor_neighbor_derivatives_torch.append([])
                neighbor_indices_env.append([i % n_atoms_sys[n] for i in neighbor_index_env])
                descriptor_neighbor_derivatives_torch_env.append([])
                # calculate symmetry function value and derivative contributions
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', category=NumbaTypeSafetyWarning)
                    descriptor, descriptor_i_derivative, descriptor_neighbor_derivative, \
                        descriptor_neighbor_derivative_env = calc_descriptor_derivative_radial(
                            ij_0, ij_1, ij_2, ij_3, ij_4, neighbor_index, n_atoms[n],
                            n_atoms_sys[n], self.n_descriptors, elem_func_index, rad_func_index,
                            scale_func_index, self.R_c, eta_ij, H_parameters_rad,
                            H_parameters_rad_scale, n_parameters_rad, element_types_rad,
                            self.settings.calc_forces, self.settings.QMMM, active_atom,
                            n_atoms_active[n], neighbor_index_env, I_type_j, n_parameters_rad_env,
                            self.settings.MM_atomic_charge_max)
                # compile symmetry function values
                descriptors_torch.append(torch.tensor(np.array(descriptor), requires_grad=True,
                                                      dtype=self.dtype_torch))
                if self.settings.calc_forces:
                    for i in range(n_atoms_sys[n]):
                        # compile symmetry function derivatives with respect to the central atom i
                        descriptor_i_derivatives_torch[-1].append(torch.tensor(
                            descriptor_i_derivative[i], dtype=self.dtype_torch))
                        # compile symmetry function derivatives with respect to neighbor atoms of atom i
                        descriptor_neighbor_derivatives_torch[-1].append(torch.tensor(
                            descriptor_neighbor_derivative[i], dtype=self.dtype_torch))
                    # compile symmetry function derivatives with respect to neighbor active environment
                    # atoms of atom i
                    if self.settings.QMMM:
                        for i in range(n_atoms_active[n] - n_atoms_sys[n]):
                            descriptor_neighbor_derivatives_torch_env[-1].append(torch.tensor(
                                descriptor_neighbor_derivative_env[i], dtype=self.dtype_torch))

        # convert number of active atoms from list to NumPy array
        n_atoms_active = np.array(n_atoms_active)

        return descriptors_torch, [descriptor_i_derivatives_torch, descriptor_neighbor_derivatives_torch], \
            neighbor_indices, descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, \
            active_atoms, n_atoms_active, MM_gradients

####################################################################################################

    def read_generalization(self, model_index=0):
        '''
        Implementation: lMLP, lMLP-only_prediction

        Modify: model, descriptor_parameters, R_c, element_energy
        '''
        # implemented generalization formats
        generalization_format_list = ['lMLP', 'lMLP-only_prediction']

        # read lMLP model
        if self.settings.generalization_format in ('lMLP', 'lMLP-only_prediction'):
            # check if generalization file exists
            if not path.isfile(self.settings.generalization_file):
                print('ERROR: Generalization file {0} does not exist.'.format(
                    self.settings.generalization_file))
                sys.exit()
            self.read_lMLP_model(model_index)

        # not implemented generalization format
        else:
            print('ERROR: Generalization format {0} is not yet implemented.'
                  .format(self.settings.generalization_format),
                  '\nPlease use one of the following formats:')
            for gen_format in generalization_format_list:
                print('{0}'.format(gen_format))
            sys.exit()

####################################################################################################

    def read_lMLP_model(self, model_index):
        '''
        Modify: model, descriptor_parameters, R_c, element_energy
        '''
        # read lMLP model
        checkpoint = torch.load(self.settings.generalization_file)
        try:
            self.model[model_index].load_state_dict(checkpoint['model_state_dict'])
            self.descriptor_parameters = checkpoint['descriptor_parameters']
            self.R_c = checkpoint['R_c']
            self.element_energy = checkpoint['element_energy']
        except KeyError:
            print('ERROR: lMLP model file {0} is broken.'.format(
                self.settings.generalization_file))
            sys.exit()

####################################################################################################

    def calculate_energy_forces(self, model_index, n, elements_int_sys, descriptors_torch,
                                descriptor_derivatives_torch, neighbor_indices, n_atoms_sys,
                                descriptor_neighbor_derivatives_torch_env, neighbor_indices_env,
                                n_atoms_active, MM_gradients, create_graph=False):
        '''
        Reutrn: energy_prediction_torch, forces_prediction_torch
        '''
        # predict energy of fit structures
        energy_prediction_torch = self.model[model_index](elements_int_sys[n], descriptors_torch[n], n_atoms_sys[n])
        # predict forces of fit structures
        if self.settings.calc_forces:
            if self.settings.QMMM:
                forces_prediction_torch = calculate_forces_QMMM(
                    energy_prediction_torch, descriptors_torch[n], descriptor_derivatives_torch[0][n],
                    descriptor_derivatives_torch[1][n], neighbor_indices[n], n_atoms_sys[n],
                    descriptor_neighbor_derivatives_torch_env[n], neighbor_indices_env[n],
                    n_atoms_active[n], MM_gradients, create_graph=create_graph)
            else:
                forces_prediction_torch = calculate_forces(
                    energy_prediction_torch, descriptors_torch[n], descriptor_derivatives_torch[0][n],
                    descriptor_derivatives_torch[1][n], neighbor_indices[n], n_atoms_sys[n],
                    create_graph=create_graph)
        else:
            return energy_prediction_torch, torch.zeros((0, 3))

        return energy_prediction_torch, forces_prediction_torch
