#!/usr/bin/python3

from os import environ
from types import SimpleNamespace
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--activation_function_type")
parser.add_argument("--weight_initialization")
parser.add_argument("--seed", type=int)
parser.add_argument("--n_neurons_hidden_layers", nargs="+", type=int)
args = parser.parse_args()


# Settings of the lifelong machine learning potential
lMLP_settings = SimpleNamespace(
    # Input and output settings (all strings)
    # format of the episodic memory input file
    # value: 'inputdata'
    episodic_memory_format = 'inputdata',
    # path/name of the episodic memory input file
    episodic_memory_file = 'episodic_memory/input.data_B',
    # type of the supplemental potential
    # value: 'element_energy' (recommended), 'MieRc'
    supplemental_potential_type = 'element_energy',
    # path/name of the supplemental potential input file
    supplemental_potential_file = 'supplemental_potential/element_energy-PBED3_def2TZVP.dat',
    # type of the descriptor
    # value: 'eeACSF' (recommended), 'ACSF'
    descriptor_type = 'eeACSF',
    # path/name of the descriptor parameter input file
    descriptor_parameter_file = 'descriptor/eeACSF_bump_cos_int_linear_12.dat',
    # format of the generalization setting file
    # value: 'lMLP'
    generalization_setting_format = 'lMLP',
    # path/name of the generalization setting file
    generalization_setting_file = 'generalization/lMLP_setting_000.ini',
    # format of the generalization output file(s)
    # value: 'lMLP' (recommended), 'lMLP-only_prediction', 'RuNNer'
    generalization_format = 'lMLP',
    # path/name(dir) of the generalization output file(dir) (dir for generalization_format == 'RuNNer')
    generalization_file = 'generalization/lMLP_model_000.pt',
    # format of the prediction output file
    # value: 'inputdata'
    prediction_format = 'inputdata',
    # path/name of the prediction output file
    prediction_file = 'output_000.data',

    # Training and prediction settings
    # chemical element types (only those of QM atoms for QM/MM)
    # value: list of strings
    element_types = ['H', 'C', 'N', 'O', 'F', 'S', 'Cl', 'Se', 'Br', 'I'],
    # machine learning model type
    # value: 'HDNNP'
    model_type = 'HDNNP',
    # maximal atomic force length in eV/Angstrom (structures are disregarded otherwise)
    # value: float (atomic_force_max > 0)
    atomic_force_max = 20.0,
    # disable fitting
    # value: boolean
    prediction_only = False,
    # enable supplemental pair potential contributions (atomic energies are always enabled)
    # value: boolean (recommended: False)
    pair_contributions = False,
    # enable QM/MM reference data (electrostatic embedding)
    # value: boolean
    QMMM = False,
    # maximal absolute atomic charge of MM atoms in atomic units (required in descriptor calculation,
    # structures are disregarded otherwise)
    # value: float (MM_atomic_charge_max > 0)
    MM_atomic_charge_max = 2.5,
    # enable preparation of training for transfer learning
    # value: boolean
    transfer_learning = False,
    # enable scaling and shifting of input values
    # value: boolean (recommended: True)
    scale_shift_layer = True,
    # number of neurons in each hidden layer
    # value: list of integers (integers > 0)
    n_neurons_hidden_layers = args.n_neurons_hidden_layers,
    # activation function type
    # value: 'sTanh' (recommended), 'Tanh', 'Tanhshrink'
    activation_function_type = args.activation_function_type,
    # combined radial and cutoff function type in descriptor
    # value: 'bump' (recommended), 'gaussian_bump', 'gaussian_cos'
    descriptor_radial_type = 'bump',
    # angular function type in descriptor
    # value: 'bump' (recommended), 'cos', 'cos_int'
    descriptor_angular_type = 'cos_int',
    # scaling function type in descriptor
    # value: 'crss' (recommended), 'sqrt', 'linear'
    descriptor_scaling_type = 'crss',
    # 'None': handling of descriptors in memory, 'write': (over)writing and reading of descriptor
    # derivatives and neighbor indices on/from disk, 'append': writing of nonexisting and reading of
    # preexisting descriptor derivatives and neighbor indices on/from disk
    # value: 'None' (recommended), 'write', 'append'
    descriptor_on_disk = 'None',
    # directory of the descriptor binary files
    # value: string
    descriptor_disk_dir = '',
    # None: calculate descriptor values and run all checks, string: file name of descriptor cache
    # (either the file is read and its completeness is trusted or it is written if it is not existing)
    # value: None (recommended), string
    descriptor_cache_file = None,
    # None: descriptor output is not written, string: file name of descriptor output
    # value: None (recommended), string
    descriptor_output_file = None,
    # enable evaluation of maximal memory consumption
    # value: boolean (recommended: True)
    memory_evaluation = True,
    # floating point format used by PyTorch
    # value: 'double' (recommended), 'float'
    dtype_torch = 'double',
    # device used by PyTorch
    # value: 'cpu' (recommended), 'cuda'
    device = 'cpu',
    # energy unit in output
    # value: 'eV' (recommended), 'Hartree', 'kJ_mol'
    energy_unit = 'eV',
    # length unit in output
    # value: 'Angstrom' (recommended), 'Bohr', 'pm'
    length_unit = 'Angstrom',

    # Additional training settings
    # seed of the NumPy and PyTorch random number generators
    # value: integer
    seed = args.seed,
    # enable restart from old weights
    # value: boolean
    restart = False,
    # weight initialization scheme
    # value: 'sTanh' (recommended), 'Tanh', 'Tanhshrink', 'default'
    weight_initialization = args.weight_initialization,
    # enable fitting of the forces
    # value: boolean (recommended: True)
    fit_forces = True,
    # enable only energy preoptimization step in each epoch
    # value: boolean (recommended: True)
    energy_preoptimization_step = True,
    # fraction of reference structures used for training
    # value: float (0 < training_fraction <= 1)
    training_fraction = 0.9,
    # fraction of training structures used in each training epoch
    # value: float (0 < fit_fraction <= 1)
    fit_fraction = 0.1,
    # optimizer
    # value: 'CoRe' (recommended), 'Adam', 'Rprop', etc.
    optimizer = 'CoRe',
    # number of training epochs
    # value: integer (n_epochs >= 0)
    n_epochs = 2000,
    # learning rate value in the optimizer
    # value: float (learning_rate > 0)
    learning_rate = 0.001,
    # step size parameters in the CoRe and Rprop optimizers
    # value: tuple of floats (float > 0, float > 0)
    step_sizes = (1e-6, 0.01),
    # eta parameters in the CoRe optimizer
    # value: tuple of floats (0 < float <= 1, float >= 1)
    etas = (0.55, 1.2),
    # beta parameters in the CoRe optimizer
    # value: tuple of floats (0 <= float < 1, 0 <= float < 1, float > 0, float < 1)
    betas = (0.7375, 0.8125, 250.0, 0.99),
    # weight decay parameter in the CoRe optimizer
    # value: float (0 <= weight_decay < 1)
    weight_decay = 0.1,
    # score history parameter in the CoRe optimizer
    # value: float (score_history >= 0)
    score_history = 250,
    # fraction of frozen weights to each hidden layer according to the score history in the CoRe optimizer
    # value: float (0 <= frozen < 1)
    frozen = 0.025,
    # enable foreach implementation in the CoRe optimizer
    # value: boolean (recommended: True)
    foreach = True,
    # freeze weights of given layer indices (scale_shift_layer: index 0, hidden layers: index
    # counting starts at 1 (even if scale_shift_layer == False), output layer: N_hidden_layers + 1)
    # value: list of integers (recommended: [])
    frozen_layers = [],
    # loss function
    # value: 'MSELoss' (recommended), 'HuberLoss', 'SmoothL1Loss', 'L1Loss'
    loss_function = 'MSELoss',
    # loss parameter for energy loss and forces loss in 'HuberLoss' and 'SmoothL1Loss'
    # value: tuple of floats (float >= 0, float >= 0)
    loss_parameters = (1.0, 1.0),
    # loss(E) is multiplied by loss_E_scaling and loss(F) is not scaled
    # value: float (loss_E_scaling > 0)
    loss_E_scaling = 10.9**2,
    # selection scheme for the structures to be fitted if fit_fraction < 1
    # value: 'lADS' (recommended), 'random'
    selection_scheme = 'lADS',
    # quantities to be monitored to determine the probability factor of a structure in selection scheme 'lADS'
    # value: 'E+F_losses' (recommended), 'total_loss'
    selection_measure = 'total_loss',
    # minimum and maximum probability factor of a structure in selection scheme 'lADS'
    # value: tuple of floats (0 < float < 1, float > 1)
    selection_range = (0.1, 100.0),
    # error region boundaries are obtained by multiplying the error in the loss function by the thresholds
    # in selection scheme 'lADS'
    # value: tuple of floats (float > 0, float > 0, float > 0, float > 0)
    selection_thresholds = (0.9, 2.0, 3.0, 6.0),
    # number of subsequent de-/increases yielding min./max probability factor in selection scheme 'lADS'
    # value: tuple of integers (integer > 0, integer > 0)
    selection_strikes = (15, 100),
    # number of subsequent small de-/increases yielding min./max. probability factor in selection scheme 'lADS'
    # value: tuple of integers (integer > 0, integer > 0)
    selection_small_strikes = (45, 300),
    # number of subsequent high errors leading to exclusion of a structure in selection scheme 'lADS'
    # value: integer (exclusion_strikes > 0)
    exclusion_strikes = 5,
    # maximal fraction of new redundant training data in each epoch in selection scheme 'lADS'
    # value: float (0 < fraction_redundant_max <= 1)
    fraction_redundant_max = 0.015,
    # enable backtracking of gradients for deselected training structures in selection scheme 'lADS'
    # (only available for the optimizer 'CoRe')
    # value: boolean (recommended: True)
    gradient_backtracking = True,
    # selection probability increase factor of structures with very small mean atomic force lengths
    # compared to those with maximum mean atomic force lengths (atomic_force_max) in selection scheme 'lADS'
    # value: float (stationary_point_prob_factor >= 1)
    stationary_point_prob_factor = 1.0,
    # maximal fraction of good training data in each epoch in selection scheme 'lADS'
    # value: float (0 < fraction_good_max < 1)
    fraction_good_max = 2.0 / 3.0,
    # number of changes between zero and the maximum fraction of good data in selection scheme 'lADS'
    # value: integer (n_fraction_intervals > 0)
    n_fraction_intervals = 20,
    # enable writing of new, bad, and redundant training data and test data in selection scheme 'lADS'
    # value: boolean (recommended: False)
    write_new_episodic_memory = False,
    # enable printing of bad training data names in output file in selection scheme 'lADS'
    # value: boolean (recommended: True)
    print_bad_data_names = True,
    # scheme for adding data later in the fit (only available for selection scheme 'lADS')
    # value: 'None' (recommended), 'bottom', 'random', 'top'
    late_data_scheme = 'None',
    # fraction of data which is added later in the fit
    # value: float (0 <= late_data_fraction < 1)
    late_data_fraction = 0.0,
    # epoch in which the late data is added (0: no addition)
    # value: integer (0 <= late_data_epoch <= n_epochs)
    late_data_epoch = 0,
    # training epoch interval of RMSE calculations
    # value: integer (RMSE_interval > 0)
    RMSE_interval = 10,
    # integer: interval of training epochs in which generalization files are written
    # (write_weights_interval has to be a multiple of RMSE_interval), None: no writing
    # value: integer (write_weights_interval > 0), None
    write_weights_interval = None,
    # enable writing of prediction output file
    # value: boolean
    write_prediction = True,

    # Additional prediction settings
    # enable RMSE calculation of all predicted data
    # value: boolean
    prediction_RMSE = True,

    # Computational settings
    # number of threads used by Numba
    # value: integer (n_threads >= 1)
    n_threads = 1)

