#!/usr/bin/python3

####################################################################################################

####################################################################################################

'''
Lifelong Machine Learning Potentials (lMLP)
'''
__copyright__ = '''This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.'''

####################################################################################################

####################################################################################################

from ast import literal_eval
from os import path
import configparser
import sys
import numpy as np
import torch
from numba import typed   # type: ignore
from .descriptors import prepare_periodic_cell, get_periodic_images, get_triple_properties
from .models import HDNNP
from .stanh import sTanh


####################################################################################################

####################################################################################################

class lMLP_base():
    '''
    Lifelong Machine Learning Potential Base Class
    '''

####################################################################################################

    def __init__(self):
        '''
        Initialization
        '''
        # define unit conversion factors
        self.Bohr2Angstrom = 0.529177210903   # CODATA 2018
        self.Hartree2eV = 27.211386245988   # CODATA 2018
        self.eV2kJ_mol = 1.602176634 * 60.2214076   # CODATA 2018

        # disable debugging functions of PyTorch
        torch.autograd.profiler.emit_nvtx(False)
        torch.autograd.profiler.profile(False)
        torch.autograd.set_detect_anomaly(False)

####################################################################################################

    def define_dtype_torch(self):
        '''
        Implementation: float, double

        Return: dtype_torch
        '''
        # implemented torch dtypes
        dtype_torch_list = ['float', 'double']

        # torch dtype float
        if self.settings.dtype_torch == 'float':
            dtype_torch = torch.float

        # torch dtype double
        elif self.settings.dtype_torch == 'double':
            dtype_torch = torch.double

        # not implemented torch dtype
        else:
            print('ERROR: Using the data type {0} is not yet implemented for PyTorch tensors.'
                  .format(self.settings.dtype_torch),
                  '\nPlease use one of the following data types:')
            for dty_tor in dtype_torch_list:
                print('{0}'.format(dty_tor))
            sys.exit()

        return dtype_torch

####################################################################################################

    def read_generalization_setting(self, spezified_settings):
        '''
        Implementation: lMLP

        Modify: generalization_format, model_type, descriptor_type, descriptor_radial_type,
                descriptor_angular_type, descriptor_scaling_type, scale_shift_layer,
                n_neurons_hidden_layers, activation_function_type, dtype_torch, QMMM,
                MM_atomic_charge_max

        Return: element_types, n_descriptors, ensemble, test_RMSEs
        '''
        # implemented file formats
        generalization_setting_format_list = ['lMLP']

        # check existance of file
        if not path.isfile(self.settings.generalization_setting_file):
            print('ERROR: Generalization setting file {0} does not exist.'.format(
                self.settings.generalization_setting_file))
            sys.exit()

        # read file format lMLP
        if self.settings.generalization_setting_format == 'lMLP':
            element_types, n_descriptors, ensemble, test_RMSEs = self.read_settings(
                spezified_settings)

        # not implemented file format
        else:
            print('ERROR: Generalization setting format {0} is not yet implemented.'
                  .format(self.settings.generalization_setting_format),
                  '\nPlease use one of the following formats:')
            for gen_set_format in generalization_setting_format_list:
                print('{0}'.format(gen_set_format))
            sys.exit()

        return element_types, n_descriptors, ensemble, test_RMSEs

