#!/usr/bin/python3

####################################################################################################

####################################################################################################

'''
Lifelong Machine Learning Potentials (lMLP)
'''
__copyright__ = '''This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.'''

####################################################################################################

####################################################################################################

from os import path
from copy import deepcopy
from pathlib import Path
from time import time
from typing import Tuple
import configparser
import gc
import sys
import warnings
import numpy as np
import torch
from numba import NumbaTypeSafetyWarning, typed   # type: ignore
from numpy.typing import NDArray
from scipy.stats import rankdata   # type: ignore
from torch import Tensor
from .core import CoRe
from .descriptors import calc_descriptor_derivative, calc_descriptor_derivative_radial
from .lmlp_base import lMLP_base
from .models import calculate_forces, calculate_forces_QMMM
from .performance import ncjit, ncfjit


####################################################################################################

####################################################################################################

class lMLP(lMLP_base):
    '''
    Lifelong Machine Learning Potential
    '''

####################################################################################################

    def __init__(self, settings):
        '''
        Initialization
        '''
        # get settings
        time_start = time()
        super().__init__()
        self.settings = settings

        # check settings
        self.check_settings()

        # print header and settings
        self.print_settings()

        # import memory_profiler if requested
        if self.settings.memory_evaluation:
            from memory_profiler import memory_usage   # type: ignore
            self.memory_usage = memory_usage

        # read and set previous generalization settings if requested
        if self.settings.prediction_only:
            self.settings.restart = True
        if self.settings.restart:
            self.element_types, self.n_descriptors, _, self.test_RMSEs = \
                self.read_generalization_setting([
                    self.settings.generalization_format, self.settings.model_type,
                    self.settings.descriptor_type, self.settings.descriptor_radial_type,
                    self.settings.descriptor_angular_type, self.settings.descriptor_scaling_type,
                    self.settings.scale_shift_layer, self.settings.n_neurons_hidden_layers,
                    self.settings.activation_function_type, self.settings.dtype_torch,
                    self.settings.QMMM, self.settings.MM_atomic_charge_max])
        else:
            self.test_RMSEs = []

        # get and order unique element types
            self.element_types = self.get_element_types()
        self.n_element_types = len(self.element_types)

        # check device used by PyTorch
        self.check_device()

        # define default data type of PyTorch tensors
        self.dtype_torch = self.define_dtype_torch()
        torch.set_default_dtype(self.dtype_torch)

        # construct a NumPy random number generator for training and test set splitting and
        # selection of fitting structures
        self.rng = np.random.default_rng(self.settings.seed)

        # get unit conversion factors
        self.get_unit_conversion()

        # read supplemental pair potential
        if not (self.settings.restart and self.settings.generalization_format.startswith('lMLP')
                and not self.settings.pair_contributions):
            self.read_supplemental_potential()
        else:
            self.supplemental_potential_parameters, self.element_energy = {}, {}

        # read parameters of descriptors which are used as input for the machine learning potential
        if not (self.settings.restart and self.settings.generalization_format.startswith('lMLP')):
            self.read_descriptor_parameters()
            self.n_descriptors = len(self.descriptor_parameters)
        else:
            self.descriptor_parameters, self.R_c = [], 0.0

        # define activation function
        self.activation_function = self.define_activation_function()

        # set seed of PyTorch random number generator for weight_initialization == 'default'
        torch.manual_seed(self.settings.seed)

        # define machine learning potential
        self.model = [None]
        self.define_model()

        # define loss function and optimizer and initialize training state
        self.energy_loss_function, self.forces_loss_function = self.define_loss_function()
        self.define_optimization()
        self.training_state = []

        # read and set previous model, loss function, optimizer, training state, and/or descriptor
        # parameters if requested
        if self.settings.restart:
            self.read_generalization()

        # prepare transfer learning if requested
        if self.settings.transfer_learning:
            self.prepare_transfer_learning()
        self.max_memory_usage = 0.0
        self.print_time_memory('initialization:', time_start)
        self.time_init = time() - time_start

####################################################################################################

    def run(self):
        '''
        Evaluation
        '''
        # read episodic memory
        time_total = time()
        time_start = time()
        elements, positions, lattices, atomic_classes, atomic_charges, energy, forces, \
            n_structures, n_atoms, name = self.read_episodic_memory()

        # order atoms by atomic type
        if self.settings.QMMM:
            elements, positions, atomic_classes, atomic_charges, forces, n_atoms_sys, reorder = \
                self.atomic_type_ordering(elements, positions, atomic_classes, atomic_charges,
                                          forces, n_structures, n_atoms)
        else:
            n_atoms_sys = n_atoms
            reorder = []

        # determine element integer list
        elements_int_sys = self.convert_element_names(elements, n_structures, n_atoms_sys)

        # determine neighbor indices and descriptors and their derivatives as a function of the
        # Cartesian coordinates
        descriptors_torch, descriptor_derivatives_torch, neighbor_indices, \
            descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, active_atoms, \
            n_atoms_active, MM_gradients = self.calculate_descriptors(
                elements_int_sys, positions, lattices, atomic_classes, atomic_charges, n_structures,
                n_atoms, n_atoms_sys, name)

        # write descriptor parameters and values
        if self.settings.descriptor_output_file is not None:
            self.write_descriptor_output(descriptors_torch, elements, n_structures, n_atoms_sys)
        self.print_time_memory('descriptor calculation:', time_start, measure_memory=False)
        gc.collect()

        # subtract supplemental pair potential contributions including atomic energies
        # in episodic memory reference data
        if not self.settings.prediction_only:
            time_start = time()
            energy, forces = self.calculate_atomic_and_pair_contributions(
                '-', elements, positions, lattices, energy, forces, n_structures, n_atoms_sys)

        # write modified input.data
        #    self.write_prediction(
        #        'input.data-modified', elements, positions, lattices, atomic_classes,
        #        atomic_charges, energy, forces, n_structures, n_atoms, reorder, name)

        # convert NumPy arrays to PyTorch tensors
            energy_torch = [torch.tensor([structure], dtype=self.dtype_torch)
                            for structure in energy]
            if self.settings.QMMM:
                forces_torch = [torch.tensor(forces[n][active_atoms[n]], dtype=self.dtype_torch)
                                for n in range(n_structures)]
            else:
                forces_torch = [torch.tensor(structure, dtype=self.dtype_torch)
                                for structure in forces]

        # define splitting of training and test data taking into account the previous training state
            train, test, n_structures_train, n_structures_test, assignment = \
                self.define_train_test_splitting(name, n_structures)

        # set seed of PyTorch random number generator for weight_initialization != 'default'
            if not self.settings.restart:
                torch.manual_seed(self.settings.seed)

        # initialize model weights
                if self.settings.scale_shift_layer:
                    self.initialize_scale_shift_layer(train, elements_int_sys, descriptors_torch)
                self.initialize_weights(train, elements_int_sys, energy, n_structures_train,
                                        n_atoms_sys)

        # freeze and unfreeze model weights
            self.freeze_weights()

        # fit model
            selection = self.fit_model(
                train, test, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                neighbor_indices, descriptor_neighbor_derivatives_torch_env, neighbor_indices_env,
                energy_torch, forces_torch, n_structures_train, n_structures_test, n_atoms_active,
                n_atoms_sys, MM_gradients, name)
            self.print_time_memory('model fit:', time_start, measure_memory=False)
            gc.collect()

        # calculate energy and forces
        time_start = time()
        if self.settings.prediction_only or self.settings.write_prediction:
            energy_prediction = typed.List()
            forces_prediction = typed.List()
            for n in range(n_structures):
                energy_prediction_torch, forces_prediction_torch = self.calculate_energy_forces(
                    n, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                    neighbor_indices, n_atoms_sys, descriptor_neighbor_derivatives_torch_env,
                    neighbor_indices_env, n_atoms_active, MM_gradients, name, create_graph=False)
                energy_prediction.append(float(energy_prediction_torch.cpu().detach().numpy()[0]))
                if self.settings.QMMM:
                    forces_prediction.append(np.zeros((n_atoms[n], 3)))
                    forces_prediction[-1][active_atoms[n]] = \
                        forces_prediction_torch.cpu().detach().numpy().astype(float)
                else:
                    forces_prediction.append(forces_prediction_torch.cpu().detach().numpy().astype(float))

        # add supplemental pair potential contributions including atomic energies to prediction data
            energy_prediction, forces_prediction = self.calculate_atomic_and_pair_contributions(
                '+', elements, positions, lattices, energy_prediction, forces_prediction, n_structures,
                n_atoms_sys)
            self.print_time_memory('model prediction:', time_start)

        # calculate RMSE for prediction
            time_start = time()
            if self.settings.prediction_only:
                if self.settings.prediction_RMSE:
                    E_RMSE, F_RMSE, F_RMSE_env = calculate_RMSE(
                        np.array(energy_prediction), np.array(energy), forces_prediction, forces,
                        n_structures, n_atoms_sys, self.settings.QMMM, typed.List(active_atoms),
                        n_atoms_active)
                    print('Energy RMSE:                   {0} {1}/atom'.format(
                        round(E_RMSE * self.energy_conversion, 6), self.settings.energy_unit))
                    print('Force RMSE:                    {0} {1}/{2}'.format(
                        round(F_RMSE * self.force_conversion, 6), self.settings.energy_unit,
                        self.settings.length_unit))
                    if self.settings.QMMM:
                        print('Environment Force RMSE:        {0} {1}/{2}'.format(
                            round(F_RMSE_env * self.force_conversion, 6), self.settings.energy_unit,
                            self.settings.length_unit))
                    print('')

        # write predictions
                assignment = []
            self.write_prediction(
                self.settings.prediction_file, elements, positions, lattices, atomic_classes,
                atomic_charges, energy_prediction, forces_prediction, n_structures, n_atoms, reorder,
                name, assignment)

        # write generalization setting
        if not self.settings.prediction_only:
            self.write_generalization_setting()

        # write generalization
            self.write_generalization()

        # write episodic memory for new, bad, and redundant training data and test data
        # and print names of bad training data
            if self.settings.selection_scheme == 'lADS':
                self.write_modified_episodic_memory(
                    selection, train, test, elements, positions, lattices, atomic_classes,
                    atomic_charges, energy, forces, n_structures, n_structures_test, n_atoms,
                    n_atoms_sys, reorder, name)
        self.print_time_memory('finalization:', time_start)
        print('Total time:                    {0:.3f} s'.format(
            round(time() - time_total + self.time_init, 3)))

####################################################################################################

    def print_settings(self):
        '''
        Output: Header and settings
        '''
        # print header
        print(r'|\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/|')
        print(r'|/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\|')
        print(r'|\/\/\/\/\/\/\/\/\/\/\/\                                                    /\/\/\/\/\/\/\/\/\/\/\/|')
        print(r'|/\/\/\/\/\/\/\/\/\/\/\/        Lifelong Machine Learning Potentials        \/\/\/\/\/\/\/\/\/\/\/\|')
        print(r'|\/\/\/\/\/\/\/\/\/\/\/\                 Version 28-05-2024                 /\/\/\/\/\/\/\/\/\/\/\/|')
        print(r'|/\/\/\/\/\/\/\/\/\/\/\/                 Dr. Marco Eckhoff                  \/\/\/\/\/\/\/\/\/\/\/\|')
        print(r'|\/\/\/\/\/\/\/\/\/\/\/\                     ETH Zurich                     /\/\/\/\/\/\/\/\/\/\/\/|')
        print(r'|/\/\/\/\/\/\/\/\/\/\/\/  Department of Chemistry and Applied Biosciences   \/\/\/\/\/\/\/\/\/\/\/\|')
        print(r'|\/\/\/\/\/\/\/\/\/\/\/\  Vladimir-Prelog-Weg 2, 8093 Zurich, Switzerland   /\/\/\/\/\/\/\/\/\/\/\/|')
        print(r'|/\/\/\/\/\/\/\/\/\/\/\/           lifelong_ml@phys.chem.ethz.ch            \/\/\/\/\/\/\/\/\/\/\/\|')
        print(r'|\/\/\/\/\/\/\/\/\/\/\/\                                                    /\/\/\/\/\/\/\/\/\/\/\/|')
        print(r'|/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\|')
        print(r'|\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/|')

        # print input and output settings
        print('\nepisodic_memory_format:        {0}'.format(self.settings.episodic_memory_format),
              '\nepisodic_memory_file:          {0}'.format(self.settings.episodic_memory_file),
              '\nsupplemental_potential_type:   {0}'.format(self.settings.supplemental_potential_type),
              '\nsupplemental_potential_file:   {0}'.format(self.settings.supplemental_potential_file),
              '\ndescriptor_type:               {0}'.format(self.settings.descriptor_type),
              '\ndescriptor_parameter_file:     {0}'.format(self.settings.descriptor_parameter_file),
              '\ngeneralization_setting_format: {0}'.format(self.settings.generalization_setting_format),
              '\ngeneralization_setting_file:   {0}'.format(self.settings.generalization_setting_file),
              '\ngeneralization_format:         {0}'.format(self.settings.generalization_format),
              '\ngeneralization_file:           {0}'.format(self.settings.generalization_file),
              '\nprediction_format:             {0}'.format(self.settings.prediction_format),
              '\nprediction_file:               {0}'.format(self.settings.prediction_file))

        # print training and prediction settings
        print('\nelement_types:                 {0}'.format(self.settings.element_types),
              '\nmodel_type:                    {0}'.format(self.settings.model_type),
              '\natomic_force_max:              {0}'.format(self.settings.atomic_force_max),
              '\nprediction_only:               {0}'.format(self.settings.prediction_only),
              '\npair_contributions:            {0}'.format(self.settings.pair_contributions),
              '\nQMMM:                          {0}'.format(self.settings.QMMM),
              '\nMM_atomic_charge_max:          {0}'.format(self.settings.MM_atomic_charge_max),
              '\ntransfer_learning:             {0}'.format(self.settings.transfer_learning),
              '\nscale_shift_layer:             {0}'.format(self.settings.scale_shift_layer),
              '\nn_neurons_hidden_layers:       {0}'.format(self.settings.n_neurons_hidden_layers),
              '\nactivation_function_type:      {0}'.format(self.settings.activation_function_type),
              '\ndescriptor_radial_type:        {0}'.format(self.settings.descriptor_radial_type),
              '\ndescriptor_angular_type:       {0}'.format(self.settings.descriptor_angular_type),
              '\ndescriptor_scaling_type:       {0}'.format(self.settings.descriptor_scaling_type),
              '\ndescriptor_on_disk:            {0}'.format(self.settings.descriptor_on_disk),
              '\ndescriptor_disk_dir:           {0}'.format(self.settings.descriptor_disk_dir),
              '\ndescriptor_cache_file:         {0}'.format(self.settings.descriptor_cache_file),
              '\ndescriptor_output_file:        {0}'.format(self.settings.descriptor_output_file),
              '\nmemory_evaluation:             {0}'.format(self.settings.memory_evaluation),
              '\ndtype_torch:                   {0}'.format(self.settings.dtype_torch),
              '\ndevice:                        {0}'.format(self.settings.device),
              '\nenergy_unit:                   {0}'.format(self.settings.energy_unit),
              '\nlength_unit:                   {0}'.format(self.settings.length_unit))

        # print additional training settings
        if not self.settings.prediction_only:
            print('\nseed:                          {0}'.format(self.settings.seed),
                  '\nrestart:                       {0}'.format(self.settings.restart),
                  '\nweight_initialization:         {0}'.format(self.settings.weight_initialization),
                  '\nfit_forces:                    {0}'.format(self.settings.fit_forces),
                  '\nenergy_preoptimization_step:   {0}'.format(self.settings.energy_preoptimization_step),
                  '\ntraining_fraction:             {0}'.format(self.settings.training_fraction),
                  '\nfit_fraction:                  {0}'.format(self.settings.fit_fraction),
                  '\noptimizer:                     {0}'.format(self.settings.optimizer),
                  '\nn_epochs:                      {0}'.format(self.settings.n_epochs),
                  '\nlearning_rate:                 {0}'.format(self.settings.learning_rate),
                  '\nstep_sizes:                    {0}'.format(self.settings.step_sizes),
                  '\netas:                          {0}'.format(self.settings.etas),
                  '\nbetas:                         {0}'.format(self.settings.betas),
                  '\nweight_decay:                  {0}'.format(self.settings.weight_decay),
                  '\nscore_history:                 {0}'.format(self.settings.score_history),
                  '\nfrozen:                        {0}'.format(self.settings.frozen),
                  '\nforeach:                       {0}'.format(self.settings.foreach),
                  '\nfrozen_layers                  {0}'.format(self.settings.frozen_layers),
                  '\nloss_function:                 {0}'.format(self.settings.loss_function),
                  '\nloss_parameters:               {0}'.format(self.settings.loss_parameters),
                  '\nloss_E_scaling:                {0}'.format(self.settings.loss_E_scaling),
                  '\nselection_scheme:              {0}'.format(self.settings.selection_scheme),
                  '\nselection_measure:             {0}'.format(self.settings.selection_measure),
                  '\nselection_range:               {0}'.format(self.settings.selection_range),
                  '\nselection_thresholds:          {0}'.format(self.settings.selection_thresholds),
                  '\nselection_strikes:             {0}'.format(self.settings.selection_strikes),
                  '\nselection_small_strikes:       {0}'.format(self.settings.selection_small_strikes),
                  '\nexclusion_strikes:             {0}'.format(self.settings.exclusion_strikes),
                  '\nfraction_redundant_max:        {0}'.format(self.settings.fraction_redundant_max),
                  '\ngradient_backtracking:         {0}'.format(self.settings.gradient_backtracking),
                  '\nstationary_point_prob_factor:  {0}'.format(self.settings.stationary_point_prob_factor),
                  '\nfraction_good_max:             {0}'.format(self.settings.fraction_good_max),
                  '\nn_fraction_intervals:          {0}'.format(self.settings.n_fraction_intervals),
                  '\nwrite_new_episodic_memory:     {0}'.format(self.settings.write_new_episodic_memory),
                  '\nprint_bad_data_names:          {0}'.format(self.settings.print_bad_data_names),
                  '\nlate_data_scheme:              {0}'.format(self.settings.late_data_scheme),
                  '\nlate_data_fraction:            {0}'.format(self.settings.late_data_fraction),
                  '\nlate_data_epoch:               {0}'.format(self.settings.late_data_epoch),
                  '\nRMSE_interval:                 {0}'.format(self.settings.RMSE_interval),
                  '\nwrite_weights_interval:        {0}'.format(self.settings.write_weights_interval),
                  '\nwrite_prediction:              {0}'.format(self.settings.write_prediction))

        # print additional prediction settings
        else:
            print('\nprediction_RMSE:               {0}'.format(self.settings.prediction_RMSE))

        # print computational settings
        print('\nn_threads:                     {0}\n'.format(self.settings.n_threads))