# Environment variables
# NUMBA_NUM_THREADS overwrites the Numba number of threads specified by n_threads
# NUMBA_JIT=0 disables all Numba jit compilation
# PYTORCH_JIT=0 disables all PyTorch jit compilation
# command: python3 -u input_lmlp.py > output_000.dat

####################################################################################################

####################################################################################################

# Set Numba number of threads
if environ.get('NUMBA_NUM_THREADS') is None:
    environ['NUMBA_NUM_THREADS'] = f'{lMLP_settings.n_threads}'

import lmlp
lMLP = lmlp.lMLP(lMLP_settings)

import torch
import scipy
import numpy as np

def get_effective_rank(matrix, return_singular_values=False):
    # effective rank presented here https://ieeexplore.ieee.org/abstract/document/7098875
    S = torch.linalg.svdvals(matrix)
    if return_singular_values:
        singular_values = S.detach().clone()
    S /= torch.sum(S)
    erank = torch.e ** scipy.stats.entropy(S.detach())
    if return_singular_values:
        return np.nan_to_num(erank), singular_values
    return np.nan_to_num(erank)


elements, positions, lattices, atomic_classes, atomic_charges, energy, forces, n_structures, n_atoms, name = lMLP.read_episodic_memory()
n_atoms_sys = n_atoms
# determine element integer list
elements_int_sys = lMLP.convert_element_names(elements, n_structures, n_atoms_sys)