####################################################################################################

    def read_settings(self, spezified_settings):
        '''
        Modify: generalization_format, model_type, descriptor_type, descriptor_radial_type,
                descriptor_angular_type, descriptor_scaling_type, scale_shift_layer,
                n_neurons_hidden_layers, activation_function_type, dtype_torch, QMMM,
                MM_atomic_charge_max

        Return: element_types, n_descriptors, ensemble, test_RMSEs
        '''
        # initialize configparser
        config = configparser.ConfigParser()
        files = config.read(self.settings.generalization_setting_file)

        # check generalization setting file
        if len(files) != 1:
            print('ERROR: Generalization setting file {0} does not exist or is broken.'.format(
                self.settings.generalization_setting_file))
            sys.exit()
        generalization_settings = config['settings']

        # get generalization settings
        self.settings.generalization_format = generalization_settings['generalization_format']
        ensemble = literal_eval(generalization_settings['ensemble'])
        test_RMSEs = literal_eval(generalization_settings['test_RMSEs'])
        self.settings.model_type = generalization_settings['model_type']
        element_types = np.array(literal_eval(generalization_settings['element_types']))
        self.settings.descriptor_type = generalization_settings['descriptor_type']
        self.settings.descriptor_radial_type = generalization_settings['descriptor_radial_type']
        self.settings.descriptor_angular_type = generalization_settings['descriptor_angular_type']
        self.settings.descriptor_scaling_type = generalization_settings['descriptor_scaling_type']
        n_descriptors = literal_eval(generalization_settings['n_descriptors'])
        self.settings.scale_shift_layer = literal_eval(generalization_settings['scale_shift_layer'])
        self.settings.n_neurons_hidden_layers = literal_eval(generalization_settings['n_neurons_hidden_layers'])
        self.settings.activation_function_type = generalization_settings['activation_function_type']
        self.settings.dtype_torch = generalization_settings['dtype_torch']
        self.settings.QMMM = literal_eval(generalization_settings['QMMM'])
        self.settings.MM_atomic_charge_max = literal_eval(generalization_settings['MM_atomic_charge_max'])

        # check generalization settings
        if len(spezified_settings) > 0:
            if self.settings.generalization_format != spezified_settings[0]:
                print('ERROR: Spezified generalization_format {0} is not equal to'
                      .format(spezified_settings[0]),
                      'generalization setting generalization_format {0} (file {1}).'
                      .format(self.settings.generalization_format,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.model_type != spezified_settings[1]:
                print('ERROR: Spezified model_type {0} is not equal to'
                      .format(spezified_settings[1]),
                      'generalization setting model_type {0} (file {1}).'
                      .format(self.settings.model_type,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.descriptor_type != spezified_settings[2]:
                print('ERROR: Spezified descriptor_type {0} is not equal to'
                      .format(spezified_settings[2]),
                      'generalization setting descriptor_type {0} (file {1}).'
                      .format(self.settings.descriptor_type,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.descriptor_radial_type != spezified_settings[3]:
                print('ERROR: Spezified descriptor_radial_type {0} is not equal to'
                      .format(spezified_settings[3]),
                      'generalization setting descriptor_radial_type {0} (file {1}).'
                      .format(self.settings.descriptor_radial_type,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.descriptor_angular_type != spezified_settings[4]:
                print('ERROR: Spezified descriptor_angular_type {0} is not equal to'
                      .format(spezified_settings[4]),
                      'generalization setting descriptor_angular_type {0} (file {1}).'
                      .format(self.settings.descriptor_angular_type,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.descriptor_scaling_type != spezified_settings[5]:
                print('ERROR: Spezified descriptor_scaling_type {0} is not equal to'
                      .format(spezified_settings[5]),
                      'generalization setting descriptor_scaling_type {0} (file {1}).'
                      .format(self.settings.descriptor_scaling_type,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.scale_shift_layer != spezified_settings[6]:
                print('ERROR: Spezified scale_shift_layer {0} is not equal to'
                      .format(spezified_settings[6]),
                      'generalization setting scale_shift_layer {0} (file {1}).'
                      .format(self.settings.scale_shift_layer,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.n_neurons_hidden_layers != spezified_settings[7]:
                print('ERROR: Spezified n_neurons_hidden_layers {0} is not equal to'
                      .format(spezified_settings[7]),
                      'generalization setting n_neurons_hidden_layers {0} (file {1}).'
                      .format(self.settings.n_neurons_hidden_layers,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.activation_function_type != spezified_settings[8]:
                print('ERROR: Spezified activation_function_type {0} is not equal to'
                      .format(spezified_settings[8]),
                      'generalization setting activation_function_type {0} (file {1}).'
                      .format(self.settings.activation_function_type,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.dtype_torch != spezified_settings[9]:
                print('ERROR: Spezified dtype_torch {0} is not equal to'
                      .format(spezified_settings[9]),
                      'generalization setting dtype_torch {0} (file {1}).'
                      .format(self.settings.dtype_torch,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.QMMM != spezified_settings[10]:
                print('ERROR: Spezified QMMM {0} is not equal to'
                      .format(spezified_settings[10]),
                      'generalization setting QMMM {0} (file {1}).'
                      .format(self.settings.QMMM,
                              self.settings.generalization_setting_file))
                sys.exit()
            if self.settings.MM_atomic_charge_max != spezified_settings[11]:
                print('ERROR: Spezified MM_atomic_charge_max {0} is not equal to'
                      .format(spezified_settings[11]),
                      'generalization setting MM_atomic_charge_max {0} (file {1}).'
                      .format(self.settings.MM_atomic_charge_max,
                              self.settings.generalization_setting_file))
                sys.exit()

        return element_types, n_descriptors, ensemble, test_RMSEs

####################################################################################################

    def prepare_periodic_systems(self, elements, positions, lattice, atomic_classes, atomic_charges,
                                 energy, forces, n_atoms):
        '''
        Return: elements, positions, lattice, atomic_classes, atomic_charges, energy, forces,
                n_atoms, reorder
        '''
        # order system and environment atoms
        if self.settings.QMMM:
            order = np.arange(n_atoms)[atomic_classes == 1]
            n_atoms_sys = len(order)
            order = np.concatenate((order, np.arange(n_atoms)[atomic_classes == 2]))
            elements = elements[order]
            positions = positions[order]
            atomic_classes = atomic_classes[order]
            atomic_charges = atomic_charges[order]
            if len(forces) > 0:
                forces = forces[order]
            reorder = np.argsort(order)
        else:
            n_atoms_sys = n_atoms
            reorder = np.arange(n_atoms)

        # align center of system atoms and center of cell for QMMM, wrap atoms into original cell,
        # determine if periodic boundary conditions are required, and expand cell if its heights are
        # smaller than the cutoff radius
        positions, lattice, pbc_required, n_images_tot = prepare_periodic_cell(
            positions, lattice, n_atoms, n_atoms_sys, self.R_c)

        # expand periodic system if its heights are smaller than the cutoff radius
        if pbc_required:
            if n_images_tot > 1:
                elements = np.tile(elements, n_images_tot)
                atomic_classes = np.tile(atomic_classes, n_images_tot)
                atomic_charges = np.tile(atomic_charges, n_images_tot)
                energy *= n_images_tot
                if len(forces) > 0:
                    forces = np.tile(forces, (n_images_tot, 1))
                n_atoms *= n_images_tot

        # remove lattice if periodic boundary conditions are not required
        else:
            lattice = np.array([])

        return elements, positions, lattice, atomic_classes, atomic_charges, energy, forces, \
            n_atoms, reorder

####################################################################################################

    def atomic_type_ordering(self, elements, positions, atomic_classes, atomic_charges, forces,
                             n_structures, n_atoms):
        '''
        Return: elements, positions, atomic_classes, atomic_charges, forces, n_atoms_sys, reorder
        '''
        # order elements, positions, atomic_classes, atomic_charges, and forces by atomic class
        n_forces = len(forces)
        n_atoms_sys = np.empty(n_structures, dtype=int)
        reorder = []
        for n in range(n_structures):
            order_QM = np.arange(n_atoms[n])[atomic_classes[n] == 1]
            order = np.concatenate((order_QM, np.arange(n_atoms[n])[atomic_classes[n] == 2]))
            elements[n] = elements[n][order]
            positions[n] = positions[n][order]
            atomic_classes[n] = atomic_classes[n][order]
            atomic_charges[n] = atomic_charges[n][order]
            if n_forces:
                forces[n] = forces[n][order]
            n_atoms_sys[n] = len(order_QM)
            reorder.append(np.argsort(order))

        return elements, positions, atomic_classes, atomic_charges, forces, n_atoms_sys, reorder

####################################################################################################

    def convert_element_names(self, elements, n_structures, n_atoms_sys):
        '''
        Return: elements_int_sys
        '''
        # initialize list
        elements_int_sys = []

        # create integer arrays for the elements of every structure
        for n in range(n_structures):
            elements_int_sys.append(-np.ones(n_atoms_sys[n], dtype=int))
            # set correct indices for all element types
            for i in range(self.n_element_types):
                elements_int_sys[n][elements[n][:n_atoms_sys[n]] == self.element_types[i]] = i
            if np.any(elements_int_sys[n] < 0):
                print('ERROR: Not all element types are spezified in the settings.',
                      '\nAll specifications: {0}, current structure: {1}.'
                      .format(self.element_types, np.unique(elements[n][:n_atoms_sys[n]])))
                sys.exit()

        return elements_int_sys

####################################################################################################

    def calculate_descriptors(self, elements_int_sys, positions, lattices, atomic_classes,
                              atomic_charges, n_structures, n_atoms, n_atoms_sys, name):
        '''
        Limitation: All elements will have the same set of descriptors

        Implementation: ACSF, eeACSF

        Return: descriptors_torch, descriptor_derivatives_torch, neighbor_indices,
                descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, active_atoms,
                n_atoms_active, MM_gradients
        '''
        # implemented descriptor types
        descriptor_type_list = ['ACSF', 'eeACSF']

        # calculate ACSFs or eeACSFs
        if self.settings.descriptor_type in ('ACSF', 'eeACSF'):
            descriptors_torch, descriptor_derivatives_torch, neighbor_indices, \
                descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, active_atoms, \
                n_atoms_active, MM_gradients = self.calculate_symmetry_function(
                    elements_int_sys, positions, lattices, atomic_classes, atomic_charges, n_structures,
                    n_atoms, n_atoms_sys, name)

        # not implemented descriptor type
        else:
            print('ERROR: Calculating descriptor type {0} is not yet implemented.'
                  .format(self.settings.descriptor_type),
                  '\nPlease use one of the following types:')
            for des_type in descriptor_type_list:
                print('{0}'.format(des_type))
            sys.exit()

        return descriptors_torch, descriptor_derivatives_torch, neighbor_indices, \
            descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, active_atoms, \
            n_atoms_active, MM_gradients

####################################################################################################

    def calculate_atomic_environments(self, elements_int_sys, positions, lattice, atomic_classes,
                                      atomic_charges, n_atoms, n_atoms_sys, calc_derivatives=True,
                                      angular=True):
        '''
        Output format: ij_list: [ [elements_int_j], [R_ij], [dR_ij__dalpha_i_] ]_atoms
                                (For each atom there are three lists containing neighboring element
                                integers j, distances R_ij, and their derivatives dR_ij_/dalpha_i_,
                                respectively.)
                       ijk_list: [ [elements_int_j], [elements_int_k], [R_ij], [R_ik],
                                 [dR_ij__dalpha_i_], [dR_ik__dalpha_i_], [cos_theta_ijk],
                                 [dcos_theta_ijk__dalpha_] ]_atoms
                                 (For each atom there are eight lists containing neighboring element
                                 integers j and k, distances R_ij and R_ik, their derivatives
                                 dR_ij_/dalpha_i_ and dR_ik_/dalpha_i_, angles cos(theta_ijk), and
                                 their derivatives dcos(theta_ijk)_/dalpha_i,j,k_, respectively.)

                       dR_ij__dalpha_i_: [dR_ij_dx_i, dR_ij_dy_i, dR_ij_dz_i]
                       dcos_theta_ijk__dalpha_: [dct_dx_i, dct_dy_i, dct_dz_i, dct_dx_j, dct_dy_j,
                                                dct_dz_j, dct_dx_k, dct_dy_k, dct_dz_k]

        Hint: dR_ij__dalpha_j_ = -dR_ij__dalpha_i_

        Return: neighbor_index, ij_0, ij_1, ij_2, ij_3, ij_4, ijk_0, ijk_1, ijk_2, ijk_3, ijk_4,
                ijk_5, ijk_6, ijk_7, ijk_8, ijk_9, ijk_10, ijk_11, ijk_12, active_atoms,
                neighbor_index_env
        '''
        # initialize lists
        neighbor_index = typed.List()
        ij_0 = typed.List()
        ij_1 = typed.List()
        ij_2 = typed.List()
        ij_3 = typed.List()
        ij_4 = typed.List()
        ijk_0 = typed.List()
        ijk_1 = typed.List()
        ijk_2 = typed.List()
        ijk_3 = typed.List()
        ijk_4 = typed.List()
        ijk_5 = typed.List()
        ijk_6 = typed.List()
        ijk_7 = typed.List()
        ijk_8 = typed.List()
        ijk_9 = typed.List()
        ijk_10 = typed.List()
        ijk_11 = typed.List()
        ijk_12 = typed.List()

        # handle periodic systems
        if len(lattice) > 0:
            # get adjacent periodic images
            positions_all = get_periodic_images(positions, lattice, n_atoms)
            indices = np.arange(27 * n_atoms)
        # handle non-periodic systems
        else:
            positions_all = positions
            indices = np.arange(n_atoms)

        # determine neighbor indices lists and pair and triple properties inside the cutoff sphere
        # for each atom
        for i in range(n_atoms_sys):
            # pair properties
            # calculate distance vectors and values
            R_ij_ = positions_all - positions[i]
            R_ij = np.sqrt((R_ij_**2).sum(axis=1))
            # determine neighbor atoms inside cutoff sphere
            neighbors = np.less(R_ij, self.R_c)
            neighbors[i] = False
            # extract and sort properties for neighbor atoms
            neighbor_index.append(indices[neighbors])
            order = np.argsort(neighbor_index[-1] % n_atoms)
            neighbor_index[-1] = neighbor_index[-1][order]
            R_ij_ = R_ij_[neighbors][order]
            R_ij = R_ij[neighbors][order]
            # calculate derivatives of R_ij
            dR_ij__dalpha_i_ = -(R_ij_.T / R_ij).T
            # extract properties for QM/MM reference data
            if self.settings.QMMM:
                neighbor_index_sys = neighbor_index[-1][neighbor_index[-1] % n_atoms < n_atoms_sys]
                elements_int_j = np.concatenate((
                    elements_int_sys[neighbor_index_sys % n_atoms], -np.ones(
                        len(neighbor_index[-1]) - len(neighbor_index_sys), dtype=int)))
                interaction_classes_j = atomic_classes[neighbor_index[-1] % n_atoms]
                atomic_charges_j = atomic_charges[neighbor_index[-1] % n_atoms]
            else:
                elements_int_j = elements_int_sys[neighbor_index[-1] % n_atoms]
                interaction_classes_j = np.empty(0, dtype=int)
                atomic_charges_j = np.empty(0, dtype=float)
            # append properties to ij lists
            ij_0.append(elements_int_j)
            ij_1.append(R_ij)
            ij_2.append(dR_ij__dalpha_i_)
            ij_3.append(interaction_classes_j)
            ij_4.append(atomic_charges_j)

            # triple properties
            if angular:
                ijk = get_triple_properties(
                    elements_int_j, R_ij, dR_ij__dalpha_i_, interaction_classes_j, atomic_charges_j,
                    self.settings.QMMM, calc_derivatives=calc_derivatives)
                # append properties to ijk lists
                ijk_0.append(ijk[0])
                ijk_1.append(ijk[1])
                ijk_2.append(ijk[2])
                ijk_3.append(ijk[3])
                ijk_4.append(ijk[4])
                ijk_5.append(ijk[5])
                ijk_6.append(ijk[6])
                ijk_7.append(ijk[7])
                ijk_8.append(ijk[8])
                ijk_9.append(ijk[9])
                ijk_10.append(ijk[10])
                ijk_11.append(ijk[11])
                ijk_12.append(ijk[12])

        # determine active atoms and the system neighbor indices lists of active environment atoms
        neighbor_index_env = typed.List()
        if self.settings.QMMM:
            # determine active atoms
            active_atom_env = np.unique(np.array([i for index in neighbor_index for i in index]) % n_atoms)
            active_atom_env = active_atom_env[atomic_classes[active_atom_env] == 2]
            active_atom = np.concatenate((np.arange(n_atoms_sys), active_atom_env))

            # handle periodic systems
            if len(lattice) > 0:
                # get adjacent periodic images of system atoms
                positions_sys_all = get_periodic_images(positions[:n_atoms_sys], lattice, n_atoms_sys)
                indices_sys = np.arange(27 * n_atoms_sys)
            # handle non-periodic systems
            else:
                positions_sys_all = positions[:n_atoms_sys]
                indices_sys = np.arange(n_atoms_sys)

            # determine system neighbor indices lists of active environment atoms
            if len(active_atom_env) > 0:
                for i in active_atom_env:
                    # calculate distances to system atoms
                    R_ij_env = np.sqrt(((positions_sys_all - positions[i])**2).sum(axis=1))
                    # determine neighbor atoms inside cutoff sphere
                    neighbors_env = np.less(R_ij_env, self.R_c)
                    # append neighbor indices lists of active environment atoms
                    neighbor_index_env.append(indices_sys[neighbors_env])
                    order_env = np.argsort(neighbor_index_env[-1] % n_atoms_sys)
                    neighbor_index_env[-1] = neighbor_index_env[-1][order_env]
            else:
                neighbor_index_env.append(np.arange(0))

        # set all atoms to be active if QM/MM is false
        else:
            active_atom = np.arange(n_atoms_sys)
            neighbor_index_env.append(np.arange(0))

        return neighbor_index, ij_0, ij_1, ij_2, ij_3, ij_4, ijk_0, ijk_1, ijk_2, ijk_3, ijk_4, \
            ijk_5, ijk_6, ijk_7, ijk_8, ijk_9, ijk_10, ijk_11, ijk_12, active_atom, \
            neighbor_index_env

####################################################################################################

    def get_H_parameters(self, H_type_j, H_type_jk):
        '''
        Implementation: Elements: H to Xe
                        Element parameters: 1, n, sp, d, 6-n, 9-sp, 11-d (all parameters need to be
                                            specified for all elements)
                        Exceptions: He: sp == 8 and 9-sp == 1, main group: d == 0 and 11-d == 0,
                                    d-block: sp == 0 and 9-sp == 0

        Return: H_parameters_rad, H_parameters_ang, n_H_parameters
        '''
        # implemented H parameters
        H_parameters = {'H':  [1, 1, 1,  0, 5, 8,  0],
                        'He': [1, 1, 8,  0, 5, 1,  0],
                        'Li': [1, 2, 1,  0, 4, 8,  0],
                        'Be': [1, 2, 2,  0, 4, 7,  0],
                        'B':  [1, 2, 3,  0, 4, 6,  0],
                        'C':  [1, 2, 4,  0, 4, 5,  0],
                        'N':  [1, 2, 5,  0, 4, 4,  0],
                        'O':  [1, 2, 6,  0, 4, 3,  0],
                        'F':  [1, 2, 7,  0, 4, 2,  0],
                        'Ne': [1, 2, 8,  0, 4, 1,  0],
                        'Na': [1, 3, 1,  0, 3, 8,  0],
                        'Mg': [1, 3, 2,  0, 3, 7,  0],
                        'Al': [1, 3, 3,  0, 3, 6,  0],
                        'Si': [1, 3, 4,  0, 3, 5,  0],
                        'P':  [1, 3, 5,  0, 3, 4,  0],
                        'S':  [1, 3, 6,  0, 3, 3,  0],
                        'Cl': [1, 3, 7,  0, 3, 2,  0],
                        'Ar': [1, 3, 8,  0, 3, 1,  0],
                        'K':  [1, 4, 1,  0, 2, 8,  0],
                        'Ca': [1, 4, 2,  0, 2, 7,  0],
                        'Sc': [1, 4, 2,  1, 2, 7, 10],
                        'Ti': [1, 4, 2,  2, 2, 7,  9],
                        'V':  [1, 4, 2,  3, 2, 7,  8],
                        'Cr': [1, 4, 2,  4, 2, 7,  7],
                        'Mn': [1, 4, 2,  5, 2, 7,  6],
                        'Fe': [1, 4, 2,  6, 2, 7,  5],
                        'Co': [1, 4, 2,  7, 2, 7,  4],
                        'Ni': [1, 4, 2,  8, 2, 7,  3],
                        'Cu': [1, 4, 2,  9, 2, 7,  2],
                        'Zn': [1, 4, 2, 10, 2, 7,  1],
                        'Ga': [1, 4, 3,  0, 2, 6,  0],
                        'Ge': [1, 4, 4,  0, 2, 5,  0],
                        'As': [1, 4, 5,  0, 2, 4,  0],
                        'Se': [1, 4, 6,  0, 2, 3,  0],
                        'Br': [1, 4, 7,  0, 2, 2,  0],
                        'Kr': [1, 4, 8,  0, 2, 1,  0],
                        'Rb': [1, 5, 1,  0, 1, 8,  0],
                        'Sr': [1, 5, 2,  0, 1, 7,  0],
                        'Y':  [1, 5, 2,  1, 1, 7, 10],
                        'Zr': [1, 5, 2,  2, 1, 7,  9],
                        'Nb': [1, 5, 2,  3, 1, 7,  8],
                        'Mo': [1, 5, 2,  4, 1, 7,  7],
                        'Tc': [1, 5, 2,  5, 1, 7,  6],
                        'Ru': [1, 5, 2,  6, 1, 7,  5],
                        'Rh': [1, 5, 2,  7, 1, 7,  4],
                        'Pd': [1, 5, 2,  8, 1, 7,  3],
                        'Ag': [1, 5, 2,  9, 1, 7,  2],
                        'Cd': [1, 5, 2, 10, 1, 7,  1],
                        'In': [1, 5, 3,  0, 1, 6,  0],
                        'Sn': [1, 5, 4,  0, 1, 5,  0],
                        'Sb': [1, 5, 5,  0, 1, 4,  0],
                        'Te': [1, 5, 6,  0, 1, 3,  0],
                        'I':  [1, 5, 7,  0, 1, 2,  0],
                        'Xe': [1, 5, 8,  0, 1, 1,  0]}
        H_parameters_max = [1, 5, 8, 10, 5, 8, 10]

        # determine H parameters
        n_H_parameters = len(H_parameters['H']) - 1
        H_parameters_rad = np.array([[H_parameters[ele][H_t_j] for H_t_j in H_type_j]
                                    for ele in self.element_types])
        H_parameters_rad_scale = np.array([1.0 / H_parameters_max[H_t_j] for H_t_j in H_type_j])
        H_parameters_ang = np.array([[H_parameters[ele][H_t_jk - n_H_parameters]
                                    if H_t_jk > n_H_parameters else H_parameters[ele][H_t_jk]
                                    for H_t_jk in H_type_jk] for ele in self.element_types])
        H_parameters_ang_scale = np.array([1.0 / H_parameters_max[H_t_jk - n_H_parameters]
                                          if H_t_jk > n_H_parameters else 0.5 / H_parameters_max[H_t_jk]
                                          for H_t_jk in H_type_jk])

        return H_parameters_rad, H_parameters_ang, H_parameters_rad_scale, H_parameters_ang_scale, \
            n_H_parameters

####################################################################################################

    def define_activation_function(self):
        '''
        Implementation: Activation function: sTanh, Tanh, Tanhshrink, SiLU

        Return: activation_function
        '''
        # implemented activation function types
        activation_function_type_list = ['sTanh', 'Tanh', 'Tanhshrink', 'SiLU']

        # activation function type sTanh
        if self.settings.activation_function_type == 'sTanh':
            activation_function = sTanh

        # activation function type Tanh
        elif self.settings.activation_function_type == 'Tanh':
            activation_function = torch.nn.Tanh

        elif self.settings.activation_function_type == 'SiLU':
            activation_function = torch.nn.SiLU

        # activation function type Tanhshrink
        elif self.settings.activation_function_type == 'Tanhshrink':
            activation_function = torch.nn.Tanhshrink

        # not implemented activation function type
        else:
            print('ERROR: Activation function type {0} is not yet implemented.'
                  .format(self.settings.activation_function_type),
                  '\nPlease use one of the following activation function types:')
            for act_fct_typ in activation_function_type_list:
                print('{0}'.format(act_fct_typ))
            sys.exit()

        return activation_function

####################################################################################################

    def define_model(self, model_index=0):
        '''
        Implementation: Model: HDNNP (with/without scale and shift layer)

        Modify: model
        '''
        # implemented model types
        model_type_list = ['HDNNP']

        # HDNNP model
        if self.settings.model_type == 'HDNNP':
            self.model[model_index] = torch.jit.script(HDNNP(
                self.n_element_types, self.n_descriptors, self.settings.n_neurons_hidden_layers,
                self.activation_function, self.settings.scale_shift_layer).to(self.device))

        # not implemented model type
        else:
            print('ERROR: Model type {0} is not yet implemented.'
                  .format(self.settings.model_type),
                  '\nPlease use one of the following model types:')
            for mod_typ in model_type_list:
                print('{0}'.format(mod_typ))
            sys.exit()

####################################################################################################

    def write_inputdata(self, inputdata_file, elements, positions, lattices, atomic_classes,
                        atomic_charges, energy, forces, n_structures, n_atoms, reorder, name=None,
                        assignment=None, uncertainty=None, counter=1, mode='w'):
        '''
        Output: inputdata file
        '''
        # write inputdata file
        with open(inputdata_file, mode, encoding='utf-8') as f:
            for n in range(n_structures):
                f.write('begin\n')
                f.write('comment number {0}\n'.format(counter))
                if name:
                    f.write('comment name {0}\n'.format(name[n]))
                if assignment:
                    f.write('comment data {0}\n'.format(assignment[n]))
                if uncertainty:
                    n_uncertainty = len(uncertainty[n])
                    f.write('comment uncertainty energy {0:.8f}'.format(
                        round(uncertainty[n][0] / self.Hartree2eV, 8)))
                    if n_uncertainty > 1:
                        f.write(' forces_max {0:.6f}'.format(
                            round(uncertainty[n][1] / self.Hartree2eV * self.Bohr2Angstrom, 6)))
                    if n_uncertainty > 2:
                        f.write(' forces_max_QMMM {0:.6f}\n'.format(
                            round(uncertainty[n][2] / self.Hartree2eV * self.Bohr2Angstrom, 6)))
                    else:
                        f.write('\n')
                if len(lattices[n]) > 0:
                    for i in range(3):
                        f.write('lattice {0:>10.6f} {1:>10.6f} {2:>10.6f}\n'
                                .format(round(lattices[n][i][0] / self.Bohr2Angstrom, 6),
                                        round(lattices[n][i][1] / self.Bohr2Angstrom, 6),
                                        round(lattices[n][i][2] / self.Bohr2Angstrom, 6)))
                element = elements[n]
                position = positions[n] / self.Bohr2Angstrom
                atomic_class = atomic_classes[n].astype(float)
                atomic_charge = atomic_charges[n]
                force = forces[n] / self.Hartree2eV * self.Bohr2Angstrom
                if self.settings.QMMM:
                    element = element[reorder[n]]
                    position = position[reorder[n]]
                    atomic_class = atomic_class[reorder[n]]
                    atomic_charge = atomic_charge[reorder[n]]
                    force = force[reorder[n]]
                for i in range(n_atoms[n]):
                    f.write('atom {0:>10.6f} {1:>10.6f} {2:>10.6f} {3:2} {4:3.1f} {5:>6.3f} '
                            .format(round(position[i][0], 6), round(position[i][1], 6),
                                    round(position[i][2], 6), element[i], atomic_class[i],
                                    round(atomic_charge[i], 3)))
                    f.write('{0:>10.6f} {1:>10.6f} {2:>10.6f}\n'.format(
                        round(force[i][0], 6), round(force[i][1], 6), round(force[i][2], 6)))
                f.write('energy {0:.8f}\ncharge 0.0\nend\n'.format(
                    round(energy[n] / self.Hartree2eV, 8)))
                counter += 1
