#!/usr/bin/python3

'''
Lifelong Machine Learning Potentials (lMLP)
'''
__copyright__ = '''This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.'''

from glob import glob
from io import StringIO
from pathlib import Path
from re import findall, search
from shutil import copy, copytree
from types import SimpleNamespace
import pytest
import lmlp

TESTS_DIR = str(Path(__file__).parent.absolute())


@pytest.fixture(autouse=True)
def copy_pretrained_models(tmp_path):
    '''
    Fixture that will copy the pretrained models and episodic memory files to the temporary path
    before the tests are started.
    '''
    copytree(f'{TESTS_DIR}/resources/lmlp/generalization', f'{tmp_path}/generalization')
    Path(f'{tmp_path}/episodic_memory').mkdir(parents=True)
    for epi_mem_file in glob(f'{TESTS_DIR}/../examples/episodic_memory/' + r'input.data_*_test'):
        copy(epi_mem_file, f'{tmp_path}/episodic_memory/{epi_mem_file.split("/")[-1]}')


@pytest.fixture(name="lmlp_settings")
def fixture_lmlp_settings(tmp_path, request):
    '''
    Fixture that returns a SimpleNamespace object for the settings of an lMLP object.
    '''
    lMLP_settings = SimpleNamespace(
        # Input and output settings (all strings)
        # format of the episodic memory input file
        # value: 'inputdata'
        episodic_memory_format = 'inputdata',
        # path/name of the episodic memory input file
        episodic_memory_file = f'{tmp_path}/episodic_memory/input.data_{request.param[2]}_test',
        # type of the supplemental potential
        # value: 'element_energy' (recommended), 'MieRc'
        supplemental_potential_type = request.param[3].split('-')[0],
        # path/name of the supplemental potential input file
        supplemental_potential_file = f'{TESTS_DIR}/../examples/supplemental_potential/{request.param[3]}.dat',
        # type of the descriptor
        # value: 'eeACSF' (recommended), 'ACSF'
        descriptor_type = request.param[4],
        # path/name of the descriptor parameter input file
        descriptor_parameter_file = f'{TESTS_DIR}/../examples/descriptor/{request.param[4]}_'\
                                    f'{request.param[11]}_{request.param[12]}_6{request.param[5]}.dat',
        # format of the generalization setting file
        # value: 'lMLP'
        generalization_setting_format = 'lMLP',
        # path/name of the generalization setting file
        generalization_setting_file = f'{tmp_path}/generalization/{request.param[0]}.ini',
        # format of the generalization output file(s)
        # value: 'lMLP' (recommended), 'lMLP-only_prediction', 'RuNNer'
        generalization_format = request.param[6],
        # path/name(dir) of the generalization output file(dir) (dir for generalization_format == 'RuNNer')
        generalization_file = f'{tmp_path}/generalization/{request.param[0]}.pt',
        # format of the prediction output file
        # value: 'inputdata'
        prediction_format = 'inputdata',
        # path/name of the prediction output file
        prediction_file = f'{tmp_path}/output_000.data',

        # Training and prediction settings
        # chemical element types (only those of QM atoms for QM/MM)
        # value: list of strings
        element_types = request.param[7],
        # machine learning model type
        # value: 'HDNNP'
        model_type = 'HDNNP',
        # maximal atomic force length in eV/Angstrom (structures are disregarded otherwise)
        # value: float (atomic_force_max > 0)
        atomic_force_max = 15.0,
        # disable fitting
        # value: boolean
        prediction_only = request.param[1],
        # enable supplemental pair potential contributions (atomic energies are always enabled)
        # value: boolean (recommended: False)
        pair_contributions = False,
        # enable QM/MM reference data (electrostatic embedding)
        # value: boolean
        QMMM = request.param[8],
        # maximal absolute atomic charge of MM atoms in atomic units (required in descriptor calculation,
        # structures are disregarded otherwise)
        # value: float (MM_atomic_charge_max > 0)
        MM_atomic_charge_max = 2.0,
        # enable preparation of training for transfer learning
        # value: boolean
        transfer_learning = False,
        # enable scaling and shifting of input values
        # value: boolean (recommended: True)
        scale_shift_layer = request.param[9],
        # number of neurons in each hidden layer
        # value: list of integers (integers > 0)
        n_neurons_hidden_layers = [20, 16, 12],
        # activation function type
        # value: 'sTanh' (recommended), 'Tanh', 'Tanhshrink'
        activation_function_type = request.param[10],
        # combined radial and cutoff function type in descriptor
        # value: 'bump' (recommended), 'gaussian_bump', 'gaussian_cos'
        descriptor_radial_type = request.param[11],
        # angular function type in descriptor
        # value: 'bump' (recommended), 'cos', 'cos_int'
        descriptor_angular_type = request.param[12],
        # scaling function type in descriptor
        # value: 'crss' (recommended), 'sqrt', 'linear'
        descriptor_scaling_type = request.param[13],
        # 'None': handling of descriptors in memory, 'write': (over)writing and reading of descriptor
        # derivatives and neighbor indices on/from disk, 'append': writing of nonexisting and reading of
        # preexisting descriptor derivatives and neighbor indices on/from disk
        # value: 'None' (recommended), 'write', 'append'
        descriptor_on_disk = request.param[14],
        # directory of the descriptor binary files
        # value: string
        descriptor_disk_dir = f'{tmp_path}/tmp_descriptor',
        # None: calculate descriptor values and run all checks, string: file name of descriptor cache
        # (either the file is read and its completeness is trusted or it is written if it is not existing)
        # value: None (recommended), string
        descriptor_cache_file = None,
        # None: descriptor output is not written, string: file name of descriptor output
        # value: None (recommended), string
        descriptor_output_file = f'{tmp_path}/descriptor.dat',
        # enable evaluation of maximal memory consumption
        # value: boolean (recommended: True)
        memory_evaluation = request.param[15],
        # floating point format used by PyTorch
        # value: 'double' (recommended), 'float'
        dtype_torch = request.param[16],
        # device used by PyTorch
        # value: 'cpu' (recommended), 'cuda'
        device = 'cpu',
        # energy unit in output
        # value: 'eV' (recommended), 'Hartree', 'kJ_mol'
        energy_unit = request.param[17],
        # length unit in output
        # value: 'Angstrom' (recommended), 'Bohr', 'pm'
        length_unit = request.param[18],

        # Additional training settings
        # seed of the NumPy and PyTorch random number generators
        # value: integer
        seed = 227,
        # enable restart from old weights
        # value: boolean
        restart = request.param[19],
        # weight initialization scheme
        # value: 'sTanh' (recommended), 'Tanh', 'Tanhshrink', 'default'
        weight_initialization = request.param[20],
        # enable fitting of the forces
        # value: boolean (recommended: True)
        fit_forces = request.param[21],
        # enable energy only preoptimization step in each epoch (i.e. two optimizer steps per epoch)
        # value: boolean (recommended: True)
        energy_preoptimization_step = request.param[22],
        # fraction of reference structures used for training
        # value: float (0 < training_fraction <= 1)
        training_fraction = 0.9,
        # float: fraction of training structures used in each training epoch,
        # integer: number of training structures used in each training epoch
        # value: float (0 < fit_fraction <= 1), integer (fit_fraction >= 1)
        fit_fraction = request.param[23],
        # optimizer
        # value: 'CoRe' (recommended), 'Adam', 'Rprop', etc.
        optimizer = request.param[24],
        # number of training epochs
        # value: integer (n_epochs >= 0)
        n_epochs = 20,
        # learning rate value in the optimizer
        # value: float (learning_rate > 0)
        learning_rate = 0.001,
        # step size parameters in the CoRe and Rprop optimizers
        # value: tuple of floats (float > 0, float > 0)
        step_sizes = (1e-6, 0.01),
        # eta parameters in the CoRe optimizer
        # value: tuple of floats (0 < float <= 1, float >= 1)
        etas = (0.55, 1.2),
        # beta parameters in the CoRe optimizer
        # value: tuple of floats (0 <= float < 1, 0 <= float < 1, float > 0, float < 1)
        betas = (0.7375, 0.8125, 250.0, 0.99),
        # weight decay parameter in the CoRe optimizer
        # value: float (0 <= weight_decay < 1)
        weight_decay = 0.1,
        # score history parameter in the CoRe optimizer
        # value: float (score_history >= 0)
        score_history = 5,
        # fraction of frozen weights to each hidden layer according to the score history
        # in the CoRe optimizer
        # value: float (0 <= frozen < 1)
        frozen = 0.025,
        # enable foreach implementation in the CoRe optimizer
        # value: boolean (recommended: True)
        foreach = request.param[25],
        # freeze weights of given layer indices (scale_shift_layer: index 0, hidden layers: index
        # counting starts at 1 (even if scale_shift_layer == False), output layer: N_hidden_layers + 1)
        # value: list of integers (recommended: [])
        frozen_layers = request.param[26],
        # loss function
        # value: 'MSELoss' (recommended), 'HuberLoss', 'SmoothL1Loss', 'L1Loss'
        loss_function = request.param[27],
        # loss parameter for energy loss and forces loss in 'HuberLoss' and 'SmoothL1Loss'
        # value: tuple of floats (float >= 0, float >= 0)
        loss_parameters = (1.0, 1.0),
        # loss(E) is multiplied by loss_E_scaling and loss(F) is not scaled
        # value: float (loss_E_scaling > 0)
        loss_E_scaling = 100.0,
        # selection scheme for the structures to be fitted if fit_fraction < 1
        # value: 'lADS' (recommended), 'random'
        selection_scheme = request.param[28],
        # quantities to be monitored to determine the probability factor of a structure
        # in selection scheme 'lADS'
        # value: 'E+F_losses' (recommended), 'total_loss'
        selection_measure = request.param[29],
        # minimum and maximum probability factor of a structure in selection scheme 'lADS'
        # value: tuple of floats (0 < float < 1, float > 1)
        selection_range = (0.1, 100.0),
        # error region boundaries are obtained by multiplying the error in the loss function
        # by the thresholds in selection scheme 'lADS'
        # value: tuple of floats (float > 0, float > 0, float > 0, float > 0)
        selection_thresholds = (0.9, 1.2, 2.0, 6.0),
        # number of subsequent de-/increases yielding min./max probability factor
        # in selection scheme 'lADS'
        # value: tuple of integers (integer > 0, integer > 0)
        selection_strikes = (3, 6),
        # number of subsequent small de-/increases yielding min./max. probability factor
        # in selection scheme 'lADS'
        # value: tuple of integers (integer > 0, integer > 0)
        selection_small_strikes = (3, 6),
        # number of subsequent high errors leading to exclusion of a structure
        # in selection scheme 'lADS'
        # value: integer (exclusion_strikes > 0)
        exclusion_strikes = 2,
        # maximal fraction of new redundant training data in each epoch in selection scheme 'lADS'
        # value: float (0 < fraction_redundant_max <= 1)
        fraction_redundant_max = 0.015,
        # enable backtracking of gradients for deselected training structures in selection scheme 'lADS'
        # (only available for the optimizer 'CoRe')
        # value: boolean (recommended: True)
        gradient_backtracking = request.param[30],
        # selection probability increase factor of structures with very small mean atomic force lengths
        # compared to those with maximum mean atomic force lengths (atomic_force_max) in selection scheme 'lADS'
        # value: float (stationary_point_prob_factor >= 1)
        stationary_point_prob_factor = request.param[31],
        # maximal fraction of good training data in each epoch in selection scheme 'lADS'
        # value: float (0 < fraction_good_max < 1)
        fraction_good_max = 2.0 / 3.0,
        # number of changes between zero and the maximum fraction of good data
        # in selection scheme 'lADS'
        # value: integer (n_fraction_intervals > 0)
        n_fraction_intervals = 20,
        # enable writing of new, bad, and redundant training data and test data
        # in selection scheme 'lADS'
        # value: boolean (recommended: False)
        write_new_episodic_memory = request.param[32],
        # enable printing of bad training data names in output file in selection scheme 'lADS'
        # value: boolean (recommended: True)
        print_bad_data_names = request.param[33],
        # scheme for adding data later in the fit (only available for selection scheme 'lADS')
        # value: 'None' (recommended), 'bottom', 'random', 'top'
        late_data_scheme = 'None',
        # fraction of data which is added later in the fit
        # value: float (0 <= late_data_fraction < 1)
        late_data_fraction = 0.0,
        # epoch in the current run in which the late data is added
        # value: integer (0 <= late_data_epoch <= n_epochs)
        late_data_epoch = 0,
        # training epoch interval of RMSE calculations
        # value: integer (RMSE_interval > 0)
        RMSE_interval = 2,
        # integer: interval of training epochs in which generalization files are written,
        # (write_weights_interval has to be a multiple of RMSE_interval), None: no writing
        # value: integer (write_weights_interval > 0), None
        write_weights_interval = request.param[34],
        # enable writing of prediction output file
        # value: boolean
        write_prediction = True,

        # Additional prediction settings
        # enable RMSE calculation of all predicted data
        # value: boolean
        prediction_RMSE = True,

        # Computational settings
        # number of threads used by Numba
        # value: integer (n_threads >= 1)
        n_threads = 1)

    return lMLP_settings