# determine neighbor indices and descriptors and their derivatives as a function of the
# Cartesian coordinates
descriptors_torch, descriptor_derivatives_torch, neighbor_indices, \
    descriptor_neighbor_derivatives_torch_env, neighbor_indices_env, active_atoms, \
    n_atoms_active, MM_gradients = lMLP.calculate_descriptors(
        elements_int_sys, positions, lattices, atomic_classes, atomic_charges, n_structures,
        n_atoms, n_atoms_sys, name
    )


print("Loaded")


train, test, n_structures_train, n_structures_test, assignment = lMLP.define_train_test_splitting(name, n_structures)
lMLP.initialize_scale_shift_layer(train, elements_int_sys, descriptors_torch)
lMLP.initialize_weights(train, elements_int_sys, energy, n_structures_train, n_atoms_sys)
print("Init done")

activations = []
def get_activation(network_id, layer_id):
    def hook(model, input, output):
        activations[network_id][layer_id] = torch.cat((activations[network_id][layer_id], output.unsqueeze(0)), dim=0)
        if activations[network_id][layer_id].shape[0] > activations[network_id][layer_id].shape[1]:
            randperm = torch.randperm(activations[network_id][layer_id].shape[0])[:activations[network_id][layer_id].shape[1]]
            activations[network_id][layer_id] = activations[network_id][layer_id][randperm]
    return hook