####################################################################################################

    def check_settings(self):
        '''
        Implementation: bool: prediction_only, pair_contributions, QMMM, transfer_learning,
                              scale_shift_layer, memory_evaluation, restart, fit_forces,
                              energy_preoptimization_step, foreach, gradient_backtracking,
                              write_new_episodic_memory, print_bad_data_names, write_prediction,
                              prediction_RMSE
                        int: seed, n_epochs, score_history, exclusion_strikes, n_fraction_intervals,
                             late_data_epoch, RMSE_interval, write_weights_interval, n_threads
                        float: atomic_force_max, MM_atomic_charge_max, training_fraction,
                               fit_fraction, learning_rate, weight_decay, frozen, loss_E_scaling,
                               fraction_redundant_max, stationary_point_prob_factor,
                               fraction_good_max, late_data_fraction
                        string: episodic_memory_file, supplemental_potential_file,
                                descriptor_parameter_file, generalization_setting_file,
                                generalization_file, prediction_file, descriptor_disk_dir,
                                descriptor_output_file
                        list: element_types, n_neurons_hidden_layers, frozen_layers
                        tuple: step_sizes, etas, betas, loss_parameters, selection_range,
                               selection_thresholds, selection_strikes, selection_small_strikes
        '''
        # check boolean settings
        if not isinstance(self.settings.prediction_only, bool):
            print('ERROR: prediction_only has to be of type bool instead of {0}.'.format(
                self.settings.prediction_only.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.pair_contributions, bool):
            print('ERROR: pair_contributions has to be of type bool instead of {0}.'.format(
                self.settings.pair_contributions.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.QMMM, bool):
            print('ERROR: QMMM has to be of type bool instead of {0}.'.format(
                self.settings.QMMM.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.transfer_learning, bool):
            print('ERROR: transfer_learning has to be of type bool instead of {0}.'.format(
                self.settings.transfer_learning.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.scale_shift_layer, bool):
            print('ERROR: scale_shift_layer has to be of type bool instead of {0}.'.format(
                self.settings.scale_shift_layer.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.memory_evaluation, bool):
            print('ERROR: memory_evaluation has to be of type bool instead of {0}.'.format(
                self.settings.memory_evaluation.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.restart, bool):
            print('ERROR: restart has to be of type bool instead of {0}.'.format(
                self.settings.restart.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.fit_forces, bool):
            print('ERROR: fit_forces has to be of type bool instead of {0}.'.format(
                self.settings.fit_forces.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.energy_preoptimization_step, bool):
            print('ERROR: energy_preoptimization_step has to be of type bool instead of {0}.'.format(
                self.settings.energy_preoptimization_step.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.foreach, bool):
            print('ERROR: foreach has to be of type bool instead of {0}.'.format(
                self.settings.foreach.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.gradient_backtracking, bool):
            print('ERROR: gradient_backtracking has to be of type bool instead of {0}.'.format(
                self.settings.gradient_backtracking.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.write_new_episodic_memory, bool):
            print('ERROR: write_new_episodic_memory has to be of type bool instead of {0}.'.format(
                self.settings.write_new_episodic_memory.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.print_bad_data_names, bool):
            print('ERROR: print_bad_data_names has to be of type bool instead of {0}.'.format(
                self.settings.print_bad_data_names.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.write_prediction, bool):
            print('ERROR: write_prediction has to be of type bool instead of {0}.'.format(
                self.settings.write_prediction.__class__.__name__))
            sys.exit()
        if not isinstance(self.settings.prediction_RMSE, bool):
            print('ERROR: prediction_RMSE has to be of type bool instead of {0}.'.format(
                self.settings.prediction_RMSE.__class__.__name__))
            sys.exit()

        # check integer settings
        if not isinstance(self.settings.seed, int):
            print('ERROR: seed has to be of type int instead of {0}.'.format(
                self.settings.seed.__class__.__name__))
            sys.exit()

        if not isinstance(self.settings.n_epochs, int):
            print('ERROR: n_epochs has to be of type int instead of {0}.'.format(
                self.settings.n_epochs.__class__.__name__))
            sys.exit()
        elif self.settings.n_epochs < 0:
            print('ERROR: n_epochs has to be larger then or equal to 0 instead of {0}.'.format(
                self.settings.n_epochs))
            sys.exit()

        if not isinstance(self.settings.score_history, int):
            print('ERROR: score_history has to be of type int instead of {0}.'.format(
                self.settings.score_history.__class__.__name__))
            sys.exit()
        elif self.settings.score_history < 0:
            print('ERROR: score_history has to be larger than or equal to 0 instead of {0}.'.format(
                self.settings.score_history))
            sys.exit()

        if not isinstance(self.settings.exclusion_strikes, int):
            print('ERROR: exclusion_strikes has to be of type int instead of {0}.'.format(
                self.settings.exclusion_strikes.__class__.__name__))
            sys.exit()
        elif self.settings.exclusion_strikes <= 0:
            print('ERROR: exclusion_strikes has to be larger than 0 instead of {0}.'.format(
                self.settings.exclusion_strikes))
            sys.exit()

        if not isinstance(self.settings.n_fraction_intervals, int):
            print('ERROR: n_fraction_intervals has to be of type int instead of {0}.'.format(
                self.settings.n_fraction_intervals.__class__.__name__))
            sys.exit()
        elif self.settings.n_fraction_intervals <= 0:
            print('ERROR: n_fraction_intervals has to be larger than 0 instead of {0}.'.format(
                self.settings.n_fraction_intervals))
            sys.exit()

        if not isinstance(self.settings.late_data_epoch, int):
            print('ERROR: late_data_epoch has to be of type int instead of {0}.'.format(
                self.settings.late_data_epoch.__class__.__name__))
            sys.exit()
        elif self.settings.late_data_epoch < 0 or self.settings.late_data_epoch > self.settings.n_epochs:
            print('ERROR: late_data_epoch has to be larger than or equal to 0',
                  'and smaller than or equal to n_epochs = {1} instead of {0}.'.format(
                      self.settings.late_data_epoch, self.settings.n_epochs))
            sys.exit()

        if not isinstance(self.settings.RMSE_interval, int):
            print('ERROR: RMSE_interval has to be of type int instead of {0}.'.format(
                self.settings.RMSE_interval.__class__.__name__))
            sys.exit()
        elif self.settings.RMSE_interval <= 0:
            print('ERROR: RMSE_interval has to be larger than 0 instead of {0}.'.format(
                self.settings.RMSE_interval))
            sys.exit()

        if not (isinstance(self.settings.write_weights_interval, int)
                or self.settings.write_weights_interval is None):
            print('ERROR: write_weights_interval has to be of type int or NoneType',
                  'instead of {0}.'.format(self.settings.write_weights_interval.__class__.__name__))
            sys.exit()
        elif self.settings.write_weights_interval is not None:
            if self.settings.write_weights_interval % self.settings.RMSE_interval != 0:
                print('ERROR: write_weights_interval has to be a multiple of RMSE_interval',
                      '({0} != N * {1}).'.format(
                          self.settings.write_weights_interval, self.settings.RMSE_interval))
                sys.exit()

        if not isinstance(self.settings.n_threads, int):
            print('ERROR: n_threads has to be of type int instead of {0}.'.format(
                self.settings.n_threads.__class__.__name__))
            sys.exit()
        elif self.settings.n_threads <= 0:
            print('ERROR: n_threads has to be larger than 0 instead of {0}.'.format(
                self.settings.n_threads))
            sys.exit()

        # check float settings
        if not isinstance(self.settings.atomic_force_max, float):
            print('ERROR: atomic_force_max has to be of type float instead of {0}.'.format(
                self.settings.atomic_force_max.__class__.__name__))
            sys.exit()
        elif self.settings.atomic_force_max <= 0.0:
            print('ERROR: atomic_force_max has to be larger than 0 instead of {0}.'.format(
                self.settings.atomic_force_max))
            sys.exit()

        if not isinstance(self.settings.MM_atomic_charge_max, float):
            print('ERROR: MM_atomic_charge_max has to be of type float instead of {0}.'.format(
                self.settings.MM_atomic_charge_max.__class__.__name__))
            sys.exit()
        elif self.settings.QMMM and self.settings.MM_atomic_charge_max <= 0.0:
            print('ERROR: MM_atomic_charge_max has to be larger than 0 instead of {0}.'.format(
                self.settings.MM_atomic_charge_max))
            sys.exit()

        if not isinstance(self.settings.training_fraction, float):
            print('ERROR: training_fraction has to be of type float instead of {0}.'.format(
                self.settings.training_fraction.__class__.__name__))
            sys.exit()
        elif self.settings.training_fraction <= 0.0 or self.settings.training_fraction > 1.0:
            print('ERROR: training_fraction has to be larger than 0 and smaller than or equal to 1',
                  'instead of {0}.'.format(self.settings.training_fraction))
            sys.exit()

        if not isinstance(self.settings.fit_fraction, (float, int)):
            print('ERROR: fit_fraction has to be of type float or int instead of {0}.'.format(
                self.settings.fit_fraction.__class__.__name__))
            sys.exit()
        elif isinstance(self.settings.fit_fraction, float):
            if self.settings.fit_fraction <= 0.0 or self.settings.fit_fraction > 1.0:
                print('ERROR: Float value of fit_fraction has to be larger than 0 and smaller than',
                      'or equal to 1 instead of {0}.'.format(self.settings.fit_fraction))
                sys.exit()
        elif isinstance(self.settings.fit_fraction, int):
            if self.settings.fit_fraction < 1:
                print('ERROR: Integer value of fit_fraction has to be larger than or equal to 1',
                      'instead of {0}.'.format(self.settings.fit_fraction))
                sys.exit()

        if not isinstance(self.settings.learning_rate, float):
            print('ERROR: learning_rate has to be of type float instead of {0}.'.format(
                self.settings.learning_rate.__class__.__name__))
            sys.exit()
        elif self.settings.learning_rate <= 0.0:
            print('ERROR: learning_rate has to be larger than 0 instead of {0}.'.format(
                self.settings.learning_rate))
            sys.exit()

        if not isinstance(self.settings.weight_decay, float):
            print('ERROR: weight_decay has to be of type float instead of {0}.'.format(
                self.settings.weight_decay.__class__.__name__))
            sys.exit()
        elif self.settings.weight_decay < 0.0 or self.settings.weight_decay >= 1.0:
            print('ERROR: weight_decay has to be larger than or equal to 0 and smaller than 1',
                  'instead of {0}.'.format(self.settings.weight_decay))
            sys.exit()

        if not isinstance(self.settings.frozen, float):
            print('ERROR: frozen has to be of type float instead of {0}.'.format(
                self.settings.frozen.__class__.__name__))
            sys.exit()
        elif self.settings.frozen < 0.0 or self.settings.frozen >= 1.0:
            print('ERROR: frozen has to be larger than or equal to 0 and smaller than 1',
                  'instead of {0}.'.format(self.settings.frozen))
            sys.exit()

        if not isinstance(self.settings.loss_E_scaling, float):
            print('ERROR: loss_E_scaling has to be of type float instead of {0}.'.format(
                self.settings.loss_E_scaling.__class__.__name__))
            sys.exit()
        elif self.settings.loss_E_scaling <= 0.0:
            print('ERROR: loss_E_scaling has to be larger than 0 instead of {0}.'.format(
                self.settings.loss_E_scaling))
            sys.exit()

        if not isinstance(self.settings.fraction_redundant_max, float):
            print('ERROR: fraction_redundant_max has to be of type float instead of {0}.'.format(
                self.settings.fraction_redundant_max.__class__.__name__))
            sys.exit()
        elif self.settings.fraction_redundant_max <= 0.0 or self.settings.fraction_redundant_max > 1.0:
            print('ERROR: fraction_redundant_max has to be larger than 0 and smaller than or equal to 1',
                  'instead of {0}.'.format(self.settings.fraction_redundant_max))
            sys.exit()

        if not isinstance(self.settings.stationary_point_prob_factor, float):
            print('ERROR: stationary_point_prob_factor has to be of type float instead of {0}.'.format(
                self.settings.stationary_point_prob_factor.__class__.__name__))
            sys.exit()
        elif self.settings.stationary_point_prob_factor < 1.0:
            print('ERROR: stationary_point_prob_factor has to be larger than or equal to 1',
                  'instead of {0}.'.format(self.settings.stationary_point_prob_factor))
            sys.exit()

        if not isinstance(self.settings.fraction_good_max, float):
            print('ERROR: fraction_good_max has to be of type float instead of {0}.'.format(
                self.settings.fraction_good_max.__class__.__name__))
            sys.exit()
        elif self.settings.fraction_good_max <= 0.0 or self.settings.fraction_good_max >= 1.0:
            print('ERROR: fraction_good_max has to be larger than 0 and smaller than 1',
                  'instead of {0}.'.format(self.settings.fraction_good_max))
            sys.exit()

        if not isinstance(self.settings.late_data_fraction, float):
            print('ERROR: late_data_fraction has to be of type float instead of {0}.'.format(
                self.settings.late_data_fraction.__class__.__name__))
            sys.exit()
        elif self.settings.late_data_fraction < 0.0 or self.settings.late_data_fraction >= 1.0:
            print('ERROR: late_data_fraction has to be larger than or equal to 0 and smaller than 1',
                  'instead of {0}.'.format(self.settings.late_data_fraction))
            sys.exit()

        # check string settings
        if not isinstance(self.settings.episodic_memory_file, str):
            print('ERROR: episodic_memory_file has to be of type str instead of {0}.'.format(
                self.settings.episodic_memory_file.__class__.__name__))
            sys.exit()

        if not isinstance(self.settings.supplemental_potential_file, str):
            print('ERROR: supplemental_potential_file has to be of type str instead of {0}.'.format(
                self.settings.supplemental_potential_file.__class__.__name__))
            sys.exit()

        if not isinstance(self.settings.descriptor_parameter_file, str):
            print('ERROR: descriptor_parameter_file has to be of type str instead of {0}.'.format(
                self.settings.descriptor_parameter_file.__class__.__name__))
            sys.exit()

        if not isinstance(self.settings.generalization_setting_file, str):
            print('ERROR: generalization_setting_file has to be of type str instead of {0}.'.format(
                self.settings.generalization_setting_file.__class__.__name__))
            sys.exit()

        if not isinstance(self.settings.generalization_file, str):
            print('ERROR: generalization_file has to be of type str instead of {0}.'.format(
                self.settings.generalization_file.__class__.__name__))
            sys.exit()

        if not isinstance(self.settings.prediction_file, str):
            print('ERROR: prediction_file has to be of type str instead of {0}.'.format(
                self.settings.prediction_file.__class__.__name__))
            sys.exit()

        if not isinstance(self.settings.descriptor_disk_dir, str):
            print('ERROR: descriptor_disk_dir has to be of type str instead of {0}.'.format(
                self.settings.descriptor_disk_dir.__class__.__name__))
            sys.exit()

        if not (isinstance(self.settings.descriptor_output_file, str)
                or self.settings.descriptor_output_file is None):
            print('ERROR: descriptor_output_file has to be of type str or NoneType',
                  'instead of {0}.'.format(self.settings.descriptor_output_file.__class__.__name__))
            sys.exit()

        # check list settings
        if not isinstance(self.settings.element_types, list):
            print('ERROR: element_types has to be of type list instead of {0}.'.format(
                self.settings.element_types.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.element_types) < 1:
                print('ERROR: element_types has to contain at least one entry.')
                sys.exit()
            for i in self.settings.element_types:
                if not isinstance(i, str):
                    print('ERROR: List entries of element_types have to be of type str',
                          'instead of {0}.'.format(i.__class__.__name__))
                    sys.exit()

        if not isinstance(self.settings.n_neurons_hidden_layers, list):
            print('ERROR: n_neurons_hidden_layers has to be of type list instead of {0}.'.format(
                self.settings.n_neurons_hidden_layers.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.n_neurons_hidden_layers) < 1:
                print('ERROR: n_neurons_hidden_layers has to contain at least one entry.')
                sys.exit()
            for i in self.settings.n_neurons_hidden_layers:
                if not isinstance(i, int):
                    print('ERROR: List entries of n_neurons_hidden_layers have to be of type int',
                          'instead of {0}.'.format(i.__class__.__name__))
                    sys.exit()
                elif i <= 0:
                    print('ERROR: List entries of n_neurons_hidden_layers have to be larger than 0',
                          'instead of {0}.'.format(i))
                    sys.exit()

        if not isinstance(self.settings.frozen_layers, list):
            print('ERROR: frozen_layers has to be of type list instead of {0}.'.format(
                self.settings.frozen_layers.__class__.__name__))
            sys.exit()
        else:
            for i in self.settings.frozen_layers:
                if not isinstance(i, int):
                    print('ERROR: List entries of frozen_layers have to be of type int',
                          'instead of {0}.'.format(i.__class__.__name__))
                    sys.exit()

        # check tuple settings
        if not isinstance(self.settings.step_sizes, tuple):
            print('ERROR: step_sizes has to be of type tuple instead of {0}.'.format(
                self.settings.step_sizes.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.step_sizes) != 2:
                print('ERROR: step_sizes has to contain two entries.')
                sys.exit()
            for f in self.settings.step_sizes:
                if not isinstance(f, float):
                    print('ERROR: Tuple entries of step_sizes have to be of type float',
                          'instead of {0}.'.format(f.__class__.__name__))
                    sys.exit()
            if self.settings.step_sizes[0] <= 0.0:
                print('ERROR: First tuple entry of step_sizes has to be larger than 0',
                      'instead of {0}.'.format(self.settings.step_sizes[0]))
                sys.exit()
            if self.settings.step_sizes[1] < self.settings.step_sizes[0]:
                print('ERROR: Second tuple entry of step_sizes ({0})'.format(self.settings.step_sizes[1]),
                      'has to be larger than or equal to first tuple entry of step_sizes ({0}).'.format(
                          self.settings.step_sizes[0]))
                sys.exit()

        if not isinstance(self.settings.etas, tuple):
            print('ERROR: etas has to be of type tuple instead of {0}.'.format(
                self.settings.etas.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.etas) != 2:
                print('ERROR: etas has to contain two entries.')
                sys.exit()
            for f in self.settings.etas:
                if not isinstance(f, float):
                    print('ERROR: Tuple entries of etas have to be of type float instead of {0}.'.format(
                        f.__class__.__name__))
                    sys.exit()
            if self.settings.etas[0] <= 0.0 or self.settings.etas[0] > 1.0:
                print('ERROR: First tuple entry of etas has to be larger than or equal to 0',
                      'and smaller than 1 instead of {0}.'.format(self.settings.etas[0]))
                sys.exit()
            if self.settings.etas[1] < 1.0:
                print('ERROR: Second tuple entry of etas has to be larger than or equal to 0',
                      'and smaller than 1 instead of {0}.'.format(self.settings.etas[1]))
                sys.exit()

        if not isinstance(self.settings.betas, tuple):
            print('ERROR: betas has to be of type tuple instead of {0}.'.format(
                self.settings.betas.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.betas) != 4:
                print('ERROR: betas has to contain four entries.')
                sys.exit()
            for f in self.settings.betas:
                if not isinstance(f, float):
                    print('ERROR: Tuple entries of betas have to be of type float',
                          'instead of {0}.'.format(f.__class__.__name__))
                    sys.exit()
            if self.settings.betas[0] < 0.0 or self.settings.betas[0] >= 1.0:
                print('ERROR: First tuple entry of betas has to be larger than or equal to 0',
                      'and smaller than 1 instead of {0}.'.format(self.settings.betas[0]))
                sys.exit()
            if self.settings.betas[1] < 0.0 or self.settings.betas[1] >= 1.0:
                print('ERROR: Second tuple entry of betas has to be larger than or equal to 0',
                      'and smaller than 1 instead of {0}.'.format(self.settings.betas[1]))
                sys.exit()
            if self.settings.betas[2] <= 0.0:
                print('ERROR: Third tuple entry of betas has to be larger than 0 instead of {0}.'.format(
                    self.settings.betas[2]))
                sys.exit()
            if self.settings.betas[3] >= 1.0:
                print('ERROR: Fourth tuple entry of betas has to be smaller than 1 instead of {0}.'.format(
                    self.settings.betas[3]))
                sys.exit()

        if not isinstance(self.settings.loss_parameters, tuple):
            print('ERROR: loss_parameters has to be of type tuple instead of {0}.'.format(
                self.settings.loss_parameters.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.loss_parameters) != 2:
                print('ERROR: loss_parameters has to contain two entries.')
                sys.exit()
            for f in self.settings.loss_parameters:
                if not isinstance(f, float):
                    print('ERROR: List entries of loss_parameters have to be of type float',
                          'instead of {0}.'.format(f.__class__.__name__))
                    sys.exit()
                elif f < 0.0:
                    print('ERROR: List entries of loss_parameters have to be larger than or equal to 0',
                          'instead of {0}.'.format(f))
                    sys.exit()

        if not isinstance(self.settings.selection_range, tuple):
            print('ERROR: selection_range has to be of type tuple instead of {0}.'.format(
                self.settings.selection_range.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.selection_range) != 2:
                print('ERROR: selection_range has to contain two entries.')
                sys.exit()
            for f in self.settings.selection_range:
                if not isinstance(f, float):
                    print('ERROR: List entries of selection_range have to be of type float',
                          'instead of {0}.'.format(f.__class__.__name__))
                    sys.exit()
                elif f <= 0.0:
                    print('ERROR: List entries of selection_range have to be larger than 0',
                          'instead of {0}.'.format(f))
                    sys.exit()
            if self.settings.selection_range[0] >= 1.0:
                print('ERROR: First list entry of selection_range has to be smaller than 1',
                      'instead of {0}.'.format(f))
                sys.exit()
            if self.settings.selection_range[1] <= 1.0:
                print('ERROR: Second list entry of selection_range has to be larger than 1',
                      'instead of {0}.'.format(f))
                sys.exit()

        if not isinstance(self.settings.selection_thresholds, tuple):
            print('ERROR: selection_thresholds has to be of type tuple instead of {0}.'.format(
                self.settings.selection_thresholds.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.selection_thresholds) != 4:
                print('ERROR: selection_thresholds has to contain four entries.')
                sys.exit()
            else:
                for f in self.settings.selection_thresholds:
                    if not isinstance(f, float):
                        print('ERROR: List entries of selection_thresholds have to be of type float',
                              'instead of {0}.'.format(f.__class__.__name__))
                        sys.exit()
                    elif f <= 0.0:
                        print('ERROR: List entries of selection_thresholds have to be larger than 0',
                              'instead of {0}.'.format(f))
                        sys.exit()

        if not isinstance(self.settings.selection_strikes, tuple):
            print('ERROR: selection_strikes has to be of type tuple instead of {0}.'.format(
                self.settings.selection_strikes.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.selection_strikes) != 2:
                print('ERROR: selection_stikes has to contain two entries.')
                sys.exit()
            for i in self.settings.selection_strikes:
                if not isinstance(i, int):
                    print('ERROR: List entries of selection_strikes have to be of type int',
                          'instead of {0}.'.format(i.__class__.__name__))
                    sys.exit()
                elif i <= 0:
                    print('ERROR: List entries of selection_strikes have to be larger than 0',
                          'instead of {0}.'.format(i))
                    sys.exit()

        if not isinstance(self.settings.selection_small_strikes, tuple):
            print('ERROR: selection_small_strikes has to be of type tuple instead of {0}.'.format(
                self.settings.selection_small_strikes.__class__.__name__))
            sys.exit()
        else:
            if len(self.settings.selection_small_strikes) != 2:
                print('ERROR: selection_small_stikes has to contain two entries.')
                sys.exit()
            for i in self.settings.selection_small_strikes:
                if not isinstance(i, int):
                    print('ERROR: List entries of selection_small_strikes have to be of type int',
                          'instead of {0}.'.format(i.__class__.__name__))
                    sys.exit()
                elif i <= 0:
                    print('ERROR: List entries of selection_small_strikes have to be larger than 0',
                          'instead of {0}.'.format(i))
                    sys.exit()

####################################################################################################

    def check_device(self):
        '''
        Implementation: cpu, cuda

        Modify: device
        '''
        # implemented torch devices
        device_list = ['cpu', 'cuda']

        # check if device is implemented
        if self.settings.device not in device_list:
            print('ERROR: Using device {0} is not yet implemented in check device.'
                  .format(self.settings.device),
                  '\nPlease use one of the following devices:')
            for dev in device_list:
                print('{0}'.format(dev))
            sys.exit()

        # check if device is available
        if self.settings.device == 'cuda':
            if not torch.cuda.is_available():
                print('WARNING: Device cuda is not available.',
                      'Calculation is performed on device cpu instead.\n')
                self.settings.device = 'cpu'

        # set device
        self.device = self.settings.device

####################################################################################################

    def get_atomic_numbers(self):
        '''
        Implementation: H to Xe

        Return: atomic_numbers
        '''
        # implemented chemical elements
        atomic_numbers = {'H':   1, 'He':  2, 'Li':  3, 'Be':  4, 'B':   5, 'C':   6,
                          'N':   7, 'O':   8, 'F':   9, 'Ne': 10, 'Na': 11, 'Mg': 12,
                          'Al': 13, 'Si': 14, 'P':  15, 'S':  16, 'Cl': 17, 'Ar': 18,
                          'K':  19, 'Ca': 20, 'Sc': 21, 'Ti': 22, 'V':  23, 'Cr': 24,
                          'Mn': 25, 'Fe': 26, 'Co': 27, 'Ni': 28, 'Cu': 29, 'Zn': 30,
                          'Ga': 31, 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36,
                          'Rb': 37, 'Sr': 38, 'Y':  39, 'Zr': 40, 'Nb': 41, 'Mo': 42,
                          'Tc': 43, 'Ru': 44, 'Rh': 45, 'Pd': 46, 'Ag': 47, 'Cd': 48,
                          'In': 49, 'Sn': 50, 'Sb': 51, 'Te': 52, 'I':  53, 'Xe': 54}

        return atomic_numbers

####################################################################################################

    def get_element_types(self):
        '''
        Modify: element_types
        '''
        # Get and order unique element types
        atomic_numbers = self.get_atomic_numbers()
        element_types = np.unique(np.array(self.settings.element_types))
        missing = np.array([ele not in atomic_numbers for ele in element_types])
        if np.any(missing):
            print('ERROR: Elements {0} are not yet implemented.'.format(element_types[missing]),
                  '\nOnly elements from H to Xe are implemented.')
            sys.exit()
        element_types = element_types[np.argsort(np.array([
            atomic_numbers[ele] for ele in element_types]))]

        return element_types

####################################################################################################

    def get_unit_conversion(self):
        '''
        Implementation: energy_unit: eV, Hartree, kJ_mol
                        length_unit: Angstrom, Bohr, pm

        Modify: energy_conversion, force_conversion
        '''
        # implemented energy and length units
        energy_unit_list = ['eV', 'Hartree', 'kJ_mol']
        length_unit_list = ['Angstrom', 'Bohr', 'pm']

        # determine energy conversion factor
        if self.settings.energy_unit == 'eV':
            self.energy_conversion = 1.0
        elif self.settings.energy_unit == 'Hartree':
            self.energy_conversion = 1.0 / self.Hartree2eV
        elif self.settings.energy_unit == 'kJ_mol':
            self.energy_conversion = self.eV2kJ_mol
        else:
            print('ERROR: Energy unit {0} is not yet implemented for unit conversion.'
                  .format(self.settings.energy_unit),
                  '\nPlease use one of the following units:')
            for unit in energy_unit_list:
                print('{0}'.format(unit))
            sys.exit()

        # determine force conversion factor
        if self.settings.length_unit == 'Angstrom':
            self.force_conversion = self.energy_conversion
        elif self.settings.length_unit == 'Bohr':
            self.force_conversion = self.energy_conversion * self.Bohr2Angstrom
        elif self.settings.length_unit == 'pm':
            self.force_conversion = self.energy_conversion / 100.0
        else:
            print('ERROR: Length unit {0} is not yet implemented for unit conversion.'
                  .format(self.settings.length_unit),
                  '\nPlease use one of the following units:')
            for unit in length_unit_list:
                print('{0}'.format(unit))
            sys.exit()

####################################################################################################

    def print_time_memory(self, task_name, time_start, measure_memory=True):
        '''
        Output: time - time_start, max_memory_usage

        Modify: max_memory_usage
        '''
        print('Time {0:25} {1:.3f} s'.format(task_name, round(time() - time_start, 3)))
        if self.settings.memory_evaluation:
            if measure_memory:
                self.max_memory_usage = self.memory_usage(interval=1e-6 - 1e-9, timeout=1e-6, max_usage=True)
            print('Memory usage:                  {0:.3f} MB'.format(round(self.max_memory_usage, 3)))
        print('')

####################################################################################################

    def read_supplemental_potential(self):
        '''
        Requirement: Last parameter has to be the atomic energy of the central atom

        Implementation: element_energy, MieRc

        Modify: supplemental_potential_parameters, element_energy
        '''
        # implemented file formats
        supplemental_potential_type_list = ['element_energy', 'MieRc']

        # check existance of file
        if not path.isfile(self.settings.supplemental_potential_file):
            print('ERROR: Supplemental potential file {0} does not exist.'.format(
                self.settings.supplemental_potential_file))
            sys.exit()

        # read file format element_energy
        if self.settings.supplemental_potential_type == 'element_energy':
            self.read_element_energy()

        # read file format MieRc
        elif self.settings.supplemental_potential_type == 'MieRc':
            self.read_MieRc()

        # not implemented file format
        else:
            print('ERROR: Reading of supplemental potential type {0} is not yet implemented.'
                  .format(self.settings.supplemental_potential_type),
                  '\nPlease use one of the following types:')
            for sup_pot_type in supplemental_potential_type_list:
                print('{0}'.format(sup_pot_type))
            sys.exit()

####################################################################################################

    def read_element_energy(self):
        '''
        File format: [element] [E_atom in E_h]

        Modify: supplemental_potential_parameters, element_energy
        '''
        # check if pair contributions are disabled
        if self.settings.pair_contributions:
            print('ERROR: Pair contributions are not available for supplemental potential type element_energy')
            sys.exit()

        # read data
        with open(self.settings.supplemental_potential_file, encoding='utf-8') as f:
            data = [line.strip().split() for line in f if line.strip() and not line.startswith('#')]

        # order data
        self.supplemental_potential_parameters = {}
        self.element_energy = {atom[0]: float(atom[1]) * self.Hartree2eV for atom in data}

        # check if all elements have element energies
        for ele in self.element_types:
            if ele not in self.element_energy.keys():
                print('ERROR: Element energy of element {0} is not specified in {1}.'.format(
                    ele, self.settings.supplemental_potential_file))
                sys.exit()

####################################################################################################

    def read_MieRc(self):
        '''
        File format: [element_i]-[element_j] [C_n in Ang^-n] [C_m in Ang^-m] [n] [m] [R_c in Ang] [E_atom in E_h]

        Modify: supplemental_potential_parameters, element_energy
        '''
        # read data
        with open(self.settings.supplemental_potential_file, encoding='utf-8') as f:
            data = [line.strip().split() for line in f if line.strip() and not line.startswith('#')]

        # order data
        self.supplemental_potential_parameters = {atom[0]: [
            float(atom[1]), float(atom[2]), float(atom[3]), float(atom[4]), float(atom[5]),
            float(atom[6]) * self.Hartree2eV] for atom in data}
        self.element_energy = {ele: self.supplemental_potential_parameters[ele + '-' + ele][-1]
                               for ele in self.element_types}
        if not self.settings.pair_contributions:
            self.supplemental_potential_parameters = {}

####################################################################################################

    def calculate_atomic_and_pair_contributions(self, operator, elements, positions, lattices,
                                                energy, forces, n_structures, n_atoms_sys):
        '''
        Implementation: Addition ('+')
                        Subtraction ('-')

        Return: energy, forces
        '''
        # implemented operators
        operator_list = ['+', '-']

        # prepare calculation according to operator +
        if operator == '+':
            factor = 1

        # prepare calculation according to operator -
        elif operator == '-':
            factor = -1

        # not implemented operator
        else:
            print('ERROR: Operator {0} is not yet implemented'.format(operator),
                  'for handling of atomic and pair contributions.',
                  '\nPlease use one of the following operators:')
            for op in operator_list:
                print('{0}'.format(op))
            sys.exit()

        # enabled pair contributions
        if self.settings.pair_contributions:
            if lattices:
                print('ERROR: Supplemental potential pair contributions do not support',
                      'periodic boundary conditions yet.')
                sys.exit()
            if self.settings.QMMM:
                print('ERROR: Supplemental potential pair contributions do not support',
                      'QM/MM data yet.')
                sys.exit()
            # calculate atomic and pair contributions and apply the operation on the energy and forces
            for n in range(n_structures):
                # determine unique elements for each structure
                element_types = np.unique(elements[n])
                n_elements = len(element_types)
                # loop over all elements i
                for i in range(n_elements):
                    # determine atoms of element i
                    i_atoms = elements[n] == element_types[i]
                    # loop over all elements j
                    for j in range(n_elements):
                        # determine atoms of element j
                        j_atoms = elements[n] == element_types[j]
                        # calculate pair contributions
                        E_contributions, F_contributions = self.calculate_pair_contributions(
                            positions[n][i_atoms], positions[n][j_atoms], i == j,
                            self.supplemental_potential_parameters[element_types[i] + '-' + element_types[j]])
                        # apply the operation on the energy and force data for the pair contributions
                        if len(E_contributions) > 0:
                            energy[n] += factor * np.sum(E_contributions)
                            forces[n][i_atoms] += factor * F_contributions
                    # apply the operation on the energy data for the atomic contributions
                    energy[n] += (
                        factor * len(i_atoms[i_atoms])
                        * self.supplemental_potential_parameters[element_types[i] + '-' + element_types[i]][-1])
        # only enabled atomic contributions
        else:
            # calculate atomic contributions and apply the operation on the energy data
            for n in range(n_structures):
                energy[n] += factor * np.sum(np.array([
                    self.element_energy[ele] for ele in elements[n][:n_atoms_sys[n]]]))

        return energy, forces

####################################################################################################

    def calculate_pair_contributions(self, pos_i, pos_j, ii, sup_pot_parameters):
        '''
        Return: E_contributions, F_contributions
        '''
        # Check that there are pairs for the given elements i and j, otherwise return empty arrays
        n_pos_i = len(pos_i)
        if n_pos_i == 0 or len(pos_j) == 0 or (ii and n_pos_i == 1):
            return np.array([]), np.array([])

        # initialize arrays
        E_contributions = np.zeros(n_pos_i)
        F_contributions = np.zeros((n_pos_i, 3))

        # calculate pair contributions
        for i in range(n_pos_i):
            # calculate distance vectors
            d_ = pos_j - pos_i[i]
            # delete distance vector between atom i and itself
            if ii:
                d_ = np.delete(d_, i, 0)
            # calculate distances
            d = np.sqrt((d_**2).sum(axis=1))
            # calculate pair potentials
            E_i_contributions, F_i_lengths = self.calculate_supplemental_potential(
                d, sup_pot_parameters)
            # sum energy contributions for atom i
            E_contributions[i] = np.sum(E_i_contributions)
            # determine direction of force contributions and sum them for atom i
            F_contributions[i] = (d_.T / d * F_i_lengths).T.sum(axis=0)

        return E_contributions, F_contributions

####################################################################################################

    def calculate_supplemental_potential(self, d, sup_pot_parameters):
        '''
        Requirement: Calculation of contributions for single atoms. Double counting of energy
                     contributions has to be avoided (factor 1/2), while force contributions have
                     to be taken into account entirely on each atomic site of an interaction
                     (factor 1).
                     F = -dE/dR and F_i = -F_j have to be considered.

        Implementation: MieRc

        Return: E_contributions, F_lengths
        '''
        # implemented supplemental potential types
        supplemental_potential_type_list = ['MieRc']

        # calculate supplemental potential type MieRc
        if self.settings.supplemental_potential_type == 'MieRc':
            E_i_contributions, F_i_lengths = self.MieRc(
                d, sup_pot_parameters[0], sup_pot_parameters[1], sup_pot_parameters[2],
                sup_pot_parameters[3], sup_pot_parameters[4])

        # not implemented supplemental potential type
        else:
            print('ERROR: Calculating supplemental potential type {0} is not yet implemented.'
                  .format(self.settings.supplemental_potential_type),
                  '\nPlease use one of the following types:')
            for sup_pot_type in supplemental_potential_type_list:
                print('{0}'.format(sup_pot_type))
            sys.exit()

        return E_i_contributions, F_i_lengths

####################################################################################################

    def MieRc(self, R, C_n, C_m, n, m, R_c):
        '''
        Requirement: R > 0

        Return: E, F
        '''
        # initialize arrays
        n_R = len(R)
        E = np.zeros(n_R)
        F = np.zeros(n_R)

        # determine atomic distances smaller than the cutoff radius
        interaction = R < R_c
        R_inside = R[interaction]

        # calculate repeating terms of MieRc
        A = C_n / R_inside**n - C_m / R_inside**m
        b = 1.0 - (R_inside / R_c)**2
        B = np.exp(1.0 - 1.0 / b)

        # calculate energy and force values of the pair potential under the conditions
        # explained in the function calculate_supplemental_potential
        E[interaction] = 0.5 * A * B * self.Hartree2eV
        F[interaction] = ((-n * C_n / R_inside**(n + 1) + m * C_m / R_inside**(m + 1)
                          - A * 2.0 * R_inside / (R_c * b)**2) * B) * self.Hartree2eV

        return E, F

####################################################################################################

    def read_episodic_memory(self):
        '''
        Implementation: inputdata

        Output unit: Energy: eV
                     Length: Angstrom

        Return: elements, positions, lattices, atomic_classes, atomic_charges, energy, forces,
                n_structures, n_atoms, name
        '''
        # implemented file formats
        episodic_memory_format_list = ['inputdata']

        # check existance of file
        if not path.isfile(self.settings.episodic_memory_file):
            print('ERROR: Episodic memory file {0} does not exist.'.format(
                self.settings.episodic_memory_file))
            sys.exit()

        # read file format inputdata
        if self.settings.episodic_memory_format == 'inputdata':
            elements, positions, lattices, atomic_classes, atomic_charges, energy, forces, \
                n_structures, n_atoms, name = self.read_inputdata()

        # not implemented file format
        else:
            print('ERROR: Episodic memory format {0} is not yet implemented.'
                  .format(self.settings.episodic_memory_format),
                  '\nPlease use one of the following formats:')
            for epi_mem_format in episodic_memory_format_list:
                print('{0}'.format(epi_mem_format))
            sys.exit()

        return elements, positions, lattices, atomic_classes, atomic_charges, energy, forces, \
            n_structures, n_atoms, name

####################################################################################################

    def read_inputdata(self):
        '''
        Return: elements, positions, lattices, atomic_classes, atomic_charges, energy, forces,
                n_structures, n_atoms, name
        '''
        # initialize lists
        elements = []
        positions = []
        lattices = []
        atomic_classes = []
        atomic_charges = []
        energy = typed.List()
        forces = typed.List()
        name = []
        element = []
        position = []
        lattice = []
        force = []
        atomic_class = []
        atomic_charge = []

        # read and order data
        n_disregarded = 0
        n_disregarded_charges = 0
        n_disregarded_forces = 0
        name_disregarded = []
        E = None
        N = None
        with open(self.settings.episodic_memory_file, encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    line = line.strip().split()
                    # append preliminary lists of current structure
                    if line[0] == 'atom':
                        element.append(line[4])
                        position.append(line[1:4])
                        atomic_class.append(line[5])
                        atomic_charge.append(line[6])
                        force.append(line[7:10])
                    # append lattice list
                    if line[0] == 'lattice':
                        lattice.append(line[1:4])
                    # get name
                    elif line[0] == 'comment':
                        if line[1] == 'name' or line[1] == 'structure_id':
                            if N is None:
                                N = line[2]
                            else:
                                print('ERROR: Name/structure ID is provided more than once for a single structure',
                                      '\n(only one of these two properties can be provided for a single structure).')
                                sys.exit()
                    # get energy
                    elif line[0] == 'energy':
                        if E is None:
                            E = float(line[1]) * self.Hartree2eV
                        else:
                            print('ERROR: Energy is provided more than once for a single structure.')
                            sys.exit()
                    # adjust types and units of current structure lists and append them to reference data lists
                    elif line[0] == 'end':
                        atomic_class = (np.array(atomic_class).astype(float) + 1e-6).astype(int)
                        atomic_charge = np.array(atomic_charge).astype(float)
                        force = np.array(force).astype(float) * self.Hartree2eV / self.Bohr2Angstrom
                        # check if the structure is regarded
                        disregarded = False
                        if self.settings.QMMM:
                            if np.min(atomic_class) < 1 or np.max(atomic_class) > 2:
                                print('ERROR: Atomic class has to be 1 or 2 for QMMM.')
                                sys.exit()
                            atomic_charge_MM = atomic_charge[atomic_class == 2]
                            if len(atomic_charge_MM) > 0:
                                if np.max(np.absolute(atomic_charge_MM)) > self.settings.MM_atomic_charge_max:
                                    n_disregarded_charges += 1
                                    disregarded = True
                        if np.max(np.sqrt(np.sum(force**2, axis=1))) > self.settings.atomic_force_max:
                            n_disregarded_forces += 1
                            disregarded = True
                        if disregarded:
                            n_disregarded += 1
                            if N is not None:
                                name_disregarded.append(N)
                        else:
                            element = np.array(element).astype(str)
                            position = np.array(position).astype(float) * self.Bohr2Angstrom
                            lattice = np.array(lattice).astype(float) * self.Bohr2Angstrom
                            # prepare periodic systems
                            if len(lattice) > 0:
                                element, position, lattice, atomic_class, atomic_charge, E, force, n_atoms, _ = \
                                    self.prepare_periodic_systems(element, position, lattice, atomic_class,
                                                                  atomic_charge, E, force, len(element))
                            elements.append(element)
                            positions.append(position)
                            lattices.append(lattice)
                            atomic_classes.append(atomic_class)
                            atomic_charges.append(atomic_charge)
                            forces.append(force)
                            energy.append(E)
                            if N is not None:
                                name.append(N)
                        # empty structure lists and reset E and N to None
                        element = []
                        position = []
                        lattice = []
                        atomic_class = []
                        atomic_charge = []
                        force = []
                        E = None
                        N = None

        # get number of structures and atoms per structure
        n_structures = len(energy)
        n_atoms = np.array([len(structure) for structure in elements])

        # check number of structures
        if n_structures <= 0 and n_disregarded <= 0:
            print('ERROR: There is no structure in the episodic memory file.')
            sys.exit()
        if n_disregarded > 0:
            print('{0} structures are disregarded due to too high'.format(n_disregarded_forces),
                  'absolute atomic forces (atomic_force_max = {0}).'.format(self.settings.atomic_force_max))
            if self.settings.QMMM:
                print('{0} structures are disregarded due to too high'.format(n_disregarded_charges),
                      'absolute MM atomic charges (MM_atomic_charge_max = {0}).'.format(
                          self.settings.MM_atomic_charge_max),
                      '\n(These structures may already be disregarded by too high absolute atomic forces.)')
            print('')
            if n_structures <= 0:
                print('ERROR: All structures are disregarded.')
                sys.exit()

        # check if unique elements match element_types
        if self.settings.QMMM:
            elements_unique = np.unique([elements[n][i] for n in range(n_structures)
                                         for i in range(n_atoms[n]) if atomic_classes[n][i] == 1])
        else:
            elements_unique = np.unique([element for structure in elements for element in structure])
        if self.settings.prediction_only:
            if not np.all(np.isin(elements_unique, self.element_types)):
                print('ERROR: Some elements in episodic memory are not included in elements',
                      'in element_types.\n{0} not all in {1}'.format(
                          elements_unique, np.unique(self.element_types)))
                sys.exit()
        else:
            if len(elements_unique) != len(self.element_types):
                print('ERROR: Number of elements in episodic memory does not match number of elements',
                      'in element_types.\n{0} != {1}'.format(
                          elements_unique, np.unique(self.element_types)))
                sys.exit()
            if np.any(elements_unique != np.unique(self.element_types)):
                print('ERROR: Elements in episodic memory do not match elements in element_types.',
                      '\n{0} != {1}'.format(elements_unique, np.unique(self.element_types)))
                sys.exit()

        # check if names are provided for all or zero structures
        n_names = len(name)
        if n_names not in (0, n_structures):
            print('ERROR: Names/structure IDs have to be provided for all or zero structures',
                  '\nand only one of these two properties can be provided for a single structure.')
            sys.exit()

        # check if unique names are provided for all structures if generalization_format is 'lMLP' or
        # descriptor_on_disk is not 'None'
        if self.settings.generalization_format == 'lMLP' or self.settings.descriptor_on_disk != 'None':
            if n_names != n_structures:
                print('ERROR: Names/structure IDs have to be provided for all structures',
                      '\nif generalization_format is lMLP and/or descriptor_on_disk is not None.')
                sys.exit()
            if n_names != len(np.unique(np.array(name))):
                print('ERROR: Names/structure IDs have to be unique for all structures',
                      '\nif generalization_format is lMLP and/or descriptor_on_disk is not None.')
                sys.exit()

        # print disregarded structure names
        if n_disregarded > 0:
            print('Disregarded training data:     {0}'.format(n_disregarded))
            if len(name_disregarded) == n_disregarded:
                for N in name_disregarded:
                    print(N)
            print('')

        return elements, positions, lattices, atomic_classes, atomic_charges, energy, forces, \
            n_structures, n_atoms, name

####################################################################################################

    def read_descriptor_parameters(self):
        '''
        Output requirement: The order of the descriptor parameters needs to be always the same and
                            thus independent of the order in the descriptor input file.

        Implementation: ACSF, eeACSF

        Modify: descriptor_parameters, R_c
        '''
        # implemented descriptor types
        descriptor_type_list = ['ACSF', 'eeACSF']

        # check existance of file
        if not path.isfile(self.settings.descriptor_parameter_file):
            print('ERROR: Descriptor parameter file {0} does not exist.'.format(
                self.settings.descriptor_parameter_file))
            sys.exit()

        # read descriptor file ACSF and sort descriptors in a well defined order
        if self.settings.descriptor_type == 'ACSF':
            self.read_ACSF()
            self.sort_ACSF()

        # read descriptor file eeACSF and sort descriptors in a well defined order
        elif self.settings.descriptor_type == 'eeACSF':
            self.read_eeACSF()
            self.sort_eeACSF()

        # not implemented descriptor type
        else:
            print('ERROR: Reading of descriptor type {0} is not yet implemented.'
                  .format(self.settings.descriptor_type),
                  '\nPlease use one of the following types:')
            for des_type in descriptor_type_list:
                print('{0}'.format(des_type))
            sys.exit()

####################################################################################################

    def read_ACSF(self):
        '''
        Implementation: 1: Radial ACSF
                        2: Angular ACSF

        File format: First line: [R_c in a_0]
                     Following lines: 1: [ACSF type] [eta in a_0^-2]
                                      2: [ACSF type] [eta in a_0^-2] [lambda] [zeta] [xi]

        Modify: descriptor_parameters, R_c
        '''
        # implemented ACSF types
        # for an extension the new type has to be implemented in calculate_symmetry_function
        ACSF_type_list = ['1', '2']

        # read ACSF descriptor file
        with open(self.settings.descriptor_parameter_file, encoding='utf-8') as f:
            # get cutoff radius
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    self.R_c = float(line) * self.Bohr2Angstrom
                    break
            # get descriptor parameters
            self.descriptor_parameters = []
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    line = line.split()
                    # read radial ACSF
                    if line[0] == '1':
                        self.descriptor_parameters.append([
                            int(line[0]), float(line[1]) / self.Bohr2Angstrom**2])
                    # read angular ACSF
                    elif line[0] == '2':
                        self.descriptor_parameters.append([
                            int(line[0]), float(line[1]) / self.Bohr2Angstrom**2, int(line[2]),
                            int(line[3]), float(line[4])])
                    # not implemented ACSF type
                    else:
                        print('ERROR: Reading of ACSF type {0} is not yet implemented.'
                              .format(line[0]),
                              '\nPlease use one of the following types:')
                        for SF_type in ACSF_type_list:
                            print('{0}'.format(SF_type))
                        sys.exit()

####################################################################################################

    def sort_ACSF(self):
        '''
        Modify: descriptor_parameters
        '''
        # sort ACSF descriptors in a well defined order
        descriptor_parameters_sorted = sorted(
            self.n_element_types * [para for para in self.descriptor_parameters if para[0] == 1],
            key=lambda l: l[1])
        descriptor_parameters_sorted.extend(sorted(
            sorted(sorted(sorted(np.sum(np.arange(1, self.n_element_types + 1))
                                 * [para for para in self.descriptor_parameters if para[0] == 2],
                                 key=lambda l: l[4]), key=lambda l: l[3]), key=lambda l: l[2]),
            key=lambda l: l[1]))

        self.descriptor_parameters = descriptor_parameters_sorted

####################################################################################################

    def read_eeACSF(self):
        '''
        Implementation: 1: Radial eeACSF (subtypes: 0, 1, 2, 3, 4, 5, 6; QM/MM subtypes: 0, 1)
                        2: Angular eeACSF (subtypes: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12;
                                           QM/MM subtypes: 0, 1, 2, 3, 4)

        File format: First line: [R_c in a_0]
                     Following lines: 1: [eeACSF type] [eeACSF subtype] [eta in a_0^-2] <I_type_j>
                                      2: [eeACSF type] [eeACSF subtype] [eta in a_0^-2] [lambda]
                                         [zeta] [xi] <I_type_jk>

        Modify: descriptor_parameters, R_c
        '''
        # implemented eeACSF types
        # for an extension the new type has to be implemented in calculate_symmetry_function
        eeACSF_type_list = ['1', '2']
        # for an extension the new type has to be implemented in elem_rad_func and in get_H_parameters
        radial_eeACSF_subtype_list = ['0', '1', '2', '3', '4', '5', '6']
        # for an extension the new type needs to be implemented in elem_ang_func and in get_H_parameters
        angular_eeACSF_subtype_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12']
        # for an extension the new type has to be implemented in elem_rad_func and int_rad_eeACSF_QMMM
        radial_eeACSF_QMMM_subtype_list = ['0', '1']
        # for an extension the new type needs to be implemented in elem_ang_func and int_ang_eeACSF_QMMM
        angular_eeACSF_QMMM_subtype_list = ['0', '1', '2', '3', '4']

        # read eeACSF descriptor file
        with open(self.settings.descriptor_parameter_file, encoding='utf-8') as f:
            # get cutoff radius
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    self.R_c = float(line) * self.Bohr2Angstrom
                    break
            # get descriptor parameters
            self.descriptor_parameters = []
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    line = line.split()
                    # read radial eeACSF
                    if line[0] == '1':
                        if line[1] not in radial_eeACSF_subtype_list:
                            print('ERROR: Reading of radial eeACSF subtype {0} is not yet implemented.'
                                  .format(line[1]),
                                  '\nPlease use one of the following types:')
                            for SF_type in radial_eeACSF_subtype_list:
                                print('{0}'.format(SF_type))
                            sys.exit()
                        else:
                            if self.settings.QMMM:
                                if line[3] not in radial_eeACSF_QMMM_subtype_list:
                                    print('ERROR: Reading of radial eeACSF QM/MM subtype {0}'.format(line[3]),
                                          'is not yet implemented.\nPlease use one of the following types:')
                                    for SF_type in radial_eeACSF_QMMM_subtype_list:
                                        print('{0}'.format(SF_type))
                                    sys.exit()
                                else:
                                    self.descriptor_parameters.append([
                                        int(line[0]), int(line[1]), float(line[2]) / self.Bohr2Angstrom**2,
                                        int(line[3])])
                            else:
                                self.descriptor_parameters.append([
                                    int(line[0]), int(line[1]), float(line[2]) / self.Bohr2Angstrom**2])
                    # read angular eeACSF
                    elif line[0] == '2':
                        if line[1] not in angular_eeACSF_subtype_list:
                            print('ERROR: Reading of angular eeACSF subtype {0} is not yet implemented.'
                                  .format(line[1]),
                                  '\nPlease use one of the following types:')
                            for SF_type in angular_eeACSF_subtype_list:
                                print('{0}'.format(SF_type))
                            sys.exit()
                        else:
                            if self.settings.QMMM:
                                if line[6] not in angular_eeACSF_QMMM_subtype_list:
                                    print('ERROR: Reading of angular eeACSF QM/MM subtype {0}'.format(line[6]),
                                          'is not yet implemented.\nPlease use one of the following types:')
                                    for SF_type in angular_eeACSF_QMMM_subtype_list:
                                        print('{0}'.format(SF_type))
                                    sys.exit()
                                else:
                                    self.descriptor_parameters.append([
                                        int(line[0]), int(line[1]), float(line[2]) / self.Bohr2Angstrom**2,
                                        int(line[3]), int(line[4]), float(line[5]), int(line[6])])
                            else:
                                self.descriptor_parameters.append([
                                    int(line[0]), int(line[1]), float(line[2]) / self.Bohr2Angstrom**2,
                                    int(line[3]), int(line[4]), float(line[5])])
                    # not implemented eeACSF type
                    else:
                        print('ERROR: Reading of eeACSF type {0} is not yet implemented.'
                              .format(line[0]),
                              '\nPlease use one of the following types:')
                        for SF_type in eeACSF_type_list:
                            print('{0}'.format(SF_type))
                        sys.exit()

####################################################################################################

    def sort_eeACSF(self):
        '''
        Modify: descriptor_parameters
        '''
        # sort eeACSF descriptors in a well defined order
        if self.settings.QMMM:
            descriptor_parameters_sorted = sorted(sorted(sorted(
                [para for para in self.descriptor_parameters if para[0] == 1],
                key=lambda l: l[2]), key=lambda l: l[1]), key=lambda l: l[3])
            descriptor_parameters_sorted.extend(sorted(sorted(sorted(
                sorted(sorted(sorted([para for para in self.descriptor_parameters if para[0] == 2],
                                     key=lambda l: l[5]), key=lambda l: l[4]), key=lambda l: l[3]),
                key=lambda l: l[2]), key=lambda l: l[1]), key=lambda l: l[6]))
        else:
            descriptor_parameters_sorted = sorted(sorted(
                [para for para in self.descriptor_parameters if para[0] == 1],
                key=lambda l: l[2]), key=lambda l: l[1])
            descriptor_parameters_sorted.extend(sorted(sorted(
                sorted(sorted(sorted([para for para in self.descriptor_parameters if para[0] == 2],
                                     key=lambda l: l[5]), key=lambda l: l[4]), key=lambda l: l[3]),
                key=lambda l: l[2]), key=lambda l: l[1]))

        self.descriptor_parameters = descriptor_parameters_sorted

####################################################################################################

    def calculate_symmetry_function(self, elements_int_sys, positions, lattices, atomic_classes,
                                    atomic_charges, n_structures, n_atoms, n_atoms_sys, name):
        '''
        Implementation: Radial types: bump, Gaussian-bump, Gaussian-cosine
                        Angular types: bump, cosine, cosine_integer
                        Scaling types: cube root-scaled-shifted, linear, square root

        Return: descriptors_torch, [descriptor_i_derivatives_torch,
                descriptor_neighbor_derivatives_torch], neighbor_indices,
                descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, active_atoms,
                n_atoms_active, MM_gradients
        '''
        # implemented descriptor radial types
        descriptor_radial_type_list = ['bump', 'gaussian_bump', 'gaussian_cos']
        # implemented descriptor angular types
        descriptor_angular_type_list = ['bump', 'cos', 'cos_int']
        # implemented descriptor radial types
        descriptor_scaling_type_list = ['crss', 'linear', 'sqrt']
        # implemented descriptor on disk settings
        descriptor_on_disk_list = ['None', 'append', 'write']

        # read descriptor cache file if requested and available
        if self.settings.descriptor_cache_file is not None:
            if self.settings.descriptor_on_disk != 'None':
                if path.isfile(self.settings.descriptor_cache_file):
                    descriptors_torch, active_atoms, n_atoms_active, MM_gradients \
                        = self.read_descriptor_cache()
                    return descriptors_torch, [[], []], [], [], [], active_atoms, n_atoms_active, \
                        MM_gradients
            else:
                print('WARNING: A descriptor cache file can only be used if descriptor_on_disk is not None.',
                      '\nThe descriptor cache file is set to None.\n')
                self.settings.descriptor_cache_file = None

        # get bump radial function index
        if self.settings.descriptor_radial_type == 'bump':
            rad_func_index = 0
        # get Gaussian-bump radial function index
        elif self.settings.descriptor_radial_type == 'gaussian_bump':
            rad_func_index = 1
        # get Gaussian-cosine radial function index
        elif self.settings.descriptor_radial_type == 'gaussian_cos':
            rad_func_index = 2
        # not implemented descriptor radial type
        else:
            print('ERROR: Calculating descriptor radial type {0} is not yet implemented.'
                  .format(self.settings.descriptor_radial_type),
                  '\nPlease use one of the following types:')
            for des_rad_type in descriptor_radial_type_list:
                print('{0}'.format(des_rad_type))
            sys.exit()

        # get bump angular function index
        if self.settings.descriptor_angular_type == 'bump':
            ang_func_index = 0
        # get cosine angular function index
        elif self.settings.descriptor_angular_type == 'cos':
            ang_func_index = 1
        # get cosine integer angular function index
        elif self.settings.descriptor_angular_type == 'cos_int':
            ang_func_index = 2
        # not implemented descriptor angular type
        else:
            print('ERROR: Calculating descriptor angular type {0} is not yet implemented.'
                  .format(self.settings.descriptor_angular_type),
                  '\nPlease use one of the following types:')
            for des_ang_type in descriptor_angular_type_list:
                print('{0}'.format(des_ang_type))
            sys.exit()

        # check descriptor scaling type for QM/MM data
        if self.settings.QMMM and self.settings.descriptor_scaling_type != 'linear':
            print('ERROR: QM/MM data require descriptor scaling type linear.')
            sys.exit()
        # get cube root-scaled-shifted scaling function index
        if self.settings.descriptor_scaling_type == 'crss':
            scale_func_index = 0
        # get linear scaling function index
        elif self.settings.descriptor_scaling_type == 'linear':
            scale_func_index = 1
        # get square root scaling function index
        elif self.settings.descriptor_scaling_type == 'sqrt':
            scale_func_index = 2
        # not implemented descriptor scaling type
        else:
            print('ERROR: Calculating descriptor scaling type {0} is not yet implemented.'
                  .format(self.settings.descriptor_scaling_type),
                  '\nPlease use one of the following types:')
            for des_sca_type in descriptor_scaling_type_list:
                print('{0}'.format(des_sca_type))
            sys.exit()

        # check if descriptor on disk setting is implemented
        if self.settings.descriptor_on_disk not in descriptor_on_disk_list:
            print('ERROR: Descriptor on disk setting {0} is not yet implemented.'
                  .format(self.settings.descriptor_on_disk),
                  '\nPlease use one of the following settings:')
            for des_dis_setting in descriptor_on_disk_list:
                print('{0}'.format(des_dis_setting))
            sys.exit()
        # create descriptor directory if it does not exist
        if self.settings.descriptor_on_disk != 'None':
            Path(self.settings.descriptor_disk_dir).mkdir(parents=True, exist_ok=True)

        # determine number of radial and angular parameters
        n_parameters_ang = np.sum(np.array(
            [self.descriptor_parameters[i][0] for i in range(self.n_descriptors)])) - self.n_descriptors
        n_parameters_rad = self.n_descriptors - n_parameters_ang
        parameters_rad = np.array(self.descriptor_parameters[:n_parameters_rad])
        # get element-dependent radial and angular function index
        if self.settings.descriptor_type == 'eeACSF':
            element_types_rad = np.array([], dtype=int)
            if self.settings.QMMM:
                elem_func_index = 2
            else:
                elem_func_index = 0
        elif self.settings.descriptor_type == 'ACSF':
            element_types_rad = np.tile(
                np.arange(self.n_element_types), n_parameters_rad // self.n_element_types)
            if self.settings.QMMM:
                print('ERROR: QM/MM reference data cannot be represented by ACSFs.',
                      '\nPlease use the descriptor type eeACSFs.')
                sys.exit()
            else:
                elem_func_index = 1

        # radial and angular symmetry functions
        if n_parameters_ang > 0:
            parameters_ang = np.array(self.descriptor_parameters[n_parameters_rad:])
            n_element_types_ang = 0
            element_types_ang = np.array([], dtype=int)
            H_parameters_rad = np.array([], dtype=int)
            H_parameters_ang = np.array([], dtype=int)
            H_parameters_rad_scale = np.array([])
            H_parameters_ang_scale = np.array([])
            n_H_parameters = 0
            if self.settings.descriptor_type == 'ACSF':
                # determine angular parameters
                n_element_types_ang = np.sum(np.arange(1, self.n_element_types + 1))
                element_types_ang = np.tile(np.array(
                    [1000 * i + j for i in range(self.n_element_types)
                     for j in range(i, self.n_element_types)]), n_parameters_ang // n_element_types_ang)
                # slice parameters array
                H_type_jk = np.array([], dtype=int)
                eta_ij = parameters_rad[:, 1]
                eta_ijk = parameters_ang[:, 1]
                lambda_ijk = parameters_ang[:, 2]
                zeta_ijk = parameters_ang[:, 3]
                xi_ijk = parameters_ang[:, 4]
            elif self.settings.descriptor_type == 'eeACSF':
                # slice parameters array
                H_type_j = parameters_rad[:, 1].astype(int)
                H_type_jk = parameters_ang[:, 1].astype(int)
                eta_ij = parameters_rad[:, 2]
                eta_ijk = parameters_ang[:, 2]
                lambda_ijk = parameters_ang[:, 3]
                zeta_ijk = parameters_ang[:, 4]
                xi_ijk = parameters_ang[:, 5]
                # determine H parameters
                H_parameters_rad, H_parameters_ang, H_parameters_rad_scale, H_parameters_ang_scale, \
                    n_H_parameters = self.get_H_parameters(H_type_j, H_type_jk)
            if self.settings.QMMM:
                # slice parameters array
                I_type_j = parameters_rad[:, 3].astype(int)
                I_type_jk = parameters_ang[:, 6].astype(int)
                # check if H and I types are compatible
                if np.any(np.logical_and(np.greater(H_type_jk, n_H_parameters), np.greater(I_type_jk, 2))):
                    print('ERROR: Angular symmetry function subtypes cannot be larger than',
                          '{0} for QM/MM subtypes larger than 2.'.format(n_H_parameters))
                    sys.exit()
                # determine I parameters
                n_parameters_rad_env = len(np.arange(n_parameters_rad)[I_type_j > 0])
                n_parameters_ang_env = len(np.arange(n_parameters_ang)[I_type_jk > 0])
                MM_gradients = list(np.arange(self.n_descriptors)[
                    np.concatenate((I_type_j > 0, I_type_jk > 0))])
            else:
                # slice parameters array
                I_type_j = np.array([], dtype=int)
                I_type_jk = np.array([], dtype=int)
                # determine I parameters
                n_parameters_rad_env = 0
                n_parameters_ang_env = 0
                MM_gradients = []

            # calculate symmetry function values and derivatives
            descriptors_torch = []
            descriptor_i_derivatives_torch = []
            descriptor_neighbor_derivatives_torch = []
            neighbor_indices = []
            descriptor_neighbor_derivatives_torch_env = []
            neighbor_indices_env = []
            active_atoms = []
            n_atoms_active = []
            calc_derivatives = True
            if self.settings.descriptor_on_disk != 'None':
                descriptor_i_derivatives_torch.append([])
                descriptor_neighbor_derivatives_torch.append([])
                descriptor_neighbor_derivatives_torch_env.append([])
            for n in range(n_structures):
                # check if descriptor file already exists if requested
                if self.settings.descriptor_on_disk != 'None':
                    if self.settings.transfer_learning:
                        N = name[n][18:]
                    else:
                        N = name[n]
                    descriptor_file = self.settings.descriptor_disk_dir + '/descriptors_' + N + '.pt'
                    if self.settings.descriptor_on_disk == 'append' and path.isfile(descriptor_file):
                        calc_derivatives = False
                    else:
                        calc_derivatives = True
                    descriptor_i_derivatives_torch[-1] = []
                    descriptor_neighbor_derivatives_torch[-1] = []
                    descriptor_neighbor_derivatives_torch_env[-1] = []
                # determine interatomic distances and angles for each atom within the cutoff sphere
                neighbor_index, ij_0, ij_1, ij_2, ij_3, ij_4, ijk_0, ijk_1, ijk_2, ijk_3, ijk_4, \
                    ijk_5, ijk_6, ijk_7, ijk_8, ijk_9, ijk_10, ijk_11, ijk_12, active_atom, \
                    neighbor_index_env = self.calculate_atomic_environments(
                        elements_int_sys[n], positions[n], lattices[n], atomic_classes[n],
                        atomic_charges[n], n_atoms[n], n_atoms_sys[n],
                        calc_derivatives=calc_derivatives)
                active_atoms.append(active_atom)
                n_atoms_active.append(len(active_atom))
                neighbor_index_sys = [neighbor_index[i][neighbor_index[i] % n_atoms[n] < n_atoms_sys[n]]
                                      for i in range(len(neighbor_index))]
                # append lists if descriptors are not written to the disk
                if self.settings.descriptor_on_disk == 'None':
                    neighbor_indices.append([i % n_atoms[n] for i in neighbor_index_sys])
                    descriptor_i_derivatives_torch.append([])
                    descriptor_neighbor_derivatives_torch.append([])
                    neighbor_indices_env.append([i % n_atoms_sys[n] for i in neighbor_index_env])
                    descriptor_neighbor_derivatives_torch_env.append([])
                # calculate symmetry function value and derivative contributions
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', category=NumbaTypeSafetyWarning)
                    descriptor, descriptor_i_derivative, descriptor_neighbor_derivative, \
                        descriptor_neighbor_derivative_env = calc_descriptor_derivative(
                            ij_0, ij_1, ij_2, ij_3, ij_4, ijk_0, ijk_1, ijk_2, ijk_3, ijk_4, ijk_5,
                            ijk_6, ijk_7, ijk_8, ijk_9, ijk_10, ijk_11, ijk_12, neighbor_index,
                            n_atoms[n], n_atoms_sys[n], self.n_descriptors, elem_func_index,
                            rad_func_index, ang_func_index, scale_func_index, self.R_c, eta_ij,
                            H_parameters_rad, H_parameters_rad_scale, n_parameters_rad,
                            element_types_rad, eta_ijk, lambda_ijk, zeta_ijk, xi_ijk,
                            H_parameters_ang, H_parameters_ang_scale, n_parameters_ang, H_type_jk,
                            n_H_parameters, element_types_ang, calc_derivatives, self.settings.QMMM,
                            active_atom, n_atoms_active[n], neighbor_index_env, I_type_j,
                            n_parameters_rad_env, I_type_jk, n_parameters_ang_env,
                            self.settings.MM_atomic_charge_max)
                # compile symmetry function values
                descriptors_torch.append(torch.tensor(np.array(descriptor), requires_grad=True,
                                                      dtype=self.dtype_torch))
                if calc_derivatives:
                    for i in range(n_atoms_sys[n]):
                        # compile symmetry function derivatives with respect to the central atom i
                        descriptor_i_derivatives_torch[-1].append(torch.tensor(
                            descriptor_i_derivative[i], dtype=self.dtype_torch))
                        # compile symmetry function derivatives with respect to neighbor atoms of atom i
                        descriptor_neighbor_derivatives_torch[-1].append(torch.tensor(
                            descriptor_neighbor_derivative[i], dtype=self.dtype_torch))
                    # compile symmetry function derivatives with respect to neighbor active environment
                    # atoms of atom i
                    if self.settings.QMMM:
                        for i in range(n_atoms_active[n] - n_atoms_sys[n]):
                            descriptor_neighbor_derivatives_torch_env[-1].append(torch.tensor(
                                descriptor_neighbor_derivative_env[i], dtype=self.dtype_torch))
                    # write descriptors on disk if requested
                    if self.settings.descriptor_on_disk != 'None':
                        self.write_descriptors(
                            descriptor_file,
                            [descriptor_i_derivatives_torch[-1], descriptor_neighbor_derivatives_torch[-1]],
                            [i % n_atoms[n] for i in neighbor_index_sys],
                            descriptor_neighbor_derivatives_torch_env[-1],
                            [i % n_atoms_sys[n] for i in neighbor_index_env])

        # only radial symmetry functions
        else:
            H_parameters_rad = np.array([], dtype=int)
            H_parameters_rad_scale = np.array([])
            if self.settings.descriptor_type == 'ACSF':
                # slice parameter arrays
                eta_ij = parameters_rad[:, 1]
            elif self.settings.descriptor_type == 'eeACSF':
                # slice parameter arrays
                H_type_j = parameters_rad[:, 1].astype(int)
                eta_ij = parameters_rad[:, 2]
                # determine H parameters
                H_parameters_rad, H_parameters_ang, H_parameters_rad_scale, H_parameters_ang_scale, \
                    n_H_parameters = self.get_H_parameters(H_type_j, np.array([], dtype=int))
            if self.settings.QMMM:
                # slice parameters array
                I_type_j = parameters_rad[:, 3].astype(int)
                # determine I parameters
                n_parameters_rad_env = len(np.arange(n_parameters_rad)[I_type_j > 0])
                MM_gradients = list(np.arange(self.n_descriptors)[I_type_j > 0])
            else:
                # slice parameters array
                I_type_j = np.array([], dtype=int)
                # determine I parameters
                n_parameters_rad_env = 0
                MM_gradients = []

            # calculate symmetry function values and derivatives
            descriptors_torch = []
            descriptor_i_derivatives_torch = []
            descriptor_neighbor_derivatives_torch = []
            neighbor_indices = []
            descriptor_neighbor_derivatives_torch_env = []
            neighbor_indices_env = []
            active_atoms = []
            n_atoms_active = []
            calc_derivatives = True
            if self.settings.descriptor_on_disk != 'None':
                descriptor_i_derivatives_torch.append([])
                descriptor_neighbor_derivatives_torch.append([])
                descriptor_neighbor_derivatives_torch_env.append([])
            for n in range(n_structures):
                # check if descriptor file already exists if requested
                if self.settings.descriptor_on_disk != 'None':
                    if self.settings.transfer_learning:
                        N = name[n][18:]
                    else:
                        N = name[n]
                    descriptor_file = self.settings.descriptor_disk_dir + '/descriptors_' + N + '.pt'
                    if self.settings.descriptor_on_disk == 'append' and path.isfile(descriptor_file):
                        calc_derivatives = False
                    else:
                        calc_derivatives = True
                    descriptor_i_derivatives_torch[-1] = []
                    descriptor_neighbor_derivatives_torch[-1] = []
                    descriptor_neighbor_derivatives_torch_env[-1] = []
                # determine interatomic distances for each atom within the cutoff sphere
                neighbor_index, ij_0, ij_1, ij_2, ij_3, ij_4, ijk_0, ijk_1, ijk_2, ijk_3, ijk_4, \
                    ijk_5, ijk_6, ijk_7, ijk_8, ijk_9, ijk_10, ijk_11, ijk_12, active_atom, \
                    neighbor_index_env = self.calculate_atomic_environments(
                        elements_int_sys[n], positions[n], lattices[n], atomic_classes[n],
                        atomic_charges[n], n_atoms[n], n_atoms_sys[n],
                        calc_derivatives=calc_derivatives, angular=False)
                active_atoms.append(active_atom)
                n_atoms_active.append(len(active_atom))
                neighbor_index_sys = [neighbor_index[i][neighbor_index[i] % n_atoms[n] < n_atoms_sys[n]]
                                      for i in range(len(neighbor_index))]
                # append lists if descriptors are not written to the disk
                if self.settings.descriptor_on_disk == 'None':
                    neighbor_indices.append([i % n_atoms[n] for i in neighbor_index_sys])
                    descriptor_i_derivatives_torch.append([])
                    descriptor_neighbor_derivatives_torch.append([])
                    neighbor_indices_env.append([i % n_atoms_sys[n] for i in neighbor_index_env])
                    descriptor_neighbor_derivatives_torch_env.append([])
                # calculate symmetry function value and derivative contributions
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', category=NumbaTypeSafetyWarning)
                    descriptor, descriptor_i_derivative, descriptor_neighbor_derivative, \
                        descriptor_neighbor_derivative_env = calc_descriptor_derivative_radial(
                            ij_0, ij_1, ij_2, ij_3, ij_4, neighbor_index, n_atoms[n],
                            n_atoms_sys[n], self.n_descriptors, elem_func_index, rad_func_index,
                            scale_func_index, self.R_c, eta_ij, H_parameters_rad,
                            H_parameters_rad_scale, n_parameters_rad, element_types_rad,
                            calc_derivatives, self.settings.QMMM, active_atom, n_atoms_active[n],
                            neighbor_index_env, I_type_j, n_parameters_rad_env,
                            self.settings.MM_atomic_charge_max)
                # compile symmetry function values
                descriptors_torch.append(torch.tensor(np.array(descriptor), requires_grad=True,
                                                      dtype=self.dtype_torch))
                if calc_derivatives:
                    for i in range(n_atoms_sys[n]):
                        # compile symmetry function derivatives with respect to the central atom i
                        descriptor_i_derivatives_torch[-1].append(torch.tensor(
                            descriptor_i_derivative[i], dtype=self.dtype_torch))
                        # compile symmetry function derivatives with respect to neighbor atoms of atom i
                        descriptor_neighbor_derivatives_torch[-1].append(torch.tensor(
                            descriptor_neighbor_derivative[i], dtype=self.dtype_torch))
                    # compile symmetry function derivatives with respect to neighbor active environment
                    # atoms of atom i
                    if self.settings.QMMM:
                        for i in range(n_atoms_active[n] - n_atoms_sys[n]):
                            descriptor_neighbor_derivatives_torch_env[-1].append(torch.tensor(
                                descriptor_neighbor_derivative_env[i], dtype=self.dtype_torch))
                    # write descriptors on disk if requested
                    if self.settings.descriptor_on_disk != 'None':
                        self.write_descriptors(
                            descriptor_file,
                            [descriptor_i_derivatives_torch[-1], descriptor_neighbor_derivatives_torch[-1]],
                            [i % n_atoms[n] for i in neighbor_index_sys],
                            descriptor_neighbor_derivatives_torch_env[-1],
                            [i % n_atoms_sys[n] for i in neighbor_index_env])

        # convert number of active atoms from list to NumPy array
        n_atoms_active = np.array(n_atoms_active)

        # write descriptor cache file if requested
        if self.settings.descriptor_cache_file is not None:
            Path(self.settings.descriptor_cache_file).parent.mkdir(parents=True, exist_ok=True)
            self.write_descriptor_cache(descriptors_torch, active_atoms, n_atoms_active, MM_gradients)

        # measure maximal memory usage
        if self.settings.memory_evaluation:
            self.max_memory_usage = self.memory_usage(interval=1e-6 - 1e-9, timeout=1e-6, max_usage=True)

        # return empty lists for descriptor_derivatives_torch and neighbor_indices if descriptors
        # are saved on disk
        if self.settings.descriptor_on_disk != 'None':
            return descriptors_torch, [[], []], [], [], [], active_atoms, n_atoms_active, MM_gradients

        return descriptors_torch, [descriptor_i_derivatives_torch, descriptor_neighbor_derivatives_torch], \
            neighbor_indices, descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, \
            active_atoms, n_atoms_active, MM_gradients

####################################################################################################

    def read_descriptor_cache(self):
        '''
        Return: descriptors_torch, active_atoms, n_atoms_active, MM_gradients
        '''
        # read cache_settings, descriptors_torch, active_atoms, n_atoms_active, and MM_gradients
        checkpoint = torch.load(self.settings.descriptor_cache_file)
        cache_settings = checkpoint['cache_settings']
        descriptors_torch = checkpoint['descriptors_torch']
        active_atoms = checkpoint['active_atoms']
        n_atoms_active = checkpoint['n_atoms_active']
        MM_gradients = checkpoint['MM_gradients']

        # check cache settings
        if cache_settings[0] != self.settings.descriptor_type:
            print('ERROR: Spezified descriptor_type {0} is not equal to'
                  .format(self.settings.descriptor_type),
                  'descriptor cache descriptor_type {0} (file {1}).'
                  .format(cache_settings[0], self.settings.descriptor_cache_file))
            sys.exit()
        if cache_settings[1] != self.settings.descriptor_radial_type:
            print('ERROR: Spezified descriptor_radial_type {0} is not equal to'
                  .format(self.settings.descriptor_radial_type),
                  'descriptor cache descriptor_radial_type {0} (file {1}).'
                  .format(cache_settings[1], self.settings.descriptor_cache_file))
            sys.exit()
        if cache_settings[2] != self.settings.descriptor_angular_type:
            print('ERROR: Spezified descriptor_angular_type {0} is not equal to'
                  .format(self.settings.descriptor_angular_type),
                  'descriptor cache descriptor_angular_type {0} (file {1}).'
                  .format(cache_settings[2], self.settings.descriptor_cache_file))
            sys.exit()
        if cache_settings[3] != self.settings.descriptor_scaling_type:
            print('ERROR: Spezified descriptor_scaling_type {0} is not equal to'
                  .format(self.settings.descriptor_scaling_type),
                  'descriptor cache descriptor_scaling_type {0} (file {1}).'
                  .format(cache_settings[3], self.settings.descriptor_cache_file))
            sys.exit()
        if cache_settings[4] != self.settings.atomic_force_max:
            print('ERROR: Spezified atomic_force_max {0} is not equal to'
                  .format(self.settings.atomic_force_max),
                  'descriptor cache atomic_force_max {0} (file {1}).'
                  .format(cache_settings[4], self.settings.descriptor_cache_file))
            sys.exit()
        if cache_settings[5] != self.settings.QMMM:
            print('ERROR: Spezified QMMM {0} is not equal to'
                  .format(self.settings.QMMM),
                  'descriptor cache QMMM {0} (file {1}).'
                  .format(cache_settings[5], self.settings.descriptor_cache_file))
            sys.exit()
        if cache_settings[6] != self.settings.MM_atomic_charge_max:
            print('ERROR: Spezified MM_atomic_charge_max {0} is not equal to'
                  .format(self.settings.MM_atomic_charge_max),
                  'descriptor cache MM_atomic_charge_max {0} (file {1}).'
                  .format(cache_settings[6], self.settings.descriptor_cache_file))
            sys.exit()

        return descriptors_torch, active_atoms, n_atoms_active, MM_gradients

####################################################################################################

    def write_descriptor_cache(self, descriptors_torch, active_atoms, n_atoms_active, MM_gradients):
        '''
        Output: descriptor cache file
        '''
        # write cache_settings, descriptors_torch, active_atoms, n_atoms_active, and MM_gradients
        torch.save({
            'cache_settings': [self.settings.descriptor_type,
                               self.settings.descriptor_radial_type,
                               self.settings.descriptor_angular_type,
                               self.settings.descriptor_scaling_type,
                               self.settings.atomic_force_max,
                               self.settings.QMMM,
                               self.settings.MM_atomic_charge_max],
            'descriptors_torch': descriptors_torch,
            'active_atoms': active_atoms,
            'n_atoms_active': n_atoms_active,
            'MM_gradients': MM_gradients}, self.settings.descriptor_cache_file)

####################################################################################################

    def write_descriptors(self, descriptor_file, descriptor_derivatives_torch, neighbor_indices,
                          descriptor_neighbor_derivatives_torch_env, neighbor_indices_env):
        '''
        Output: descriptors_[name].pt file
        '''
        # write descriptor_derivatives_torch and neighbor_indices
        torch.save({
            'descriptor_derivatives_torch': descriptor_derivatives_torch,
            'neighbor_indices': neighbor_indices,
            'descriptor_neighbor_derivatives_torch_env': descriptor_neighbor_derivatives_torch_env,
            'neighbor_indices_env': neighbor_indices_env}, descriptor_file)

####################################################################################################

    def write_descriptor_output(self, descriptors_torch, elements, n_structures, n_atoms_sys):
        '''
        Implementation: ACSF, eeACSF

        Output: Files including descriptor values
        '''
        # implemented descriptor types
        descriptor_type_list = ['ACSF', 'eeACSF']

        # write ACSFs or eeACSFs
        if self.settings.descriptor_type in ('ACSF', 'eeACSF'):
            self.write_sym_fcts(descriptors_torch, elements, n_structures, n_atoms_sys)

        # not implemented descriptor type
        else:
            print('ERROR: Writing descriptor type {0} is not yet implemented.'
                  .format(self.settings.descriptor_type),
                  '\nPlease use one of the following types:')
            for des_type in descriptor_type_list:
                print('{0}'.format(des_type))
            sys.exit()

####################################################################################################

    def write_sym_fcts(self, descriptors_torch, elements, n_structures, n_atoms_sys):
        '''
        Implementation: ACSF: 1, 2
                        eeACSF: 1, 2

        Output: Files including symmetry function parameters and values
        '''
        # implemented symmetry function types
        descriptor_type_list = ['ACSF', 'eeACSF']
        ACSF_type_list = ['1', '2']
        eeACSF_type_list = ['1', '2']

        # write symmetry functions parameters
        with open(self.settings.descriptor_output_file, 'w', encoding='utf-8') as f:
            f.write('Parameters\n')
            # write ACSFs
            if self.settings.descriptor_type == 'ACSF':
                j = 0
                k = 0
                for para in self.descriptor_parameters:
                    # write radial ACSFs
                    if para[0] == 1:
                        f.write('{0:1d} {1:2} {2:9.6f}\n'.format(
                            para[0], self.element_types[j], round(para[1] * self.Bohr2Angstrom**2, 6)))
                        j = (j + 1) % self.n_element_types
                    # write angular ACSFs
                    elif para[0] == 2:
                        f.write('{0:1d} {1:2} {2:2} {3:9.6f} {4:2d} {5:2d} {6:9.6f}\n'.format(
                            para[0], self.element_types[j], self.element_types[k],
                            round(para[1] * self.Bohr2Angstrom**2, 6), para[2], para[3],
                            round(para[3], 6)))
                        k += 1
                        if k == self.n_element_types:
                            j = (j + 1) % self.n_element_types
                            k = j
                    # not implemented ACSFs
                    else:
                        print('ERROR: Writing of ACSF type {0} is not yet implemented.'
                              .format(para[0]),
                              '\nPlease use one of the following types:')
                        for SF_type in ACSF_type_list:
                            print('{0}'.format(SF_type))
                        sys.exit()
            # write eeACSFs
            elif self.settings.descriptor_type == 'eeACSF':
                for para in self.descriptor_parameters:
                    # write radial eeACSFs
                    if para[0] == 1:
                        f.write('{0:1d} {1:2d} {2:9.6f}\n'.format(
                            para[0], para[1], round(para[2] * self.Bohr2Angstrom**2, 6)))
                    # write angular eeACSFs
                    elif para[0] == 2:
                        f.write('{0:1d} {1:2d} {2:9.6f} {3:2d} {4:2d} {5:9.6f}\n'.format(
                            para[0], para[1], round(para[2] * self.Bohr2Angstrom**2, 6),
                            para[3], para[4], round(para[5], 6)))
                    # not implemented eeACSFs
                    else:
                        print('ERROR: Writing of eeACSF type {0} is not yet implemented.'
                              .format(para[0]),
                              '\nPlease use one of the following types:')
                        for SF_type in eeACSF_type_list:
                            print('{0}'.format(SF_type))
                        sys.exit()
            # not implemented symmetry function type
            else:
                print('ERROR: Writing descriptor type {0} is not yet implemented.'
                      .format(self.settings.descriptor_type),
                      '\nPlease use one of the following types:')
                for des_type in descriptor_type_list:
                    print('{0}'.format(des_type))
                sys.exit()

            # get atomic numbers dictionary
            atomic_numbers = self.get_atomic_numbers()

            # write symmetry function values
            f.write('Values\n')
            for n in range(n_structures):
                descriptor = descriptors_torch[n].cpu().detach().numpy()
                for i in range(n_atoms_sys[n]):
                    f.write('{0:5d} {1:2d} '.format(n + 1, atomic_numbers[elements[n][i]]))
                    for j in range(self.n_descriptors):
                        f.write('{0:10.6f} '.format(round(descriptor[i][j], 6)))
                    f.write('\n')

####################################################################################################

    def read_generalization(self, model_index=0):
        '''
        Implementation: lMLP, lMLP-only_prediction, RuNNer

        Modify: model, descriptor_parameters, R_c, element_energy, optimizer, training_state
        '''
        # implemented generalization formats
        generalization_format_list = ['lMLP', 'lMLP-only_prediction', 'RuNNer']

        # read lMLP model and training state
        if self.settings.generalization_format == 'lMLP':
            # check if generalization file exists
            if not path.isfile(self.settings.generalization_file):
                print('ERROR: Generalization file {0} does not exist.'.format(
                    self.settings.generalization_file))
                sys.exit()
            self.read_lMLP(model_index)

        # read lMLP model
        elif self.settings.generalization_format == 'lMLP-only_prediction':
            # check if generalization file exists
            if not path.isfile(self.settings.generalization_file):
                print('ERROR: Generalization file {0} does not exist.'.format(
                    self.settings.generalization_file))
                sys.exit()
            self.read_lMLP_model(model_index)

        # read atomic neural network weights in RuNNer format
        elif self.settings.generalization_format == 'RuNNer':
            # check if generalization directory exists
            if not path.isdir(self.settings.generalization_file):
                print('ERROR: Generalization directory {0} does not exist.'.format(
                    self.settings.generalization_file))
                sys.exit()
            self.read_weights(model_index)

        # not implemented generalization format
        else:
            print('ERROR: Generalization format {0} is not yet implemented.'
                  .format(self.settings.generalization_format),
                  '\nPlease use one of the following formats:')
            for gen_format in generalization_format_list:
                print('{0}'.format(gen_format))
            sys.exit()

####################################################################################################

    def read_lMLP(self, model_index):
        '''
        Modify: model, descriptor_parameters, R_c, element_energy, optimizer, training_state
        '''
        # read lMLP model and training state
        checkpoint = torch.load(self.settings.generalization_file)
        try:
            self.model[model_index].load_state_dict(checkpoint['model_state_dict'])
            self.descriptor_parameters = checkpoint['descriptor_parameters']
            self.R_c = checkpoint['R_c']
            self.element_energy = checkpoint['element_energy']
            self.energy_loss_function = checkpoint['energy_loss_function']
            self.forces_loss_function = checkpoint['forces_loss_function']
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            name = checkpoint['name']
            train = checkpoint['train']
            selection_hist = checkpoint['selection_hist']
            exclusion_hist = checkpoint['exclusion_hist']
            loss_old = checkpoint['loss_old']
            loss_selection_thresholds = checkpoint['loss_selection_thresholds']
            fraction_good = checkpoint['fraction_good']
            n_previous_epochs = checkpoint['n_previous_epochs']
        except KeyError:
            print('ERROR: lMLP model and training state file {0} is incomplete.'
                  .format(self.settings.generalization_file),
                  '\nPlease try generalization_format lMLP-only_prediction',
                  'to load a model only for inference.')
            sys.exit()

        # compile training state
        if len(name) > 0:
            self.training_state = [name, train, selection_hist, exclusion_hist, loss_old,
                                   loss_selection_thresholds, fraction_good, n_previous_epochs]
        else:
            self.training_state = []

####################################################################################################

    def read_lMLP_model(self, model_index):
        '''
        Modify: model, descriptor_parameters, R_c, element_energy
        '''
        # read lMLP model
        checkpoint = torch.load(self.settings.generalization_file)
        try:
            self.model[model_index].load_state_dict(checkpoint['model_state_dict'])
            self.descriptor_parameters = checkpoint['descriptor_parameters']
            self.R_c = checkpoint['R_c']
            self.element_energy = checkpoint['element_energy']
        except KeyError:
            print('ERROR: lMLP model file {0} is broken.'.format(
                self.settings.generalization_file))
            sys.exit()

####################################################################################################

    def read_weights(self, model_index):
        '''
        Modify: model
        '''
        # determine atomic neural network architecture
        if self.settings.scale_shift_layer:
            N_i = self.model[model_index].N_hidden_layers + 2
            N_j = ([self.model[model_index].n_descriptors]
                   + self.model[model_index].n_neurons_hidden_layers + [1])
            N_k = (2 * [self.model[model_index].n_descriptors]
                   + self.model[model_index].n_neurons_hidden_layers)
        else:
            N_i = self.model[model_index].N_hidden_layers + 1
            N_j = self.model[model_index].n_neurons_hidden_layers + [1]
            N_k = ([self.model[model_index].n_descriptors]
                   + self.model[model_index].n_neurons_hidden_layers)

        # get atomic numbers string dictionary
        atomic_numbers = self.get_atomic_numbers()

        # read weights and update model
        for ele in range(self.n_element_types):
            weights_file = '{0}/weights.{1:03}.data'.format(
                self.settings.generalization_file, atomic_numbers[self.element_types[ele]])
            if not path.isfile(weights_file):
                print('ERROR: Weights file {0} does not exist.'.format(weights_file))
                sys.exit()
            with open(weights_file, encoding='utf-8') as f:
                i_layer = 0
                for i in range(N_i):
                    if not hasattr(self.model[model_index].atomic_neural_networks[ele][i_layer], 'weight'):
                        i_layer += 1
                    for k in range(N_k[i]):
                        for j in range(N_j[i]):
                            if i == 0 and self.settings.scale_shift_layer:
                                if j == k:
                                    self.model[model_index].atomic_neural_networks[ele][i].weight.data[j] = \
                                        float(f.readline().strip().split()[0])
                                else:
                                    f.readline()
                            else:
                                self.model[model_index].atomic_neural_networks[ele][i_layer].weight.data[j][k] = \
                                    float(f.readline().strip().split()[0])
                    for j in range(N_j[i]):
                        if i == 0 and self.settings.scale_shift_layer:
                            self.model[model_index].atomic_neural_networks[ele][i].bias.data[j] = (
                                -float(f.readline().strip().split()[0])
                                / self.model[model_index].atomic_neural_networks[ele][i].weight.data[j])
                        else:
                            self.model[model_index].atomic_neural_networks[ele][i_layer].bias.data[j] = \
                                float(f.readline().strip().split()[0])
                    i_layer += 1

####################################################################################################

    def prepare_transfer_learning(self):
        '''
        Modify: generalization_setting_file, generalization_file, supplemental_potential_parameters,
                element_energy
        '''
        # check that restart is True
        if not self.settings.restart:
            print('ERROR: Transfer learning requires restart = True.')
            sys.exit()

        # check that generalization format is lMLP
        if self.settings.generalization_format != 'lMLP':
            print("ERROR: Transfer learning requires generalization_format = 'lMLP'.")
            sys.exit()

        # modify generalization setting file and generalization file names
        if '.ini' in self.settings.generalization_setting_file:
            self.settings.generalization_setting_file = self.settings.generalization_setting_file.replace(
                '.ini', '_T.ini')
        else:
            print("ERROR: The generalization setting file name does not contain the substring '.ini'.",
                  '\nAutomatic renaming of the generalization setting file is therefore not possible.')
            sys.exit()
        if '.pt' in self.settings.generalization_file:
            self.settings.generalization_file = self.settings.generalization_file.replace('.pt', '_T.pt')
        else:
            print("ERROR: The generalization file name does not contain the substring '.pt'.",
                  '\nAutomatic renaming of the generalization file is therefore not possible.')
            sys.exit()

        # read modified supplemental potential
        self.read_supplemental_potential()

####################################################################################################

    def initialize_scale_shift_layer(self, train, elements_int_sys, descriptors_torch,
                                     model_index=0):
        '''
        Modify: model
        '''
        # define maximal weight in scale and shift layer initialization
        max_weight = 10.0

        # flatten element and descriptor arrays
        elements_int_flat = np.array([atom for n in train for atom in list(elements_int_sys[n])])
        descriptors_flat = np.array([
            atom for n in train for atom in descriptors_torch[n].cpu().detach().tolist()])

        # calculate scale and shift layer weight initialization
        for ele in range(self.n_element_types):
            with np.errstate(divide='ignore'):
                weight_init = 1.0 / np.std(descriptors_flat[elements_int_flat == ele], axis=0)
            weight_init[weight_init > max_weight] = max_weight
            weight_init = torch.tensor(weight_init, dtype=self.dtype_torch)
            bias_init = torch.tensor(np.mean(descriptors_flat[elements_int_flat == ele], axis=0),
                                     dtype=self.dtype_torch)

            # update model
            self.model[model_index].atomic_neural_networks[ele][0].weight.data = weight_init
            self.model[model_index].atomic_neural_networks[ele][0].bias.data = bias_init

####################################################################################################

    def initialize_weights(self, train, elements_int_sys, energy, n_structures_train, n_atoms_sys,
                           model_index=0):
        '''
        Implementation: Weight initialization: default, sTanh, Tanh, Tanhshrink

        Modify: model
        '''
        # implemented weight initializations
        weight_initialization_list = ['default', 'sTanh', 'Tanh', 'Tanhshrink']

        # default weight initialization
        if self.settings.weight_initialization == 'default':
            pass

        # weight initialization sTanh, Tanh, and Tanhshrink
        elif self.settings.weight_initialization in ('sTanh', 'Tanh', 'Tanhshrink'):
            if len(train) > 1:
                self.initialize_weights_tanh(
                    train, elements_int_sys, energy, n_structures_train, n_atoms_sys, model_index)
            else:
                print('WARNING: Weight initialization {0}'.format(self.settings.weight_initialization),
                      'requires more than one structure in the training data set.',
                      '\nThe default weight initialization is used instead of {0}.\n'
                      .format(self.settings.weight_initialization))

        # not implemented weight initialization
        else:
            print('ERROR: Weight initialization {0} is not yet implemented.'
                  .format(self.settings.weight_initialization),
                  '\nPlease use one of the following weight initializations:')
            for wei_ini in weight_initialization_list:
                print('{0}'.format(wei_ini))
            sys.exit()

####################################################################################################

    def initialize_weights_tanh(self, train, elements_int_sys, energy, n_structures_train,
                                n_atoms_sys, model_index):
        '''
        Modify: model
        '''
        # define magic numbers sTanh
        if self.settings.weight_initialization == 'sTanh':
            magic_number_1 = np.sqrt(3)
            magic_number_2 = 0.885

        # define magic numbers Tanh
        elif self.settings.weight_initialization == 'Tanh':
            magic_number_1 = 5.377
            magic_number_2 = 0.798

        # define magic numbers Tanhshrink
        elif self.settings.weight_initialization == 'Tanhshrink':
            magic_number_1 = 2.882
            magic_number_2 = 0.620

        # not implemented weight initialization
        else:
            print('ERROR: Magic numbers of weight initialization {0} are not yet implemented.'.format(
                self.settings.weight_initialization))
            sys.exit()

        # check if all elements of the reference data set are in the training data set
        n_element_types_train = len(np.unique(np.array(
            [ele for n in train for ele in elements_int_sys[n]])))
        if self.n_element_types != n_element_types_train:
            print('ERROR: For the weight initializations sTanh, Tanh, and Tanhshrink all element types',
                  '\nof the reference data set have to be in the training data set as well',
                  '\nbut only {0} of {1} element types are included in the training data set.'
                  .format(n_element_types_train, self.n_element_types))
            sys.exit()

        # determine required energy arrays
        energy_train = np.array(energy)[train]
        energy_per_atom = energy_train / n_atoms_sys[train]
        energy_per_atom_std = np.std(energy_per_atom)

        # determine stoichiometry
        if n_structures_train >= self.n_element_types:
            stoichiometry = np.zeros((n_structures_train, self.n_element_types), dtype=int)
            # determine unique stoichiometries
            for m, n in enumerate(train):
                indices, counts = np.unique(elements_int_sys[n], return_counts=True)
                stoichiometry[m][indices] = counts
            assignments = np.arange(self.n_element_types)
            # determine varied stoichiometries
            for ele_1 in range(self.n_element_types - 1):
                if assignments[ele_1] == ele_1:
                    for ele_2 in range(ele_1 + 1, self.n_element_types):
                        if assignments[ele_2] == ele_2:
                            with np.errstate(divide='ignore'):
                                indices = np.unique((stoichiometry[:, [ele_1, ele_2]].T // np.gcd(
                                    stoichiometry[:, ele_1], stoichiometry[:, ele_2])).T, axis=0)
                            if len(indices) == 1:
                                stoichiometry[:, ele_1] += stoichiometry[:, ele_2]
                                assignments[ele_2] = ele_1
            n_stoichiometry = self.n_element_types
            for ele in range(-1, -self.n_element_types - 1, -1):
                if assignments[ele] != self.n_element_types + ele:
                    stoichiometry = np.delete(stoichiometry, self.n_element_types + ele, axis=1)
                    n_stoichiometry -= 1
            assignments = rankdata(assignments, method='dense') - 1
        else:
            n_stoichiometry = 1

        # determine mean energy per atom
        if n_stoichiometry > 1:
            energy_per_atom_mean = np.linalg.lstsq(stoichiometry, energy_train, rcond=None)[0]
            energy_per_atom_mean = [energy_per_atom_mean[assignments[ele]]
                                    for ele in range(self.n_element_types)]
        else:
            energy_per_atom_mean = np.mean(energy_per_atom)
            energy_per_atom_mean = [energy_per_atom_mean for ele in range(self.n_element_types)]

        # calculate weights and update model
        for ele in range(self.n_element_types):
            for layer in self.model[model_index].atomic_neural_networks[ele][:-1]:
                if isinstance(layer, torch.nn.Linear):
                    torch.nn.init.uniform_(
                        layer.weight,
                        a=-magic_number_1 / np.sqrt(layer.weight.size(1)),
                        b=magic_number_1 / np.sqrt(layer.weight.size(1)))
                    torch.nn.init.zeros_(layer.bias)
            layer = self.model[model_index].atomic_neural_networks[ele][-1]
            torch.nn.init.uniform_(
                layer.weight,
                a=-energy_per_atom_std / magic_number_2 / np.sqrt(layer.weight.size(1)),
                b=energy_per_atom_std / magic_number_2 / np.sqrt(layer.weight.size(1)))
            torch.nn.init.constant_(layer.bias, energy_per_atom_mean[ele])

####################################################################################################

    def freeze_weights(self, model_index=0):
        '''
        Modify model
        '''
        # adjust numbering according to scale_shift_layer
        frozen_layers = np.array(self.settings.frozen_layers)
        if not self.settings.scale_shift_layer:
            frozen_layers -= 1

        # freeze and unfreeze weights
        for ele in range(self.n_element_types):
            i = 0
            for layer in self.model[model_index].atomic_neural_networks[ele]:
                if not hasattr(layer, 'weight'):
                    continue
                if i in frozen_layers:
                    layer.weight.requires_grad_(False)
                    layer.bias.requires_grad_(False)
                else:
                    layer.weight.requires_grad_(True)
                    layer.bias.requires_grad_(True)
                i += 1

####################################################################################################

    def define_optimization(self, model_index=0):
        '''
        Implementation: Adadelta, Adadelta2, Adagrad, Adam, Adam2, Adamax, Adamax2, CoRe, NAG, NAG2,
                        RMSprop, RMSprop2, Rprop, Rprop2, SGD, SGDmomentum, SGDmomentum2

        Modify: optimizer
        '''
        # implemented optimizers
        optimizer_list = ['Adadelta', 'Adadelta2', 'Adagrad', 'Adam', 'Adam2', 'Adamax', 'Adamax2',
                          'CoRe', 'NAG', 'NAG2', 'RMSprop', 'RMSprop2', 'Rprop', 'Rprop2', 'SGD',
                          'SGDmomentum', 'SGDmomentum2']

        # create CoRe optimizer
        if self.settings.optimizer == 'CoRe':
            # determine frozen list
            if self.settings.frozen > 0:
                if self.settings.scale_shift_layer:
                    frozen_list = [0, 0]
                else:
                    frozen_list = []
                frozen_list += [int(self.settings.frozen * self.n_descriptors
                                * self.settings.n_neurons_hidden_layers[0]),
                                int(self.settings.frozen * self.settings.n_neurons_hidden_layers[0])]
                for i in range(1, len(self.settings.n_neurons_hidden_layers)):
                    frozen_list += [int(self.settings.frozen * self.settings.n_neurons_hidden_layers[i - 1]
                                    * self.settings.n_neurons_hidden_layers[i]),
                                    int(self.settings.frozen * self.settings.n_neurons_hidden_layers[i])]
                frozen_list += [0, 0]
            else:
                frozen_list = [0]
            # determine weight decay list
            if self.settings.weight_decay > 0.0:
                if self.settings.scale_shift_layer:
                    weight_decay_list = [0.01, 0.01]
                else:
                    weight_decay_list = []
                weight_decay_list += [self.settings.weight_decay, self.settings.weight_decay] * len(
                    self.settings.n_neurons_hidden_layers) + [0.0, 0.0]
            else:
                weight_decay_list = [0.0]

            self.optimizer = CoRe(
                self.model[model_index].parameters(), lr=self.settings.learning_rate,
                step_sizes=self.settings.step_sizes, betas=self.settings.betas,
                etas=self.settings.etas, weight_decay=weight_decay_list,
                score_history=self.settings.score_history, frozen=frozen_list,
                foreach=self.settings.foreach)

        # create Adam optimizer
        elif self.settings.optimizer == 'Adam':
            self.optimizer = torch.optim.Adam(
                self.model[model_index].parameters(), lr=self.settings.learning_rate)

        # create Adam2 optimizer
        elif self.settings.optimizer == 'Adam2':
            self.optimizer = torch.optim.Adam(
                self.model[model_index].parameters(), lr=self.settings.learning_rate, betas=(0.9, 0.99))

        # create Rprop optimizer
        elif self.settings.optimizer == 'Rprop':
            self.optimizer = torch.optim.Rprop(
                self.model[model_index].parameters(), lr=self.settings.learning_rate,
                step_sizes=self.settings.step_sizes)

        # create Rprop2 optimizer
        elif self.settings.optimizer == 'Rprop2':
            self.optimizer = torch.optim.Rprop(
                self.model[model_index].parameters(), lr=self.settings.learning_rate,
                step_sizes=self.settings.step_sizes, etas=(0.7375, 1.2))

        # create Adadelta optimizer
        elif self.settings.optimizer == 'Adadelta':
            self.optimizer = torch.optim.Adadelta(
                self.model[model_index].parameters(), lr=self.settings.learning_rate)

        # create Adadelta2 optimizer
        elif self.settings.optimizer == 'Adadelta2':
            self.optimizer = torch.optim.Adadelta(
                self.model[model_index].parameters(), lr=self.settings.learning_rate, rho=0.975)

        # create Adagrad optimizer
        elif self.settings.optimizer == 'Adagrad':
            self.optimizer = torch.optim.Adagrad(
                self.model[model_index].parameters(), lr=self.settings.learning_rate)

        # create Adamax optimizer
        elif self.settings.optimizer == 'Adamax':
            self.optimizer = torch.optim.Adamax(
                self.model[model_index].parameters(), lr=self.settings.learning_rate)

        # create Adamax2 optimizer
        elif self.settings.optimizer == 'Adamax2':
            self.optimizer = torch.optim.Adamax(
                self.model[model_index].parameters(), lr=self.settings.learning_rate, betas=(0.9, 0.935))

        # create NAG optimizer
        elif self.settings.optimizer == 'NAG':
            self.optimizer = torch.optim.SGD(
                self.model[model_index].parameters(), lr=self.settings.learning_rate, momentum=0.9,
                nesterov=True)

        # create NAG2 optimizer
        elif self.settings.optimizer == 'NAG2':
            self.optimizer = torch.optim.SGD(
                self.model[model_index].parameters(), lr=self.settings.learning_rate, momentum=0.94,
                nesterov=True)

        # create RMSprop optimizer
        elif self.settings.optimizer == 'RMSprop':
            self.optimizer = torch.optim.RMSprop(
                self.model[model_index].parameters(), lr=self.settings.learning_rate)

        # create RMSprop2 optimizer
        elif self.settings.optimizer == 'RMSprop2':
            self.optimizer = torch.optim.RMSprop(
                self.model[model_index].parameters(), lr=self.settings.learning_rate, alpha=0.9925)

        # create SGD optimizer
        elif self.settings.optimizer == 'SGD':
            self.optimizer = torch.optim.SGD(
                self.model[model_index].parameters(), lr=self.settings.learning_rate)

        # create SGDmomentum optimizer
        elif self.settings.optimizer == 'SGDmomentum':
            self.optimizer = torch.optim.SGD(
                self.model[model_index].parameters(), lr=self.settings.learning_rate, momentum=0.9)

        # create SGDmomentum2 optimizer
        elif self.settings.optimizer == 'SGDmomentum2':
            self.optimizer = torch.optim.SGD(
                self.model[model_index].parameters(), lr=self.settings.learning_rate, momentum=0.94)

        # not implemented optimizer
        else:
            print('ERROR: Optimizer {0} is not yet implemented.'
                  .format(self.settings.optimizer),
                  '\nPlease use one of the following optimizers:')
            for opt in optimizer_list:
                print('{0}'.format(opt))
            sys.exit()

####################################################################################################

    def define_loss_function(self):
        '''
        Implementation: HuberLoss, L1Loss, MSELoss, SmoothL1Loss

        Return: energy_loss_function, forces_loss_function
        '''
        # implemented loss functions
        loss_function_list = ['HuberLoss', 'L1Loss', 'MSELoss', 'SmoothL1Loss']

        # create MSE loss function
        if self.settings.loss_function == 'MSELoss':
            energy_loss_function = torch.nn.modules.loss.MSELoss(reduction='sum')
            forces_loss_function = torch.nn.modules.loss.MSELoss(reduction='sum')

        # create Huber loss function
        elif self.settings.loss_function == 'HuberLoss':
            energy_loss_function = torch.nn.modules.loss.HuberLoss(
                reduction='sum', delta=self.settings.loss_parameters[0])
            forces_loss_function = torch.nn.modules.loss.HuberLoss(
                reduction='sum', delta=self.settings.loss_parameters[1])

        # create smooth L1 loss function
        elif self.settings.loss_function == 'SmoothL1Loss':
            energy_loss_function = torch.nn.modules.loss.SmoothL1Loss(
                reduction='sum', beta=self.settings.loss_parameters[0])
            forces_loss_function = torch.nn.modules.loss.SmoothL1Loss(
                reduction='sum', beta=self.settings.loss_parameters[1])

        # create L1 loss function
        elif self.settings.loss_function == 'L1Loss':
            energy_loss_function = torch.nn.modules.loss.L1Loss(reduction='sum')
            forces_loss_function = torch.nn.modules.loss.L1Loss(reduction='sum')

        # not implemented loss function
        else:
            print('ERROR: Loss function {0} is not yet implemented.'
                  .format(self.settings.loss_function),
                  '\nPlease use one of the following loss functions:')
            for los_fct in loss_function_list:
                print('{0}'.format(los_fct))
            sys.exit()

        return energy_loss_function, forces_loss_function

####################################################################################################

    def define_train_test_splitting(self, name, n_structures):
        '''
        Modify: training_state

        Return: train, test, n_structures_train, n_structures_test
        '''
        # assign previous training state to current data
        if len(self.training_state) > 0:
            name_old_train = [self.training_state[0][i] for i in self.training_state[1]]
            old_train = [i for i in range(n_structures) if name[i] in name_old_train]
            new = [i for i in range(n_structures) if name[i] not in self.training_state[0]]
            n_structures_old_train = len(old_train)
            n_structures_new = len(new)
        else:
            new = np.arange(n_structures)
            n_structures_old_train = 0
            n_structures_new = n_structures

        # determine number of training and test structures
        n_structures_new_train = int(n_structures_new * self.settings.training_fraction)
        n_structures_train = n_structures_old_train + n_structures_new_train
        n_structures_test = n_structures - n_structures_train

        # check if training and test structures exist
        if n_structures_train < 1:
            print('ERROR: The training data set does not contain any structures.',
                  '\nPlease increase the training fraction or increase the number of',
                  'reference structures in the episodic memory.')
            sys.exit()
        if n_structures_test < 1:
            print('WARNING: The test data set does not contain any structures.\n')

        # determine splitting
        selection = np.zeros(n_structures, dtype=bool)
        if n_structures_old_train > 0:
            selection[old_train] = True
        if n_structures_new_train > 0:
            selection_new = np.ones(n_structures_new, dtype=bool)
            selection_new[n_structures_new_train:] = False
            self.rng.shuffle(selection_new)
            selection[new] = selection_new

        # create arrays of training and test data indices
        train = np.arange(n_structures)[selection]
        test = np.arange(n_structures)[np.invert(selection)]

        # create list of assignments
        assignment = ['train' if s else 'test' for s in selection]

        # complete training state
        if len(self.training_state) > 0:
            order_new_train = np.zeros(n_structures, dtype=bool)
            order_new_train[old_train] = True
            order_new_train = order_new_train[train]
            order_old_train = [name_old_train.index(N) for N in name if N in name_old_train]
            selection_hist = np.ones((n_structures_train, 2))
            selection_hist[order_new_train] = self.training_state[2][order_old_train]
            exclusion_hist = np.zeros(n_structures_train, dtype=int)
            exclusion_hist[order_new_train] = self.training_state[3][order_old_train]
            loss_old = np.empty((n_structures_train, 2))
            loss_old.fill(np.nan)
            loss_old[order_new_train] = self.training_state[4][order_old_train]
            loss_selection_thresholds = np.empty((n_structures_train, 2, 4))
            loss_selection_thresholds.fill(np.nan)
            loss_selection_thresholds[order_new_train] = self.training_state[5][order_old_train]
            self.training_state = [name, train, selection_hist, exclusion_hist, loss_old,
                                   loss_selection_thresholds, self.training_state[6],
                                   self.training_state[7]]

        return train, test, n_structures_train, n_structures_test, assignment

####################################################################################################

    def fit_model(self, train, test, elements_int_sys, descriptors_torch,
                  descriptor_derivatives_torch, neighbor_indices,
                  descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, energy_torch,
                  forces_torch, n_structures_train, n_structures_test, n_atoms_active, n_atoms_sys,
                  MM_gradients, name):
        '''
        Implementation: energy_unit: eV, Hartree, kJ_mol
                        length_unit: Angstrom, Bohr, pm
                        selection_scheme: lADS, random

        Output: Optimized model parameters

        Modify: model, training_state

        Return: selection
        '''
        # implemented energy units
        energy_unit_list = ['eV', 'Hartree', 'kJ_mol']
        length_unit_list = ['Angstrom', 'Bohr', 'pm']

        # implemented selection schemes
        selection_scheme_list = ['lADS', 'random']

        # implemented selection measures
        selection_measure_list = ['E+F_losses', 'total_loss']

        # implemented late data schemes
        late_data_scheme_list = ['None', 'bottom', 'random', 'top']

        # check existance of requested energy units
        if (
                self.settings.energy_unit not in energy_unit_list
                or self.settings.length_unit not in length_unit_list):
            print('ERROR: Combination of energy unit {0} and length unit {1}'
                  .format(self.settings.energy_unit, self.settings.length_unit),
                  'is not yet implemented for the fit output.',
                  '\nPlease use one of the following combinations:')
            for ene_unit in energy_unit_list:
                for len_unit in length_unit_list:
                    print('{0} and {1}'.format(ene_unit, len_unit))
            sys.exit()

        # set energy_preoptimization to False if forces are not fitted
        if self.settings.energy_preoptimization_step and not self.settings.fit_forces:
            print('WARNING: Energy preoptimization step is not sensible if force fitting is disabled.',
                  '\nTherefore, energy preoptimization step is disabled.\n')
            self.settings.energy_preoptimization_step = False

        # set loss_E_scaling to 1.0 if forces are not fitted
        if not self.settings.fit_forces and self.settings.loss_E_scaling != 1.0:
            print('WARNING: loss_E_scaling is set to 1 since forces are not fitted.\n')
            self.settings.loss_E_scaling = 1.0

        # determine number of structures used in an epoch
        if isinstance(self.settings.fit_fraction, float):
            n_structures_fit = int(self.settings.fit_fraction * n_structures_train)
            if n_structures_fit < 1:
                print('WARNING: The number of structures to be fitted in an epoch was set to zero',
                      'due to a too small fit fraction.',
                      '\nTherefore, the fit fraction has been increased from {0} to {1}.\n'
                      .format(self.settings.fit_fraction, round(1.0 / n_structures_train, 6)))
                n_structures_fit = 1
        else:
            n_structures_fit = min(self.settings.fit_fraction, n_structures_train)

        # initialize selection scheme 'lADS'
        n_previous_epochs = 0
        if self.settings.selection_scheme == 'lADS':
            # not implemented selection measure
            if self.settings.selection_measure not in selection_measure_list:
                print('ERROR: Selection measure {0} is not yet implemented.'
                      .format(self.settings.selection_measure),
                      '\nPlease use one of the following selection measures:')
                for sel_mea in selection_measure_list:
                    print('{0}'.format(sel_mea))
                sys.exit()
            if self.settings.selection_measure == 'E+F_losses':
                if not self.settings.fit_forces:
                    print('ERROR: Selection measure E+F_losses requires fitting of forces',
                          '(i.e., fit_forces = True).')
                    sys.exit()
            if self.settings.selection_range[0] >= 1.0 or self.settings.selection_range[1] <= 1.0:
                print('ERROR: Selection range requires [selection_min < 1.0, selection_max > 1.0]')
                sys.exit()
            selection_min = self.settings.selection_range[0]
            selection_max = self.settings.selection_range[1]
            decrease_factor = self.settings.selection_range[0]**(
                1.0 / self.settings.selection_strikes[0])
            increase_factor = self.settings.selection_range[1]**(
                1.0 / self.settings.selection_strikes[1])
            decrease_factor_small = self.settings.selection_range[0]**(
                1.0 / self.settings.selection_small_strikes[0])
            increase_factor_small = self.settings.selection_range[1]**(
                1.0 / self.settings.selection_small_strikes[1])
            n_redundant_max = max(1, int(self.settings.fraction_redundant_max * n_structures_fit
                                         + 0.999999))
            fraction_change = self.settings.fraction_good_max / float(
                self.settings.n_fraction_intervals)
            if not self.training_state:
                selection_hist = np.ones((n_structures_train, 2))
                exclusion_hist = np.zeros(n_structures_train, dtype=int)
                loss_old = np.empty((n_structures_train, 2))
                loss_old.fill(np.nan)
                loss_old_mean = np.array([np.inf, np.inf])
                loss_selection_thresholds = np.empty((n_structures_train, 2, 4))
                loss_selection_thresholds.fill(np.nan)
                fraction_good = 0.0
                n_good = 0
                n_bad = n_structures_fit
            # restart previous training state
            else:
                selection_hist = self.training_state[2]
                exclusion_hist = self.training_state[3]
                loss_old = self.training_state[4]
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', message='Mean of empty slice')
                    loss_old_mean = np.nanmean(loss_old[selection_hist[:, 0] > 0.0], axis=0)
                loss_old_mean[np.isnan(loss_old_mean)] = np.inf
                loss_selection_thresholds = self.training_state[5]
                # reset fraction good to zero for transfer learning
                if self.settings.transfer_learning:
                    fraction_good = 0.0
                    n_good = 0
                    n_bad = n_structures_fit
                else:
                    fraction_good = self.training_state[6]
                    n_good = int(fraction_good * n_structures_fit)
                    n_bad = n_structures_fit - n_good
                n_previous_epochs = self.training_state[7]
            # prepare backtracking of gradients for deselected training structures
            if self.settings.gradient_backtracking:
                if self.settings.optimizer != 'CoRe':
                    print('WARNING: Gradient backtracking for deselected training structures',
                          'is only available for the CoRe optimizer.',
                          '\nTherefore, gradient backtracking is disabled.\n')
                    self.settings.gradient_backtracking = False
                else:
                    selection_hist_factors = selection_max / selection_max**(np.arange(
                        self.settings.selection_strikes[1]) / self.settings.selection_strikes[1])
                    selection_max_exclusion = selection_max**(
                        self.settings.exclusion_strikes / self.settings.selection_strikes[1])
                    exclusion_hist_factors = selection_max_exclusion / selection_max_exclusion**(
                        np.arange(self.settings.exclusion_strikes) / self.settings.selection_strikes[1])
                    n_groups = 2 * self.n_element_types * (1 * self.settings.scale_shift_layer + len(
                        self.settings.n_neurons_hidden_layers) + 1)
            if self.settings.stationary_point_prob_factor > 1.0:
                mean_atomic_force_len = np.array([float(torch.mean(torch.sqrt(torch.sum(
                    forces_torch[n]**2, axis=1)))) for n in train])
                a = (self.settings.stationary_point_prob_factor - 1.0) / self.settings.atomic_force_max**2
                probability_forces = 1.0 / (a * mean_atomic_force_len**2 + 1.0)
            else:
                probability_forces = np.ones(n_structures_train)
            # initialize late data scheme
            if self.settings.late_data_scheme == 'random':
                late = np.zeros(n_structures_train, dtype=bool)
                late[:int(n_structures_train * self.settings.late_data_fraction)] = True
                self.rng.shuffle(late)
                selection_hist[late] = -4.0 * np.ones(2)
            elif self.settings.late_data_scheme == 'bottom':
                late = np.ones(n_structures_train, dtype=bool)
                late[:int(n_structures_train * (1.0 - self.settings.late_data_fraction))] = False
                selection_hist[late] = -4.0 * np.ones(2)
            elif self.settings.late_data_scheme == 'top':
                late = np.zeros(n_structures_train, dtype=bool)
                late[:int(n_structures_train * self.settings.late_data_fraction)] = True
                selection_hist[late] = -4.0 * np.ones(2)
            elif self.settings.late_data_scheme != 'None':
                print('ERROR: Late data scheme {0} is not yet implemented.'
                      .format(self.settings.late_data_scheme),
                      '\nPlease use one of the following late data schemes:')
                for lat_dat_sch in late_data_scheme_list:
                    print('{0}'.format(lat_dat_sch))
                sys.exit()

        # initialize selection scheme random
        elif self.settings.selection_scheme == 'random':
            selection = np.zeros(n_structures_train, dtype=bool)
            selection[:n_structures_fit] = True

        # not implemented selection scheme
        else:
            print('ERROR: Selection scheme {0} is not yet implemented.'
                  .format(self.settings.selection_scheme),
                  '\nPlease use one of the following selection schemes:')
            for sel_sch in selection_scheme_list:
                print('{0}'.format(sel_sch))
            sys.exit()

        # set writting interval for weights
        if self.settings.write_weights_interval is None:
            write_weights_interval = np.nan
        else:
            write_weights_interval = self.settings.write_weights_interval

        # set memory evaluation epoch
        if self.settings.memory_evaluation:
            memory_evaluation_epoch = n_previous_epochs + self.settings.n_epochs - 1
        else:
            memory_evaluation_epoch = np.nan

        # define energy units and write header of fitting progress
        E_unit = {'eV': 'eV    ', 'Hartree': 'E_h   ', 'kJ_mol': 'kJ/mol'}
        F_unit = {'eV': {'Angstrom': 'eV/Ang    ', 'Bohr': 'eV/a_0    ', 'pm': 'eV/pm     '},
                  'Hartree': {'Angstrom': 'E_h/Ang   ', 'Bohr': 'E_h/a_0   ', 'pm': 'E_h/pm    '},
                  'kJ_mol': {'Angstrom': 'kJ/mol/Ang', 'Bohr': 'kJ/mol/a_0', 'pm': 'kJ/mol/pm '}}
        if self.settings.QMMM:
            print('Epoch | Loss(E_train) | Loss(F_train) | Time / s |',
                  'RMSE(E_train) | RMSE(E_test)  | RMSE(F_train) | RMSE(F_test)  |RMSE(train_env)|RMSE(test_env)')
            print('      |               |               |          |',
                  '/ {0}      | / {0}      | / {1}  | / {1}  | / {1}  | / {1}'.format(
                      E_unit[self.settings.energy_unit],
                      F_unit[self.settings.energy_unit][self.settings.length_unit]))
        else:
            print('Epoch | Loss(E_train) | Loss(F_train) | Time / s |',
                  'RMSE(E_train) | RMSE(E_test)  | RMSE(F_train) | RMSE(F_test)')
            print('      |               |               |          |',
                  '/ {0}      | / {0}      | / {1}  | / {1}'.format(
                      E_unit[self.settings.energy_unit],
                      F_unit[self.settings.energy_unit][self.settings.length_unit]))
        print('————————————————————————————————————————————————————————'
              + '————————————————————————————————————————————————————————', end='')
        if self.settings.QMMM:
            print('————————————————————————————————')
        else:
            print('')

        # print assignment of data for selection scheme 'lADS'
        if self.settings.selection_scheme == 'lADS':
            self.print_results_lADS(selection_hist[:, 0])

        # perform training epochs
        n_atoms_train = n_atoms_sys[train]
        n_atoms_test_sum = np.sum(n_atoms_sys[test])
        if self.settings.QMMM:
            n_atoms_train_env = n_atoms_active[train] - n_atoms_sys[train]
            n_atoms_test_sum_env = np.sum(n_atoms_active[test] - n_atoms_sys[test])
        for epoch in range(n_previous_epochs, n_previous_epochs + self.settings.n_epochs):
            time_start = time()
            # choice of training data for selection scheme 'lADS'
            if self.settings.selection_scheme == 'lADS':
                # update selection_hist according to late data scheme and reset n_structures_fit
                # in case of updates
                if self.settings.late_data_scheme != 'None' and (epoch - n_previous_epochs
                                                                 == self.settings.late_data_epoch):
                    selection_hist[late] = 1.0 * np.ones(2)
                    if isinstance(self.settings.fit_fraction, float):
                        n_structures_fit = max(1, int(self.settings.fit_fraction * n_structures_train))
                    else:
                        n_structures_fit = min(self.settings.fit_fraction, n_structures_train)
                    n_good = int(fraction_good * n_structures_fit)
                    n_bad = n_structures_fit - n_good
                # update number of structures used in an epoch if required
                n_structures_fit_max = len(selection_hist[:, 0][selection_hist[:, 0] > 0.0])
                if n_structures_fit > n_structures_fit_max:
                    if n_structures_fit_max < 1:
                        print('ERROR: All structures of the training data set have been disregarded.')
                        sys.exit()
                    print('WARNING: The number of fitted structures per epoch has been decreased to',
                          '{0} due to disregard of structures.\n'.format(n_structures_fit_max))
                    n_structures_fit = n_structures_fit_max
                    n_redundant_max = max(1, int(self.settings.fraction_redundant_max * n_structures_fit
                                                 + 0.999999))
                    n_good = int(fraction_good * n_structures_fit)
                    n_bad = n_structures_fit - n_good
                # choose bad structures
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', message='All-NaN slice encountered')
                    loss_old_max = np.nanmax(loss_old[selection_hist[:, 0] > 0.0], axis=0)
                    if np.nanmax(loss_old_max) <= 0.0:
                        if np.max(loss_old[selection_hist[:, 0] > 0.0, 0]) <= 0.0:
                            print('All selected training structures yield a loss of zero. Training is stopped.')
                            break
                if self.settings.selection_measure == 'E+F_losses':
                    if loss_old_max[0] > 0.0:
                        probability_loss = np.clip(loss_old[:, 0] / loss_old_max[0],
                                                   loss_old_mean[0] / loss_old_max[0], None)
                    else:
                        probability_loss = np.ones(n_structures_train)
                    if loss_old_max[1] > 0.0:
                        probability_loss *= np.clip(loss_old[:, 1] / loss_old_max[1],
                                                    loss_old_mean[1] / loss_old_max[1], None)
                    probability_loss = np.sqrt(probability_loss)
                elif self.settings.selection_measure == 'total_loss':
                    if loss_old_max[0] > 0.0:
                        probability_loss = np.clip(loss_old[:, 0] / loss_old_max[0],
                                                   loss_old_mean[0] / loss_old_max[0], None)
                    else:
                        probability_loss = np.ones(n_structures_train)
                probability_bad = (selection_hist[:, 0] * np.absolute(selection_hist[:, 1])
                                   * probability_forces * probability_loss)
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', message='All-NaN slice encountered')
                    probability_bad_max = np.nanmax(np.array([selection_max, np.nanmax(probability_bad)]))
                probability_bad[np.isnan(probability_bad)] = probability_bad_max
                probability_bad[selection_hist[:, 0] <= 0.0] = 0.0
                probability_bad /= np.sum(probability_bad)
                fit = self.rng.choice(np.arange(n_structures_train), size=n_bad, replace=False,
                                      p=probability_bad)
                # choose good structures
                if n_good > 0:
                    option = np.setdiff1d(np.arange(n_structures_train), fit, assume_unique=True)
                    probability_good = (
                        selection_hist[:, 0][option] * np.absolute(selection_hist[:, 1][option])
                        * probability_forces[option] * (1.0 - probability_loss[option]))
                    with warnings.catch_warnings():
                        warnings.filterwarnings('ignore', message='All-NaN slice encountered')
                        probability_good_min = np.nanmin(np.array([
                            selection_min, np.nanmin(probability_good[probability_good > 0.0])]))
                    probability_good[probability_loss[option] >= 1.0] = probability_good_min
                    probability_good[np.isnan(probability_good)] = probability_good_min
                    probability_good[selection_hist[:, 0][option] <= 0.0] = 0.0
                    probability_good /= np.sum(probability_good)
                    fit = np.concatenate((fit, self.rng.choice(
                        option, size=n_good, replace=False, p=probability_good)))

            # choice of training data for selection scheme random
            elif self.settings.selection_scheme == 'random':
                self.rng.shuffle(selection)
                fit = np.arange(n_structures_train)[selection]

            # initialize losses and SEs
            E_loss = 0.0
            if self.settings.fit_forces:
                F_loss = 0.0
            if self.settings.selection_scheme == 'lADS':
                loss_new_fit = np.empty((n_structures_fit, 2))
                loss_new_fit.fill(np.nan)
            if epoch % self.settings.RMSE_interval == 0:
                E_SE_train = 0.0
                E_SE_test = 0.0
                if self.settings.fit_forces:
                    F_SE_train = 0.0
                    F_SE_test = 0.0
                    if self.settings.QMMM:
                        F_SE_train_env = 0.0
                        F_SE_test_env = 0.0
            # calculate sum of model parameter gradients for fit structures
            n_atoms_fit_sum = np.sum(n_atoms_train[fit])
            if self.settings.QMMM:
                n_atoms_fit_sum_env = np.sum(n_atoms_train_env[fit])

            # energy preoptimization step
            if self.settings.energy_preoptimization_step:
                for m in range(n_structures_fit):
                    n = train[fit[m]]
                    # predict energy of fit structures
                    energy_prediction_torch, forces_prediction_torch = self.calculate_energy_forces(
                        n, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                        neighbor_indices, n_atoms_sys, descriptor_neighbor_derivatives_torch_env,
                        neighbor_indices_env, n_atoms_active, MM_gradients, name, calc_forces=False)
                    # calculate energy loss
                    loss_E = self.energy_loss_function(
                        energy_prediction_torch / n_atoms_sys[n], energy_torch[n] / n_atoms_sys[n]) * (
                            self.settings.loss_E_scaling / n_structures_fit)
                    # calculate energy loss gradients
                    loss_E.backward()
                # optimize weights
                self.optimizer.step()
                # reset gradients to zero
                self.optimizer.zero_grad(set_to_none=True)

            # optimization step
            for m in range(n_structures_fit):
                n = train[fit[m]]
                # predict energy and forces of fit structures
                energy_prediction_torch, forces_prediction_torch = self.calculate_energy_forces(
                    n, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                    neighbor_indices, n_atoms_sys, descriptor_neighbor_derivatives_torch_env,
                    neighbor_indices_env, n_atoms_active, MM_gradients, name,
                    calc_forces=self.settings.fit_forces)
                # calculate loss
                loss_E = self.energy_loss_function(
                    energy_prediction_torch / n_atoms_sys[n], energy_torch[n] / n_atoms_sys[n]) * (
                        self.settings.loss_E_scaling / n_structures_fit)
                if self.settings.fit_forces:
                    loss_F = self.forces_loss_function(forces_prediction_torch, forces_torch[n]) / (
                        3 * n_atoms_fit_sum)
                    loss = loss_E + loss_F
                else:
                    loss = loss_E
                # calculate loss gradients
                loss.backward()
                # get losses for printing
                with torch.no_grad():
                    E_loss += loss_E.item()
                    if self.settings.fit_forces:
                        F_loss += loss_F.item()
                    # get losses and loss selection thresholds for selection scheme 'lADS'
                    if self.settings.selection_scheme == 'lADS':
                        if self.settings.selection_measure == 'E+F_losses':
                            loss_new_fit[m, 0] = loss_E.item() * n_structures_fit
                            loss_new_fit[m, 1] = loss_F.item() * n_atoms_fit_sum / n_atoms_sys[n]
                            for i in range(4):
                                loss_selection_thresholds[fit[m], 0, i] = self.energy_loss_function(
                                    (self.settings.selection_thresholds[i] * energy_prediction_torch
                                     / n_atoms_sys[n]),
                                    (self.settings.selection_thresholds[i] * energy_torch[n]
                                     / n_atoms_sys[n])).item() * self.settings.loss_E_scaling
                                loss_selection_thresholds[fit[m], 1, i] = self.forces_loss_function(
                                    self.settings.selection_thresholds[i] * forces_prediction_torch,
                                    self.settings.selection_thresholds[i] * forces_torch[n]).item() / (
                                        3 * n_atoms_sys[n])
                        elif self.settings.selection_measure == 'total_loss':
                            if self.settings.fit_forces:
                                loss_new_fit[m, 0] = loss_E.item() * n_structures_fit + loss_F.item() * (
                                    n_atoms_fit_sum / n_atoms_sys[n])
                            else:
                                loss_new_fit[m, 0] = loss_E.item() * n_structures_fit
                            for i in range(4):
                                loss_selection_thresholds[fit[m], 0, i] = self.energy_loss_function(
                                    (self.settings.selection_thresholds[i] * energy_prediction_torch
                                     / n_atoms_sys[n]),
                                    (self.settings.selection_thresholds[i] * energy_torch[n]
                                     / n_atoms_sys[n])).item() * self.settings.loss_E_scaling
                                if self.settings.fit_forces:
                                    loss_selection_thresholds[fit[m], 0, i] += self.forces_loss_function(
                                        self.settings.selection_thresholds[i] * forces_prediction_torch,
                                        self.settings.selection_thresholds[i] * forces_torch[n]).item() / (
                                            3 * n_atoms_sys[n])
                    # calculate SEs of fit structures
                    if epoch % self.settings.RMSE_interval == 0:
                        E_SE_train += float(((
                            energy_prediction_torch[0] - energy_torch[n][0]) / n_atoms_sys[n])**2)
                        if self.settings.fit_forces:
                            F_SE_train += calculate_F_SE(forces_prediction_torch[:n_atoms_sys[n]],
                                                         forces_torch[n][:n_atoms_sys[n]])
                            if self.settings.QMMM:
                                F_SE_train_env += calculate_F_SE(forces_prediction_torch[n_atoms_sys[n]:],
                                                                 forces_torch[n][n_atoms_sys[n]:])
            # calculate RMSEs
            if epoch % self.settings.RMSE_interval == 0:
                for n in test:
                    # predict energy and forces of test structures
                    energy_prediction_torch, forces_prediction_torch = self.calculate_energy_forces(
                        n, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                        neighbor_indices, n_atoms_sys, descriptor_neighbor_derivatives_torch_env,
                        neighbor_indices_env, n_atoms_active, MM_gradients, name, create_graph=False,
                        calc_forces=self.settings.fit_forces)
                    with torch.no_grad():
                        E_SE_test += float(((
                            energy_prediction_torch[0] - energy_torch[n][0]) / n_atoms_sys[n])**2)
                        if self.settings.fit_forces:
                            F_SE_test += calculate_F_SE(forces_prediction_torch[:n_atoms_sys[n]],
                                                        forces_torch[n][:n_atoms_sys[n]])
                            if self.settings.QMMM:
                                F_SE_test_env += calculate_F_SE(forces_prediction_torch[n_atoms_sys[n]:],
                                                                forces_torch[n][n_atoms_sys[n]:])
                # calculate RMSEs
                E_RMSE_train = np.sqrt(E_SE_train / n_structures_fit)
                if n_structures_test > 0:
                    E_RMSE_test = np.sqrt(E_SE_test / n_structures_test)
                else:
                    E_RMSE_test = np.nan
                if self.settings.fit_forces:
                    F_RMSE_train = np.sqrt(F_SE_train / (3 * n_atoms_fit_sum))
                    if self.settings.QMMM:
                        if n_atoms_fit_sum_env > 0:
                            F_RMSE_train_env = np.sqrt(F_SE_train_env / (3 * n_atoms_fit_sum_env))
                        else:
                            F_RMSE_train_env = np.nan
                    if n_structures_test > 0:
                        F_RMSE_test = np.sqrt(F_SE_test / (3 * n_atoms_test_sum))
                    else:
                        F_RMSE_test = np.nan
                    if self.settings.QMMM:
                        if n_atoms_test_sum_env > 0:
                            F_RMSE_test_env = np.sqrt(F_SE_test_env / (3 * n_atoms_test_sum_env))
                        else:
                            F_RMSE_test_env = np.nan
                else:
                    F_RMSE_test = np.nan
                    if self.settings.QMMM:
                        F_RMSE_test_env = np.nan
                if self.settings.QMMM:
                    self.test_RMSEs = [round(E_RMSE_test, 8), round(F_RMSE_test, 8),
                                       round(F_RMSE_test_env, 8)]
                else:
                    self.test_RMSEs = [round(E_RMSE_test, 8), round(F_RMSE_test, 8)]

            # update exclusion_hist and selection_hist
            if self.settings.selection_scheme == 'lADS':
                # prepare backtracking of gradients for deselected training structures
                if self.settings.gradient_backtracking:
                    selection_hist_old = deepcopy(selection_hist[:, 0])
                # prepare update of exclusion_hist and selection_hist,
                # whereby loss_selection_thresholds_mean includes the loss of redundant data
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', message='Mean of empty slice')
                    loss_selection_thresholds_mean = np.nanmean(
                        loss_selection_thresholds[selection_hist[:, 0] > -1.1], axis=0)
                # update exclusion_hist and selection_hist employing energy and force losses
                if self.settings.selection_measure == 'E+F_losses':
                    selection_hist, exclusion_hist = update_hist_E_F_losses(
                        fit, selection_hist, exclusion_hist, loss_old[fit], loss_new_fit,
                        loss_selection_thresholds_mean, decrease_factor, decrease_factor_small,
                        increase_factor_small, increase_factor, self.settings.exclusion_strikes,
                        selection_min, selection_max, n_redundant_max, n_structures_train)
                # update exclusion_hist and selection_hist employing the total loss
                elif self.settings.selection_measure == 'total_loss':
                    selection_hist[:, 0], exclusion_hist = update_hist_total_loss(
                        fit, selection_hist[:, 0], exclusion_hist, loss_old[fit, 0],
                        loss_new_fit[:, 0], loss_selection_thresholds_mean[0], decrease_factor,
                        decrease_factor_small, increase_factor_small, increase_factor,
                        self.settings.exclusion_strikes, selection_min, selection_max,
                        n_redundant_max, n_structures_train)
                # update fraction_good and update n_good and n_bad accordingly,
                # whereby loss_new_mean includes the loss of redundant data
                loss_old[fit] = loss_new_fit
                with warnings.catch_warnings():
                    warnings.filterwarnings('ignore', message='Mean of empty slice')
                    loss_new_mean = np.nanmean(loss_old[selection_hist[:, 0] > -1.1], axis=0)
                if np.any(loss_new_mean > loss_old_mean):
                    fraction_good = np.clip(fraction_good + fraction_change, 0.0,
                                            self.settings.fraction_good_max)
                else:
                    fraction_good = np.clip(fraction_good - fraction_change, 0.0,
                                            self.settings.fraction_good_max)
                n_good = int(fraction_good * n_structures_fit)
                n_bad = n_structures_fit - n_good
                loss_old_mean = loss_new_mean

            # write generalization setting and generalization
            if epoch % write_weights_interval == 0:
                if self.settings.selection_scheme == 'lADS':
                    self.training_state = [
                        name, train, selection_hist, exclusion_hist, loss_old,
                        loss_selection_thresholds, fraction_good, epoch]
                self.write_generalization_setting()
                self.write_generalization()

            # optimize weights
            self.optimizer.step()
            # measure maximal memory usage
            if epoch == memory_evaluation_epoch:
                self.max_memory_usage = self.memory_usage(
                    interval=1e-6 - 1e-9, timeout=1e-6, max_usage=True)
            # reset gradients to zero
            self.optimizer.zero_grad(set_to_none=True)

            # backtracking of gradients for deselected training structures
            if self.settings.selection_scheme == 'lADS':
                if self.settings.gradient_backtracking:
                    # deselection based on selection_hist
                    self.backtracking_gradients(
                        selection_hist, selection_hist_old, [-2.1, -1.9], selection_hist_factors,
                        epoch, probability_forces, probability_loss, probability_bad_max,
                        fraction_good, n_structures_train, n_structures_fit, n_groups,
                        elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                        neighbor_indices, n_atoms_sys, n_atoms_fit_sum,
                        descriptor_neighbor_derivatives_torch_env, neighbor_indices_env,
                        n_atoms_active, MM_gradients, energy_torch, forces_torch, name)
                    # deselection based on exclusion_hist
                    self.backtracking_gradients(
                        selection_hist, selection_hist_old, [-3.1, -2.9], exclusion_hist_factors,
                        epoch, probability_forces, probability_loss, probability_bad_max,
                        fraction_good, n_structures_train, n_structures_fit, n_groups,
                        elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                        neighbor_indices, n_atoms_sys, n_atoms_fit_sum,
                        descriptor_neighbor_derivatives_torch_env, neighbor_indices_env,
                        n_atoms_active, MM_gradients, energy_torch, forces_torch, name)

            # print results of the epoch
            print('{0:5d} | {1:13.8f}'.format(epoch, round(E_loss, 8)), end='')
            if self.settings.fit_forces:
                print(' | {0:13.8f}'.format(round(F_loss, 8)), end='')
            else:
                print(' | {0}'.format('             '), end='')
            print(' | {0:8.3f}'.format(round(time() - time_start, 3)), end='')
            if epoch % self.settings.RMSE_interval == 0:
                print(' | {0:13.6f} | {1:13.6f}'.format(
                    round(E_RMSE_train * self.energy_conversion, 6),
                    round(E_RMSE_test * self.energy_conversion, 6)), end='')
            else:
                print(' | {0} | {0}'.format('             '), end='')
            if epoch % self.settings.RMSE_interval == 0 and self.settings.fit_forces:
                print(' | {0:13.6f} | {1:13.6f}'.format(
                    round(F_RMSE_train * self.force_conversion, 6),
                    round(F_RMSE_test * self.force_conversion, 6)), end='')
            else:
                print(' | {0} | {0}'.format('             '), end='')
            if self.settings.QMMM:
                if epoch % self.settings.RMSE_interval == 0 and self.settings.fit_forces:
                    print(' | {0:13.6f} | {1:13.6f}'.format(
                        round(F_RMSE_train_env * self.force_conversion, 6),
                        round(F_RMSE_test_env * self.force_conversion, 6)), end='')
                else:
                    print(' | {0} | {0}'.format('             '), end='')
            print('')
            if self.settings.selection_scheme == 'lADS':
                self.print_results_lADS(selection_hist[:, 0])

        # initialize losses and SEs
        time_start = time()
        E_loss = 0.0
        F_loss = 0.0
        E_SE_train = 0.0
        E_SE_test = 0.0
        F_SE_train = 0.0
        F_SE_test = 0.0
        if self.settings.QMMM:
            F_SE_train_env = 0.0
            F_SE_test_env = 0.0
        # perform final epoch for all activated training structures
        n_atoms_train_sum_final_env = 0
        if self.settings.selection_scheme == 'lADS':
            # update selection_hist according to late data scheme
            if self.settings.late_data_scheme != 'None' and self.settings.n_epochs == self.settings.late_data_epoch:
                selection_hist[late] = 1.0 * np.ones(2)
            train_final = train[selection_hist[:, 0] > 0.0]
            n_structures_train_final = len(train_final)
            n_atoms_train_sum_final = np.sum(n_atoms_train[selection_hist[:, 0] > 0.0])
            if self.settings.QMMM:
                n_atoms_train_sum_final_env = np.sum(n_atoms_train_env[selection_hist[:, 0] > 0.0])
        else:
            train_final = train
            n_structures_train_final = n_structures_train
            n_atoms_train_sum_final = np.sum(n_atoms_train)
            if self.settings.QMMM:
                n_atoms_train_sum_final_env = np.sum(n_atoms_train_env)
        for n in train_final:
            # predict energy and forces of all activated training structures
            energy_prediction_torch, forces_prediction_torch = self.calculate_energy_forces(
                n, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                neighbor_indices, n_atoms_sys, descriptor_neighbor_derivatives_torch_env,
                neighbor_indices_env, n_atoms_active, MM_gradients, name, create_graph=False)
            # calculate losses for printing
            with torch.no_grad():
                E_loss += self.energy_loss_function(
                    energy_prediction_torch / n_atoms_sys[n], energy_torch[n] / n_atoms_sys[n]).item() * (
                        self.settings.loss_E_scaling / n_structures_train_final)
                F_loss += self.forces_loss_function(forces_prediction_torch, forces_torch[n]).item() / (
                    3 * n_atoms_train_sum_final)
                # calculate SEs of all activated training structures
                E_SE_train += float(((
                    energy_prediction_torch[0] - energy_torch[n][0]) / n_atoms_sys[n])**2)
                F_SE_train += calculate_F_SE(forces_prediction_torch[:n_atoms_sys[n]],
                                             forces_torch[n][:n_atoms_sys[n]])
                if self.settings.QMMM:
                    F_SE_train_env += calculate_F_SE(forces_prediction_torch[n_atoms_sys[n]:],
                                                     forces_torch[n][n_atoms_sys[n]:])
        # calculate RMSEs
        for n in test:
            # predict energy and forces of test structures
            energy_prediction_torch, forces_prediction_torch = self.calculate_energy_forces(
                n, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                neighbor_indices, n_atoms_sys, descriptor_neighbor_derivatives_torch_env,
                neighbor_indices_env, n_atoms_active, MM_gradients, name, create_graph=False)
            # calculate SEs of test structures
            with torch.no_grad():
                E_SE_test += float(((
                    energy_prediction_torch[0] - energy_torch[n][0]) / n_atoms_sys[n])**2)
                F_SE_test += calculate_F_SE(forces_prediction_torch[:n_atoms_sys[n]],
                                            forces_torch[n][:n_atoms_sys[n]])
                if self.settings.QMMM:
                    F_SE_test_env += calculate_F_SE(forces_prediction_torch[n_atoms_sys[n]:],
                                                    forces_torch[n][n_atoms_sys[n]:])
        # calculate RMSEs
        E_RMSE_train = np.sqrt(E_SE_train / n_structures_train_final)
        F_RMSE_train = np.sqrt(F_SE_train / (3 * n_atoms_train_sum_final))
        if n_structures_test > 0:
            E_RMSE_test = np.sqrt(E_SE_test / n_structures_test)
            F_RMSE_test = np.sqrt(F_SE_test / (3 * n_atoms_test_sum))
        else:
            E_RMSE_test = np.nan
            F_RMSE_test = np.nan
        if self.settings.QMMM:
            if n_atoms_train_sum_final_env > 0:
                F_RMSE_train_env = np.sqrt(F_SE_train_env / (3 * n_atoms_train_sum_final_env))
            else:
                F_RMSE_train_env = np.nan
            if n_atoms_test_sum_env > 0:
                F_RMSE_test_env = np.sqrt(F_SE_test_env / (3 * n_atoms_test_sum_env))
            else:
                F_RMSE_test_env = np.nan
            self.test_RMSEs = [round(E_RMSE_test, 8), round(F_RMSE_test, 8),
                               round(F_RMSE_test_env, 8)]
        else:
            self.test_RMSEs = [round(E_RMSE_test, 8), round(F_RMSE_test, 8)]

        # print results of final epoch
        print('Final | {0:13.8f} | {1:13.8f} | {2:8.3f} | {3:13.6f} | {4:13.6f} | {5:13.6f} | {6:13.6f}'
              .format(round(E_loss, 8), round(F_loss, 8), round(time() - time_start, 3),
                      round(E_RMSE_train * self.energy_conversion, 6),
                      round(E_RMSE_test * self.energy_conversion, 6),
                      round(F_RMSE_train * self.force_conversion, 6),
                      round(F_RMSE_test * self.force_conversion, 6)), end='')
        if self.settings.QMMM:
            print(' | {0:13.6f} | {1:13.6f}'.format(round(F_RMSE_train_env * self.force_conversion, 6),
                                                    round(F_RMSE_test_env * self.force_conversion, 6)))
        else:
            print('')
        if self.settings.selection_scheme == 'lADS':
            self.print_results_lADS(selection_hist[:, 0])
        print('')

        # get training state and training data selection assignments for selection scheme 'lADS'
        if self.settings.selection_scheme == 'lADS':
            self.training_state = [name, train, selection_hist, exclusion_hist, loss_old,
                                   loss_selection_thresholds, fraction_good,
                                   n_previous_epochs + self.settings.n_epochs]
            indices = np.arange(n_structures_train)
            selection = [
                indices[selection_hist[:, 0] > 0.0],
                indices[np.logical_and(np.less(-3.1, selection_hist[:, 0]),
                                       np.less(selection_hist[:, 0], -1.9))],
                indices[np.logical_and(np.less(-1.1, selection_hist[:, 0]),
                                       np.less(selection_hist[:, 0], -0.9))]]
        else:
            selection = []

        # measure maximal memory usage
        if memory_evaluation_epoch == -1:
            self.max_memory_usage = self.memory_usage(
                interval=1e-6 - 1e-9, timeout=1e-6, max_usage=True)

        return selection

####################################################################################################

    def calculate_energy_forces(self, n, elements_int_sys, descriptors_torch,
                                descriptor_derivatives_torch, neighbor_indices, n_atoms_sys,
                                descriptor_neighbor_derivatives_torch_env, neighbor_indices_env,
                                n_atoms_active, MM_gradients, name, create_graph=True,
                                calc_forces=True, model_index=0):
        '''
        Reutrn: energy_prediction_torch, forces_prediction_torch
        '''
        # descriptor_derivatives_torch and neighbor_indices are on disk
        if self.settings.descriptor_on_disk != 'None':
            descriptor_derivatives_torch, neighbor_indices, descriptor_neighbor_derivatives_torch_env, \
                neighbor_indices_env = self.read_descriptors(name[n])
            # predict energy of fit structures
            energy_prediction_torch = self.model[model_index](
                elements_int_sys[n], descriptors_torch[n], n_atoms_sys[n])
            # predict forces of fit structures
            if calc_forces:
                if self.settings.QMMM:
                    forces_prediction_torch = calculate_forces_QMMM(
                        energy_prediction_torch, descriptors_torch[n], descriptor_derivatives_torch[0],
                        descriptor_derivatives_torch[1], neighbor_indices, n_atoms_sys[n],
                        descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, n_atoms_active[n],
                        MM_gradients, create_graph=create_graph)
                else:
                    forces_prediction_torch = calculate_forces(
                        energy_prediction_torch, descriptors_torch[n], descriptor_derivatives_torch[0],
                        descriptor_derivatives_torch[1], neighbor_indices, n_atoms_sys[n],
                        create_graph=create_graph)

        # descriptor_derivatives_torch and neighbor_indices are in memory
        else:
            # predict energy of fit structures
            energy_prediction_torch = self.model[model_index](
                elements_int_sys[n], descriptors_torch[n], n_atoms_sys[n])
            # predict forces of fit structures
            if calc_forces:
                if self.settings.QMMM:
                    forces_prediction_torch = calculate_forces_QMMM(
                        energy_prediction_torch, descriptors_torch[n], descriptor_derivatives_torch[0][n],
                        descriptor_derivatives_torch[1][n], neighbor_indices[n], n_atoms_sys[n],
                        descriptor_neighbor_derivatives_torch_env[n], neighbor_indices_env[n],
                        n_atoms_active[n], MM_gradients, create_graph=create_graph)
                else:
                    forces_prediction_torch = calculate_forces(
                        energy_prediction_torch, descriptors_torch[n], descriptor_derivatives_torch[0][n],
                        descriptor_derivatives_torch[1][n], neighbor_indices[n], n_atoms_sys[n],
                        create_graph=create_graph)

        # no force prediction
        if not calc_forces:
            forces_prediction_torch = torch.zeros((0, 3))

        return energy_prediction_torch, forces_prediction_torch

####################################################################################################

    def read_descriptors(self, name):
        '''
        Return: descriptor_derivatives_torch, neighbor_indices,
                descriptor_neighbor_derivatives_torch_env, neighbor_indices_env
        '''
        # read descriptor_derivatives_torch and neighbor_indices
        if self.settings.transfer_learning:
            N = name[18:]
        else:
            N = name
        descriptor_file = self.settings.descriptor_disk_dir + '/descriptors_' + N + '.pt'
        checkpoint = torch.load(descriptor_file)
        descriptor_derivatives_torch = checkpoint['descriptor_derivatives_torch']
        neighbor_indices = checkpoint['neighbor_indices']
        descriptor_neighbor_derivatives_torch_env = checkpoint['descriptor_neighbor_derivatives_torch_env']
        neighbor_indices_env = checkpoint['neighbor_indices_env']

        return descriptor_derivatives_torch, neighbor_indices, \
            descriptor_neighbor_derivatives_torch_env, neighbor_indices_env

####################################################################################################

    def backtracking_gradients(self, selection_hist, selection_hist_old, hist_values, hist_factors,
                               epoch, probability_forces, probability_loss, probability_bad_max,
                               fraction_good, n_structures_train, n_structures_fit, n_groups,
                               elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                               neighbor_indices, n_atoms_sys, n_atoms_fit_sum,
                               descriptor_neighbor_derivatives_torch_env, neighbor_indices_env,
                               n_atoms_active, MM_gradients, energy_torch, forces_torch, name):
        '''
        Modify: optimizer (prev_1, prev_2)
        '''
        # determine new deselected structures
        backtracking_structures = np.arange(n_structures_train)[
            np.greater(selection_hist[:, 0], hist_values[0])
            * np.less(selection_hist[:, 0], hist_values[1])
            * np.greater(selection_hist_old, 0.0)]

        # determine backtracking steps
        if len(backtracking_structures) > 0:
            prob_sum = (np.prod(selection_hist[selection_hist[:, 0] > 0.0], axis=1)
                        * probability_forces[selection_hist[:, 0] > 0.0]
                        * ((1.0 - fraction_good) * probability_loss[selection_hist[:, 0] > 0.0]
                           + fraction_good * (1.0 - probability_loss[selection_hist[:, 0] > 0.0])))
            prob_sum[np.isnan(prob_sum)] = probability_bad_max
            prob_sum = np.sum(prob_sum)
            for n in backtracking_structures:
                prob = hist_factors * (probability_forces[n] * (
                    (1.0 - fraction_good) * probability_loss[n]
                    + fraction_good * (1.0 - probability_loss[n])))
                backtracking_steps = np.cumsum(1.0 / (np.clip(
                    prob / (prob_sum + prob) * n_structures_fit, 0.0, 1.0)))
                backtracking_steps -= backtracking_steps[0]
                if self.settings.energy_preoptimization_step:
                    backtracking_steps *= 2
                # determine backtracking factors
                betas_1 = (self.settings.betas[1] + (self.settings.betas[0] - self.settings.betas[1])
                           * np.exp(-((epoch - backtracking_steps) / self.settings.betas[2])**2))
                backtracking_factors = [
                    np.sum((1.0 - betas_1) * betas_1**backtracking_steps),
                    np.sum((1.0 - self.settings.betas[3]) * self.settings.betas[3]**backtracking_steps)]

                # predict energy and forces of deselected structures
                energy_prediction_torch, forces_prediction_torch = self.calculate_energy_forces(
                    n, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                    neighbor_indices, n_atoms_sys, descriptor_neighbor_derivatives_torch_env,
                    neighbor_indices_env, n_atoms_active, MM_gradients, name,
                    calc_forces=self.settings.fit_forces)
                # calculate loss of deselected structures
                loss_E = self.energy_loss_function(
                    energy_prediction_torch / n_atoms_sys[n], energy_torch[n] / n_atoms_sys[n]) * (
                        self.settings.loss_E_scaling / n_structures_fit)
                if self.settings.fit_forces:
                    loss_F = self.forces_loss_function(forces_prediction_torch, forces_torch[n]) / (
                        3 * n_atoms_fit_sum)
                    loss = loss_E + loss_F
                else:
                    loss = loss_E
                # calculate loss gradients of deselected structures
                loss.backward()
            # backtrack gradients in prev_1 and prev_2
            for i in range(n_groups):
                if self.optimizer.param_groups[0]['params'][i].grad is not None:
                    self.optimizer.state_dict()['state'][i]['prev_1'] -= backtracking_factors[0] \
                        * self.optimizer.param_groups[0]['params'][i].grad
                    self.optimizer.state_dict()['state'][i]['prev_2'] -= backtracking_factors[1] \
                        * self.optimizer.param_groups[0]['params'][i].grad**2
                    # ensure that prev_2 is not negative
                    self.optimizer.state_dict()['state'][i]['prev_2'][
                        self.optimizer.state_dict()['state'][i]['prev_2'] < 0.0] = 0.0
            # reset gradients to zero
            self.optimizer.zero_grad(set_to_none=True)

            # backtracking gradients for energy preoptimization steps
            if self.settings.energy_preoptimization_step:
                # determine backtracking steps
                for n in backtracking_structures:
                    prob = hist_factors * (probability_forces[n] * (
                        (1.0 - fraction_good) * probability_loss[n]
                        + fraction_good * (1.0 - probability_loss[n])))
                    backtracking_steps = np.cumsum(1.0 / (np.clip(
                        prob / (prob_sum + prob) * n_structures_fit, 0.0, 1.0)))
                    backtracking_steps -= backtracking_steps[0]
                    backtracking_steps *= 2
                    backtracking_steps += 1
                    # determine backtracking factors
                    betas_1 = (self.settings.betas[1] + (self.settings.betas[0] - self.settings.betas[1])
                               * np.exp(-((epoch - backtracking_steps) / self.settings.betas[2])**2))
                    backtracking_factors = [
                        np.sum((1.0 - betas_1) * betas_1**backtracking_steps),
                        np.sum((1.0 - self.settings.betas[3]) * self.settings.betas[3]**backtracking_steps)]

                    # predict energy of deselected structures
                    energy_prediction_torch, forces_prediction_torch = self.calculate_energy_forces(
                        n, elements_int_sys, descriptors_torch, descriptor_derivatives_torch,
                        neighbor_indices, n_atoms_sys, descriptor_neighbor_derivatives_torch_env,
                        neighbor_indices_env, n_atoms_active, MM_gradients, name, calc_forces=False)
                    # calculate loss of deselected structures
                    loss = self.energy_loss_function(
                        energy_prediction_torch / n_atoms_sys[n], energy_torch[n] / n_atoms_sys[n]) * (
                            self.settings.loss_E_scaling / n_structures_fit)
                    # calculate loss gradients of deselected structures
                    loss.backward()
                # backtrack gradients in prev_1 and prev_2
                for i in range(n_groups):
                    if self.optimizer.param_groups[0]['params'][i].grad is not None:
                        self.optimizer.state_dict()['state'][i]['prev_1'] -= backtracking_factors[0] \
                            * self.optimizer.param_groups[0]['params'][i].grad
                        self.optimizer.state_dict()['state'][i]['prev_2'] -= backtracking_factors[1] \
                            * self.optimizer.param_groups[0]['params'][i].grad**2
                        # ensure that prev_2 is not negative
                        self.optimizer.state_dict()['state'][i]['prev_2'][
                            self.optimizer.state_dict()['state'][i]['prev_2'] < 0.0] = 0.0
                # reset gradients to zero
                self.optimizer.zero_grad(set_to_none=True)

####################################################################################################

    def print_results_lADS(self, selection_hist):
        '''
        Output: n_redundant_data, n_bad_data, n_failed_data, n_later_data
        '''
        n_redundant_data = np.sum(np.logical_and(np.less(-1.1, selection_hist), np.less(selection_hist, -0.9)))
        n_bad_data = np.sum(np.logical_and(np.less(-2.1, selection_hist), np.less(selection_hist, -1.9)))
        n_failed_data = np.sum(np.logical_and(np.less(-3.1, selection_hist), np.less(selection_hist, -2.9)))
        n_later_data = np.sum(np.logical_and(np.less(-4.1, selection_hist), np.less(selection_hist, -3.9)))
        print('Redundant: {0}, Bad: {1}, Failed: {2}, Late: {3}'.format(
            n_redundant_data, n_bad_data, n_failed_data, n_later_data))

####################################################################################################

    def write_generalization_setting(self):
        '''
        Implementation: lMLP

        Output: Generalization setting file
        '''
        # implemented generalization setting formats
        generalization_setting_format_list = ['lMLP']

        # create generalization setting directory if it does not exist
        Path(self.settings.generalization_setting_file).parent.mkdir(parents=True, exist_ok=True)

        # write lMLP generalization setting
        if self.settings.generalization_setting_format == 'lMLP':
            self.write_setting()

        # not implemented generalization setting format
        else:
            print('ERROR: Generalization setting format {0} is not yet implemented.'
                  .format(self.settings.generalization_setting_format),
                  '\nPlease use one of the following formats:')
            for gen_set_format in generalization_setting_format_list:
                print('{0}'.format(gen_set_format))
            sys.exit()

####################################################################################################

    def write_setting(self):
        '''
        Output: setting.ini file
        '''
        # initialize configparser
        config = configparser.ConfigParser()
        config['settings'] = {
            'generalization_format': self.settings.generalization_format,
            'ensemble': [self.settings.generalization_file.split('/')[-1]],
            'test_RMSEs': [self.test_RMSEs],
            'model_type': self.settings.model_type,
            'element_types': list(self.element_types),
            'descriptor_type': self.settings.descriptor_type,
            'descriptor_radial_type': self.settings.descriptor_radial_type,
            'descriptor_angular_type': self.settings.descriptor_angular_type,
            'descriptor_scaling_type': self.settings.descriptor_scaling_type,
            'n_descriptors': self.n_descriptors,
            'scale_shift_layer': self.settings.scale_shift_layer,
            'n_neurons_hidden_layers': self.settings.n_neurons_hidden_layers,
            'activation_function_type': self.settings.activation_function_type,
            'dtype_torch': self.settings.dtype_torch,
            'QMMM': self.settings.QMMM,
            'MM_atomic_charge_max': self.settings.MM_atomic_charge_max}

        # write generalization setting file
        with open(self.settings.generalization_setting_file, 'w', encoding='utf-8') as f:
            config.write(f)

####################################################################################################

    def write_generalization(self, model_index=0):
        '''
        Implementation: lMLP, lMLP-only_prediction, RuNNer

        Output: Generalization file(s)
        '''
        # implemented generalization formats
        generalization_format_list = ['lMLP', 'lMLP-only_prediction', 'RuNNer']

        # create generalization directory if it does not exist
        Path(self.settings.generalization_file).parent.mkdir(parents=True, exist_ok=True)

        # write lMLP model and training state
        if self.settings.generalization_format == 'lMLP':
            self.write_lMLP(model_index)

        # write lMLP model
        elif self.settings.generalization_format == 'lMLP-only_prediction':
            self.write_lMLP_model(model_index)

        # write atomic neural network weights in RuNNer format
        elif self.settings.generalization_format == 'RuNNer':
            self.write_weights(model_index)

        # not implemented generalization format
        else:
            print('ERROR: Generalization format {0} is not yet implemented.'
                  .format(self.settings.generalization_format),
                  '\nPlease use one of the following formats:')
            for gen_format in generalization_format_list:
                print('{0}'.format(gen_format))
            sys.exit()

####################################################################################################

    def write_lMLP(self, model_index):
        '''
        Output: lMLP.pt file
        '''
        # write lMLP model and training state
        if len(self.training_state) > 0:
            torch.save({
                'model_state_dict': self.model[model_index].state_dict(),
                'descriptor_parameters': self.descriptor_parameters,
                'R_c': self.R_c,
                'element_energy': self.element_energy,
                'energy_loss_function': self.energy_loss_function,
                'forces_loss_function': self.forces_loss_function,
                'optimizer_state_dict': self.optimizer.state_dict(),
                'name': self.training_state[0],
                'train': self.training_state[1],
                'selection_hist': self.training_state[2],
                'exclusion_hist': self.training_state[3],
                'loss_old': self.training_state[4],
                'loss_selection_thresholds': self.training_state[5],
                'fraction_good': self.training_state[6],
                'n_previous_epochs': self.training_state[7]}, self.settings.generalization_file)
        else:
            torch.save({
                'model_state_dict': self.model[model_index].state_dict(),
                'descriptor_parameters': self.descriptor_parameters,
                'R_c': self.R_c,
                'element_energy': self.element_energy,
                'energy_loss_function': self.energy_loss_function,
                'forces_loss_function': self.forces_loss_function,
                'optimizer_state_dict': self.optimizer.state_dict(),
                'name': [],
                'train': np.array([]),
                'selection_hist': np.array([]),
                'exclusion_hist': np.array([]),
                'loss_old': np.array([]),
                'loss_selection_thresholds': np.array([]),
                'fraction_good': 0.0,
                'n_previous_epochs': 0}, self.settings.generalization_file)

####################################################################################################

    def write_lMLP_model(self, model_index):
        '''
        Output: lMLP-only_prediction.pt file
        '''
        # write lMLP model
        torch.save({
            'model_state_dict': self.model[model_index].state_dict(),
            'descriptor_parameters': self.descriptor_parameters, 'R_c': self.R_c,
            'element_energy': self.element_energy}, self.settings.generalization_file)

####################################################################################################

    def write_weights(self, model_index):
        '''
        Output: weights.XXX.data files
        '''
        # create weights directory if it does not exist
        Path(self.settings.generalization_file).mkdir(parents=True, exist_ok=True)

        # determine atomic neural network architecture
        if self.settings.scale_shift_layer:
            N_j = ([self.model[model_index].n_descriptors]
                   + self.model[model_index].n_neurons_hidden_layers + [1])
            N_k = (2 * [self.model[model_index].n_descriptors]
                   + self.model[model_index].n_neurons_hidden_layers)
        else:
            N_j = self.model[model_index].n_neurons_hidden_layers + [1]
            N_k = ([self.model[model_index].n_descriptors]
                   + self.model[model_index].n_neurons_hidden_layers)

        # get atomic numbers string dictionary
        atomic_numbers = self.get_atomic_numbers()

        # write weights
        for ele in range(self.n_element_types):
            weights_file = '{0}/weights.{1:03}.data'.format(
                self.settings.generalization_file, atomic_numbers[self.element_types[ele]])
            with open(weights_file, 'w', encoding='utf-8') as f:
                counter = 1
                i = 0
                for layer in self.model[model_index].atomic_neural_networks[ele]:
                    if not hasattr(layer, 'weight'):
                        continue
                    # write weights a
                    w = layer.weight.data.cpu().detach().numpy()
                    if i == 0 and self.settings.scale_shift_layer:
                        w = w * np.eye(len(w))
                    for k in range(N_k[i]):
                        for j in range(N_j[i]):
                            f.write('{0:18.10f}  a {1:9d} {2:5d} {3:5d} {4:5d} {5:5d}\n'.format(
                                w[j][k], counter, i, k + 1, i + 1, j + 1))
                            counter += 1
                    # write weights b
                    w = layer.bias.data.cpu().detach().numpy()
                    if i == 0 and self.settings.scale_shift_layer:
                        w = -w * layer.weight.data.cpu().detach().numpy()
                    for j in range(N_j[i]):
                        f.write('{0:18.10f}  b {1:9d} {2:5d} {3:5d}\n'.format(w[j], counter, i + 1, j + 1))
                        counter += 1
                    i += 1

####################################################################################################

    def write_modified_episodic_memory(self, selection, train, test, elements, positions, lattices,
                                       atomic_classes, atomic_charges, energy, forces, n_structures,
                                       n_structures_test, n_atoms, n_atoms_sys, reorder, name):
        '''
        Output: new, bad, and redundant training data and test data episodic memory files
        '''
        # write modified episodic memory
        if self.settings.write_new_episodic_memory:
            # add atomic and pair contributions
            energy, forces = self.calculate_atomic_and_pair_contributions(
                '+', elements, positions, lattices, energy, forces, n_structures, n_atoms_sys)
            # write new episodic memory
            select = train[selection[0]]
            if self.settings.QMMM:
                reorder_select = [reorder[i] for i in select]
            else:
                reorder_select = reorder
            self.write_prediction(
                self.settings.episodic_memory_file + '_new', [elements[i] for i in select],
                [positions[i] for i in select], [lattices[i] for i in select],
                [atomic_classes[i] for i in select], [atomic_charges[i] for i in select],
                [energy[i] for i in select], [forces[i] for i in select], len(select),
                n_atoms[select], reorder_select, [name[i] for i in select])
            # write bad episodic memory data
            select = train[selection[1]]
            if self.settings.QMMM:
                reorder_select = [reorder[i] for i in select]
            self.write_prediction(
                self.settings.episodic_memory_file + '_bad', [elements[i] for i in select],
                [positions[i] for i in select], [lattices[i] for i in select],
                [atomic_classes[i] for i in select], [atomic_charges[i] for i in select],
                [energy[i] for i in select], [forces[i] for i in select], len(select),
                n_atoms[select], reorder_select, [name[i] for i in select])
            # write redundant episodic memory data
            select = train[selection[2]]
            if self.settings.QMMM:
                reorder_select = [reorder[i] for i in select]
            self.write_prediction(
                self.settings.episodic_memory_file + '_redundant', [elements[i] for i in select],
                [positions[i] for i in select], [lattices[i] for i in select],
                [atomic_classes[i] for i in select], [atomic_charges[i] for i in select],
                [energy[i] for i in select], [forces[i] for i in select], len(select),
                n_atoms[select], reorder_select, [name[i] for i in select])
            # write test data
            if self.settings.QMMM:
                reorder_test = [reorder[i] for i in test]
            else:
                reorder_test = reorder
            self.write_prediction(
                self.settings.episodic_memory_file + '_test', [elements[i] for i in test],
                [positions[i] for i in test], [lattices[i] for i in test],
                [atomic_classes[i] for i in test], [atomic_charges[i] for i in test],
                [energy[i] for i in test], [forces[i] for i in test], n_structures_test,
                n_atoms[test], reorder_test, [name[i] for i in test])

        # print names of bad and failed training data
        if self.settings.print_bad_data_names:
            if name:
                select = train[selection[1]]
                if len(select) > 0:
                    print('Bad and failed training data:  {0}'.format(len(select)))
                    for i in select:
                        print(name[i])
                    print('')
            else:
                print('WARNING: Bad and failed training data names cannot be printed',
                      'as no names are provided.\n')

####################################################################################################

    def write_prediction(self, prediction_file, elements, positions, lattices, atomic_classes,
                         atomic_charges, energy, forces, n_structures, n_atoms, reorder, name=None,
                         assignment=None):
        '''
        Implementation: inputdata

        Output: Prediction file including energies and forces
        '''
        # implemented prediction formats
        prediction_format_list = ['inputdata']

        # create prediction directory if it does not exist
        Path(prediction_file).parent.mkdir(parents=True, exist_ok=True)

        # write inputdata file
        if self.settings.prediction_format == 'inputdata':
            self.write_inputdata(prediction_file, elements, positions, lattices, atomic_classes,
                                 atomic_charges, energy, forces, n_structures, n_atoms, reorder,
                                 name, assignment)

        # not implemented prediction format
        else:
            print('ERROR: Prediction format {0} is not yet implemented.'
                  .format(self.settings.prediction_format),
                  '\nPlease use one of the following formats:')
            for pre_format in prediction_format_list:
                print('{0}'.format(pre_format))
            sys.exit()


####################################################################################################

####################################################################################################

@ncjit
def update_hist_E_F_losses(fit, selection_hist, exclusion_hist, loss_old_fit, loss_new_fit,
                           loss_selection_thresholds_mean, decrease_factor, decrease_factor_small,
                           increase_factor_small, increase_factor, exclusion_strikes, selection_min,
                           selection_max, n_redundant_max, n_structures_train) -> Tuple[
                               NDArray, NDArray]:
    '''
    Return: selection_hist, exclusion_hist
    '''
    # update exclusion_hist
    worst_bool = np.logical_or(loss_new_fit[:, 0] > loss_selection_thresholds_mean[0, 3],
                               loss_new_fit[:, 1] > loss_selection_thresholds_mean[1, 3])
    worst = fit[worst_bool]
    exclusion_hist[worst] += 1
    neutral = fit[np.logical_not(worst_bool)]
    exclusion_hist[neutral] = 0

    # update selection_hist
    for i in (0, 1):
        neutral = fit[loss_new_fit[:, i] >= loss_selection_thresholds_mean[i, 0]]
        selection_hist[:, i][neutral] = np.clip(selection_hist[:, i][neutral], 1.0, selection_max)
        neutral = fit[loss_new_fit[:, i] <= loss_selection_thresholds_mean[i, 1]]
        selection_hist[:, i][neutral] = np.clip(selection_hist[:, i][neutral], selection_min, 1.0)
        loss_lower = loss_new_fit[:, i] <= loss_old_fit[:, i]
        loss_higher = loss_new_fit[:, i] > loss_old_fit[:, i]
        good_bool = loss_new_fit[:, i] < loss_selection_thresholds_mean[i, 0]
        best = fit[np.logical_and(good_bool, loss_lower)]
        selection_hist[:, i][best] *= decrease_factor
        good = fit[np.logical_and(good_bool, loss_higher)]
        selection_hist[:, i][good] *= decrease_factor_small
        bad_bool = np.logical_and(loss_new_fit[:, i] > loss_selection_thresholds_mean[i, 1],
                                  loss_new_fit[:, i] <= loss_selection_thresholds_mean[i, 2])
        bad = fit[np.logical_and(bad_bool, loss_higher)]
        selection_hist[:, i][bad] *= increase_factor_small
        worse_bool = loss_new_fit[:, i] > loss_selection_thresholds_mean[i, 2]
        worst = fit[np.logical_and(worse_bool, loss_higher)]
        selection_hist[:, i][worst] *= increase_factor
        worse = fit[np.logical_and(worse_bool, loss_lower)]
        selection_hist[:, i][worse] *= increase_factor_small

    # deselect structures
    selection_hist[exclusion_hist >= exclusion_strikes] = -3.0 * np.ones(2)
    selection_hist_low = (np.clip(selection_hist[:, 0]**-1 - 1.0, 0.0, None)**2
                          + np.clip(selection_hist[:, 1]**-1 - 1.0, 0.0, None)**2)
    selection_hist_high = (np.clip(selection_hist[:, 0] - 1.0, 0.0, None)**2
                           + np.clip(selection_hist[:, 1] - 1.0, 0.0, None)**2)
    selection_min_threshold = (selection_min**-1 - 1.0)**2
    selection_max_threshold = (selection_max - 1.0)**2
    selection_hist[selection_hist_high > 1.000001 * selection_max_threshold] = -2.0 * np.ones(2)
    redundant = np.arange(n_structures_train)[selection_hist_low > 1.000001 * selection_min_threshold]
    n_redundant_new = len(redundant)
    if n_redundant_new > n_redundant_max:
        selection = np.argsort(fit[:n_redundant_new])
        selection_hist[redundant[selection[n_redundant_max:]]] /= decrease_factor
        redundant = redundant[selection[:n_redundant_max]]
    selection_hist[redundant] = -1.0 * np.ones(2)

    return selection_hist, exclusion_hist


####################################################################################################

@ncjit
def update_hist_total_loss(fit, selection_hist, exclusion_hist, loss_old_fit, loss_new_fit,
                           loss_selection_thresholds_mean, decrease_factor, decrease_factor_small,
                           increase_factor_small, increase_factor, exclusion_strikes, selection_min,
                           selection_max, n_redundant_max, n_structures_train) -> Tuple[
                               NDArray, NDArray]:
    '''
    Return: selection_hist, exclusion_hist
    '''
    # update exclusion_hist
    worst = fit[loss_new_fit > loss_selection_thresholds_mean[3]]
    exclusion_hist[worst] += 1
    neutral = fit[loss_new_fit <= loss_selection_thresholds_mean[3]]
    exclusion_hist[neutral] = 0

    # update selection_hist
    neutral = fit[loss_new_fit >= loss_selection_thresholds_mean[0]]
    selection_hist[neutral] = np.clip(selection_hist[neutral], 1.0, selection_max)
    neutral = fit[loss_new_fit <= loss_selection_thresholds_mean[1]]
    selection_hist[neutral] = np.clip(selection_hist[neutral], selection_min, 1.0)
    loss_lower = loss_new_fit <= loss_old_fit
    loss_higher = loss_new_fit > loss_old_fit
    good_bool = loss_new_fit < loss_selection_thresholds_mean[0]
    best = fit[np.logical_and(good_bool, loss_lower)]
    selection_hist[best] *= decrease_factor
    good = fit[np.logical_and(good_bool, loss_higher)]
    selection_hist[good] *= decrease_factor_small
    bad_bool = np.logical_and(loss_new_fit > loss_selection_thresholds_mean[1],
                              loss_new_fit <= loss_selection_thresholds_mean[2])
    bad = fit[np.logical_and(bad_bool, loss_higher)]
    selection_hist[bad] *= increase_factor_small
    worse_bool = loss_new_fit > loss_selection_thresholds_mean[2]
    worst = fit[np.logical_and(worse_bool, loss_higher)]
    selection_hist[worst] *= increase_factor
    worse = fit[np.logical_and(worse_bool, loss_lower)]
    selection_hist[worse] *= increase_factor_small

    # deselect structures
    selection_hist[exclusion_hist >= exclusion_strikes] = -3.0
    selection_hist[selection_hist > 1.000001 * selection_max] = -2.0
    redundant = np.arange(n_structures_train)[np.logical_and(selection_hist < 0.999999 * selection_min,
                                              selection_hist > 0.000001 * selection_min)]
    n_redundant_new = len(redundant)
    if n_redundant_new > n_redundant_max:
        selection = np.argsort(fit[:n_redundant_new])
        selection_hist[redundant[selection[n_redundant_max:]]] /= decrease_factor
        redundant = redundant[selection[:n_redundant_max]]
    selection_hist[redundant] = -1.0

    return selection_hist, exclusion_hist


####################################################################################################

@ncfjit
def calculate_RMSE(energy_prediction, energy, forces_prediction, forces, n_structures, n_atoms_sys,
                   QMMM, active_atoms, n_atoms_active) -> Tuple[float, float, float]:
    '''
    Return: E_RMSE, F_RMSE
    '''
    # calculate energy and force RMSE values
    E_RMSE = np.sqrt(np.sum(((
        energy_prediction - energy) / n_atoms_sys)**2) / n_structures)
    F_RMSE = np.sqrt(np.sum(np.array([
        np.sum((forces_prediction[n][:n_atoms_sys[n]] - forces[n][:n_atoms_sys[n]]).flatten()**2)
        for n in range(n_structures)])) / (3 * np.sum(n_atoms_sys)))
    if QMMM:
        F_RMSE_env = np.sqrt(np.sum(np.array([
            np.sum((forces_prediction[n][active_atoms[n]][n_atoms_sys[n]:]
                    - forces[n][active_atoms[n]][n_atoms_sys[n]:]).flatten()**2)
            for n in range(n_structures)])) / (3 * np.sum(n_atoms_active - n_atoms_sys)))
    else:
        F_RMSE_env = np.nan

    return E_RMSE, F_RMSE, F_RMSE_env


####################################################################################################

@torch.jit.script
def calculate_F_SE(forces_prediction_torch: Tensor, forces_torch: Tensor) -> float:
    '''
    Return: F_SE
    '''
    # calculate forces squared error
    return float(torch.sum(torch.flatten(forces_prediction_torch - forces_torch)**2))