@pytest.mark.parametrize('lmlp_settings, expected', [
    # generalization_file, prediction_only, episodic_memory_file, supplemental_potential_file, descriptor_type,
    # descriptor_parameter_file, generalization_format, element_types, QMMM, scale_shift_layer,
    # activation_function_type, descriptor_radial_type, descriptor_angular_type, descriptor_scaling_type,
    # descriptor_on_disk,
    # memory_evaluation, dtype_torch, energy_unit, length_unit, restart,
    # weight_initialization, fit_forces, energy_preoptimization_step, fit_fraction, optimizer,
    # foreach, frozen_layers, loss_function, selection_scheme, selection_measure,
    # gradient_backtracking, stationary_point_prob_plus, write_new_episodic_memory, print_bad_data_names,
    # write_weights_interval
    (
        ('test_0', False, 'general', 'MieRc-PBED3_def2TZVP', 'eeACSF',
         '_zeta', 'lMLP', ['H', 'C', 'Cl', 'I'], False, False,
         'sTanh', 'bump', 'cos', 'crss', 'write',
         True, 'double', 'eV', 'Angstrom', True,
         'sTanh', True, True, 43, 'CoRe',
         True, [2], 'MSELoss', 'random', 'total_losses',
         True, 1.0, True, True, 8),
        (0.009786, 0.011280, 0.248435, 0.255627)
    ),
    (
        ('test_1', False, 'general', 'MieRc-PBED3_def2TZVP', 'eeACSF',
         '', 'lMLP-only_prediction', ['H', 'C', 'Cl', 'I'], False, False,
         'Tanh', 'gaussian_bump', 'cos', 'sqrt', 'append',
         True, 'double', 'Hartree', 'Bohr', False,
         'Tanh', False, False, 0.25, 'Adam',
         True, [], 'HuberLoss', 'lADS', 'total_loss',
         False, 1.0, False, False, 8),
        (0.004466, 0.005034, 0.023325, 0.021128)
    ),
    (
        ('test_2', False, 'general', 'element_energy-PBED3_def2TZVP', 'ACSF',
         '', 'lMLP', ['H', 'C', 'Cl', 'I'], False, True,
         'Tanhshrink', 'gaussian_cos', 'cos', 'linear', 'None',
         False, 'double', 'kJ_mol', 'pm', False,
         'Tanhshrink', True, False, 0.25, 'Rprop',
         True, [0, 2, 4], 'SmoothL1Loss', 'lADS', 'total_loss',
         True, 1.0, True, True, None),
        (10.348640, 11.864595, 0.901684, 0.728409)
    ),
    (
        ('test_3', False, 'general', 'element_energy-PBED3_def2TZVP', 'ACSF',
         '_radial', 'RuNNer', ['H', 'C', 'Cl', 'I'], False, True,
         'Tanhshrink', 'bump', 'bump', 'linear', 'write',
         False, 'double', 'kJ_mol', 'pm', True,
         'default', False, False, 0.25, 'CoRe',
         False, [], 'L1Loss', 'lADS', 'total_loss',
         True, 10.0, False, False, None),
        (6.706521, 7.045070, 4.808217, 4.831979)
    ),
    (
        ('test_4', False, 'periodic', 'element_energy-PBE0rD3', 'ACSF',
         '', 'lMLP', ['Li', 'O', 'Mn'], False, True,
         'sTanh', 'bump', 'bump', 'linear', 'None',
         True, 'double', 'eV', 'Angstrom', False,
         'sTanh', True, True, 0.25, 'CoRe',
         True, [], 'MSELoss', 'lADS', 'E+F_losses',
         True, 1.0, True, True, None),
        (0.023636, 0.027820, 0.060550, 0.081940)
    ),
    (
        ('test_5', False, 'periodic', 'element_energy-PBE0rD3', 'eeACSF',
         '_radial', 'lMLP', ['Li', 'O', 'Mn'], False, False,
         'sTanh', 'bump', 'bump', 'crss', 'None',
         True, 'float', 'eV', 'Angstrom', True,
         'sTanh', True, False, 0.25, 'Adam',
         False, [], 'HuberLoss', 'random', 'E+F_losses',
         True, 1.0, False, False, None),
        (0.031973, 0.010216, 0.003281, 0.003930)
    ),
    (
        ('test_6', False, 'QMMM', 'element_energy-GFN2', 'eeACSF',
         '_QMMM', 'lMLP', ['H', 'C', 'O'], True, True,
         'sTanh', 'bump', 'cos_int', 'linear', 'append',
         True, 'double', 'eV', 'Angstrom', True,
         'sTanh', True, True, 0.25, 'Adam',
         True, [], 'MSELoss', 'lADS', 'E+F_losses',
         True, 1.0, True, True, None),
        (0.010552, 0.015281, 0.285682, 0.474609)
    ),
    (
        ('test_7', False, 'QMMM', 'element_energy-GFN2', 'eeACSF',
         '_radial_QMMM', 'lMLP', ['H', 'C', 'O'], True, False,
         'sTanh', 'bump', 'cos_int', 'linear', 'None',
         True, 'double', 'eV', 'Angstrom', False,
         'sTanh', True, False, 0.25, 'CoRe',
         True, [], 'HuberLoss', 'random', 'E+F_losses',
         True, 1.0, False, False, None),
        (0.008706, 0.003467, 0.508426, 0.541227)
    )
], indirect=['lmlp_settings'])
def test_lmlp_training(lmlp_settings, expected, capsys):
    '''
    Trains an lMLP and compares the obtained energy and forces with known good values.
    '''
    lMLP = lmlp.lMLP(lmlp_settings)
    lMLP.run()
    captured = StringIO(capsys.readouterr().out)

    # skip header and settings
    for _ in range(100):
        next(captured)
    E_RMSE_train = None
    E_RMSE_test = None
    F_RMSE_train = None
    F_RMSE_test = None
    for line in captured:
        if line.startswith('Final'):
            data = findall(r'\S+\|?', line)
            E_RMSE_train = float(data[8])
            E_RMSE_test = float(data[10])
            F_RMSE_train = float(data[12])
            F_RMSE_test = float(data[14])
            break

    assert pytest.approx(expected[0], rel=1e-5, abs=1e-5) == E_RMSE_train, \
        f'ERROR: Training energy RMSE is {E_RMSE_train} but it should be {expected[0]}.'
    assert pytest.approx(expected[1], rel=1e-5, abs=1e-5) == E_RMSE_test, \
        f'ERROR: Test energy RMSE is {E_RMSE_test} but it should be {expected[1]}.'
    assert pytest.approx(expected[2], rel=1e-5, abs=1e-5) == F_RMSE_train, \
        f'ERROR: Training forces RMSE is {F_RMSE_train} but it should be {expected[2]}.'
    assert pytest.approx(expected[3], rel=1e-5, abs=1e-5) == F_RMSE_test, \
        f'ERROR: Test forces RMSE is {F_RMSE_test} but it should be {expected[3]}.'