hooks = []
for network_id, network in enumerate(lMLP.model[0].atomic_neural_networks):
    activations.append([])
    for layer_id, layer in enumerate(network):
        activations[network_id].append(torch.tensor([]))
        hook = layer.register_forward_hook(get_activation(network_id, layer_id))
        hooks.append(hook)

ranks = {
    i: [0 for _ in range(len(activations[0][:-1]))] for i in range(len(lMLP.get_element_types()))
}

# to make sure we pick the same random samples in `get_activation()` for network of different sizes we need to set the seed again here
torch.manual_seed(args.seed)
NUM_REPEATS = 400
for repeat in range(NUM_REPEATS):
    lMLP.initialize_weights(train, elements_int_sys, energy, n_structures_train, n_atoms_sys)
    indices = torch.randperm(len(descriptors_torch))[:10_000]
    count = 0
    for i in indices:
        subsystem = np.where(elements_int_sys[i] == 0)[0]
        if len(subsystem) > 0:
            idx = np.random.choice(subsystem)
            lMLP.model[0]([elements_int_sys[i][idx]], descriptors_torch[i][idx].unsqueeze(dim=0), 1)
            count += 1
        if count >= max(descriptors_torch[0][0].shape[0], max(args.n_neurons_hidden_layers)):
            break

    for element in range(len(lMLP.get_element_types())):
        for i, activation in enumerate(activations[element][:-1]):
            if activation.numel() > 0:
                assert activation.shape[0] == activation.shape[1], f"Activation map for {element=} and {i=} is not quadratic!"
                tmp = get_effective_rank(activation)
                ranks[element][i] += tmp / NUM_REPEATS

    activations = []
    for network_id, network in enumerate(lMLP.model[0].atomic_neural_networks):
        activations.append([])
        for layer_id, layer in enumerate(network):
            activations[network_id].append(torch.tensor([]))

print(ranks)
for hook in hooks:
    hook.remove()

lMLP.run()