@pytest.mark.parametrize('lmlp_settings, expected', [
    # generalization_file, prediction_only, episodic_memory_file, supplemental_potential_file, descriptor_type,
    # descriptor_parameter_file, generalization_format, element_types, QMMM, scale_shift_layer,
    # activation_function_type, descriptor_radial_type, descriptor_angular_type, descriptor_scaling_type,
    # descriptor_on_disk,
    # memory_evaluation, dtype_torch, energy_unit, length_unit, restart,
    # weight_initialization, fit_forces, energy_preoptimization_step, fit_fraction, optimizer,
    # foreach, frozen_layers, loss_function, selection_scheme, selection_measure,
    # gradient_backtracking, stationary_point_prob_plus, write_new_episodic_memory, print_bad_data_names,
    # write_weights_interval
    (
        ('test_0', True, 'general', 'MieRc-PBED3_def2TZVP', 'eeACSF',
         '_zeta', 'lMLP', ['H', 'C', 'Cl', 'I'], False, False,
         'sTanh', 'bump', 'cos', 'crss', 'write',
         True, 'double', 'eV', 'Angstrom', True,
         'sTanh', True, True, 43, 'CoRe',
         True, [2], 'MSELoss', 'random', 'total_losses',
         True, 1.0, True, True, 8),
        (0.010891, 0.258368)
    ),
    (
        ('test_1', True, 'general', 'MieRc-PBED3_def2TZVP', 'eeACSF',
         '', 'lMLP-only_prediction', ['H', 'C', 'Cl', 'I'], False, False,
         'Tanh', 'gaussian_bump', 'cos', 'sqrt', 'append',
         True, 'double', 'Hartree', 'Bohr', False,
         'Tanh', False, False, 0.25, 'Adam',
         True, [], 'HuberLoss', 'lADS', 'total_loss',
         False, 1.0, False, False, 8),
        (0.003129, 0.022143)
    ),
    (
        ('test_2', True, 'general', 'element_energy-PBED3_def2TZVP', 'ACSF',
         '', 'lMLP', ['H', 'C', 'Cl', 'I'], False, True,
         'Tanhshrink', 'gaussian_cos', 'cos', 'linear', 'None',
         False, 'double', 'kJ_mol', 'pm', False,
         'Tanhshrink', True, False, 0.25, 'Rprop',
         True, [0, 2, 4], 'SmoothL1Loss', 'lADS', 'total_loss',
         True, 1.0, True, True, None),
        (2.749724, 0.376958)
    ),
    (
        ('test_3', True, 'general', 'element_energy-PBED3_def2TZVP', 'ACSF',
         '_radial', 'RuNNer', ['H', 'C', 'Cl', 'I'], False, True,
         'Tanhshrink', 'bump', 'bump', 'linear', 'write',
         False, 'double', 'kJ_mol', 'pm', True,
         'default', False, False, 0.25, 'CoRe',
         False, [], 'L1Loss', 'lADS', 'total_loss',
         True, 10.0, False, False, None),
        (17.04378, 4.696379)
    ),
    (
        ('test_4', True, 'periodic', 'element_energy-PBE0rD3', 'ACSF',
         '', 'lMLP', ['Li', 'O', 'Mn'], False, True,
         'sTanh', 'bump', 'bump', 'linear', 'None',
         True, 'double', 'eV', 'Angstrom', False,
         'sTanh', True, True, 0.25, 'CoRe',
         True, [], 'MSELoss', 'lADS', 'E+F_losses',
         True, 1.0, True, True, None),
        (0.006198, 0.022534)
    ),
    (
        ('test_5', True, 'periodic', 'element_energy-PBE0rD3', 'eeACSF',
         '_radial', 'lMLP', ['Li', 'O', 'Mn'], False, False,
         'sTanh', 'bump', 'bump', 'crss', 'None',
         True, 'float', 'eV', 'Angstrom', True,
         'sTanh', True, False, 0.25, 'Adam',
         False, [], 'HuberLoss', 'random', 'E+F_losses',
         True, 1.0, False, False, None),
        (0.025608, 0.003509)
    ),
    (
        ('test_6', True, 'QMMM', 'element_energy-GFN2', 'eeACSF',
         '_QMMM', 'lMLP', ['H', 'C', 'O'], True, True,
         'sTanh', 'bump', 'cos_int', 'linear', 'append',
         True, 'double', 'eV', 'Angstrom', True,
         'sTanh', True, True, 0.25, 'Adam',
         True, [], 'MSELoss', 'lADS', 'E+F_losses',
         True, 1.0, True, True, None),
        (0.012145, 0.318088)
    ),
    (
        ('test_7', True, 'QMMM', 'element_energy-GFN2', 'eeACSF',
         '_radial_QMMM', 'lMLP', ['H', 'C', 'O'], True, False,
         'sTanh', 'bump', 'cos_int', 'linear', 'None',
         True, 'double', 'eV', 'Angstrom', False,
         'sTanh', True, False, 0.25, 'CoRe',
         True, [], 'HuberLoss', 'random', 'E+F_losses',
         True, 1.0, False, False, None),
        (0.018337, 0.507737)
    )
], indirect=['lmlp_settings'])
def test_lmlp_prediction_only(lmlp_settings, expected, capsys):
    '''
    Loads an already trained lMLP and compares the obtained energy and forces with known good values.
    '''
    lMLP = lmlp.lMLP(lmlp_settings)
    lMLP.run()
    captured = capsys.readouterr().out
    E_RMSE = float(search(r'(?:Energy RMSE:\s+)(\d+\.\d+)', captured).group(1))
    F_RMSE = float(search(r'(?:Force RMSE:\s+)(\d+\.\d+)', captured).group(1))

    assert pytest.approx(expected[0], rel=1e-5, abs=1e-5) == E_RMSE, \
        f'ERROR: Energy RMSE is {E_RMSE} but it should be {expected[0]}.'
    assert pytest.approx(expected[1], rel=1e-5, abs=1e-5) == F_RMSE, \
        f'ERROR: Forces RMSE is {F_RMSE} but it should be {expected[1]}.'
