#!/usr/bin/python3

'''
Lifelong Machine Learning Potentials (lMLP)
'''
__copyright__ = '''This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.'''

from os import remove
from pathlib import Path
import unittest
import numpy as np
from torch import load
from lmlp import lMLP_calculator

RESOURCES_DIR = str(Path(__file__).parent.absolute()) + '/resources/lmlp_calculator'


class lMLP_calculatorTest(unittest.TestCase):
    '''
    Lifelong Machine Learning Potentials (lMLP) calculator tests
    '''

    def test_1(self) -> None:
        '''
        Test 1
        '''
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_1.ini')
        checkpoint = load(f'{RESOURCES_DIR}/test_1_data.pt')
        elements = checkpoint['elements']
        positions = checkpoint['positions']
        energy_prediction = checkpoint['energy_prediction']
        energy = lmlp.predict(elements, positions, calc_forces=False)
        assert np.isclose(energy, energy_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy is {energy} but it should be {energy_prediction}.'

    def test_2(self) -> None:
        '''
        Test 2
        '''
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_2.ini')
        checkpoint = load(f'{RESOURCES_DIR}/test_2_data.pt')
        elements = checkpoint['elements']
        positions = checkpoint['positions']
        energy_prediction = checkpoint['energy_prediction']
        forces_prediction = checkpoint['forces_prediction']
        energy, forces = lmlp.predict(elements, positions)
        assert np.isclose(energy, energy_prediction, rtol=1e-5, atol=1e-5), \
            f'ERROR: Energy is {energy} but it should be {energy_prediction}.'
        assert np.allclose(forces, forces_prediction, rtol=1e-5, atol=1e-5), \
            f'ERROR: Forces are {forces} but they should be {forces_prediction}.'

    def test_3(self) -> None:
        '''
        Test 3
        '''
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_3.ini')
        checkpoint = load(f'{RESOURCES_DIR}/test_3_data.pt')
        elements = checkpoint['elements']
        positions = checkpoint['positions']
        lattices = checkpoint['lattices']
        energy_prediction = checkpoint['energy_prediction']
        forces_prediction = checkpoint['forces_prediction']
        energy, forces = lmlp.predict(elements, positions, lattices)
        assert np.isclose(energy, energy_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy is {energy} but it should be {energy_prediction}.'
        assert np.allclose(forces, forces_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Forces are {forces} but they should be {forces_prediction}.'

    def test_4(self) -> None:
        '''
        Test 4
        '''
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_4.ini')
        checkpoint = load(f'{RESOURCES_DIR}/test_4_data.pt')
        elements = checkpoint['elements']
        positions = checkpoint['positions']
        atomic_classes = checkpoint['atomic_classes']
        atomic_charges = checkpoint['atomic_charges']
        energy_prediction = checkpoint['energy_prediction']
        forces_prediction = checkpoint['forces_prediction']
        energy, forces = lmlp.predict(
            elements, positions, atomic_classes=atomic_classes, atomic_charges=atomic_charges)
        assert np.isclose(energy, energy_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy is {energy} but it should be {energy_prediction}.'
        assert np.allclose(forces, forces_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Forces are {forces} but they should be {forces_prediction}.'

    def test_5(self) -> None:
        '''
        Test 5
        '''
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_5.ini')
        checkpoint = load(f'{RESOURCES_DIR}/test_5_data.pt')
        elements = checkpoint['elements']
        positions = checkpoint['positions']
        atomic_classes = checkpoint['atomic_classes']
        atomic_charges = checkpoint['atomic_charges']
        energy_prediction = checkpoint['energy_prediction']
        forces_prediction = checkpoint['forces_prediction']
        energy, forces = lmlp.predict(
            elements, positions, atomic_classes=atomic_classes, atomic_charges=atomic_charges)
        assert np.isclose(energy, energy_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy is {energy} but it should be {energy_prediction}.'
        assert np.allclose(forces, forces_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Forces are {forces} but they should be {forces_prediction}.'

    def test_6(self) -> None:
        '''
        Test 6
        '''
        checkpoint = load(f'{RESOURCES_DIR}/test_6_data.pt')
        elements = checkpoint['elements']
        positions = checkpoint['positions']
        energy_prediction = checkpoint['energy_prediction']
        forces_prediction = checkpoint['forces_prediction']
        energy_uncertainty_1 = checkpoint['energy_uncertainty_1']
        forces_uncertainty_1 = checkpoint['forces_uncertainty_1']
        energy_uncertainty_2 = checkpoint['energy_uncertainty_2']
        forces_uncertainty_2 = checkpoint['forces_uncertainty_2']
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_6.ini',
                               uncertainty_scaling=20.0)
        energy, energy_uncertainty = lmlp.predict(
            elements, positions, calc_forces=False, calc_uncertainty=True)
        energy, forces, energy_uncertainty, forces_uncertainty = lmlp.predict(
            elements, positions, calc_uncertainty=True)
        assert np.isclose(energy, energy_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy is {energy} but it should be {energy_prediction}.'
        assert np.allclose(forces, forces_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Forces are {forces} but they should be {forces_prediction}.'
        assert np.isclose(energy_uncertainty, energy_uncertainty_1, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy uncertainty is {energy_uncertainty} but it should be {energy_uncertainty_1}.'
        assert np.allclose(forces_uncertainty, forces_uncertainty_1, rtol=1e-5, atol=1e-8), \
            f'ERROR: Force uncertainties are {forces_uncertainty} but they should be {forces_uncertainty_1}.'
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_6.ini',
                               active_learning_file=f'{RESOURCES_DIR}/input.data_active_learning_tmp',
                               active_learning_thresholds=(0.0, 3.0, 3.0))
        energy, forces, energy_uncertainty, forces_uncertainty = lmlp.predict(
            elements, positions, name='1')
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_6.ini',
                               active_learning_file=f'{RESOURCES_DIR}/input.data_active_learning_tmp',
                               active_learning_thresholds=(3.0, 0.0, 3.0))
        energy, forces, energy_uncertainty, forces_uncertainty = lmlp.predict(
            elements, positions, name='1')
        remove(f'{RESOURCES_DIR}/input.data_active_learning_tmp')
        assert np.isclose(energy_uncertainty, energy_uncertainty_2, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy uncertainty is {energy_uncertainty} but it should be {energy_uncertainty_2}.'
        assert np.allclose(forces_uncertainty, forces_uncertainty_2, rtol=1e-5, atol=1e-8), \
            f'ERROR: Force uncertainties are {forces_uncertainty} but they should be {forces_uncertainty_2}.'

    def test_7(self) -> None:
        '''
        Test 7
        '''
        checkpoint = load(f'{RESOURCES_DIR}/test_7_data.pt')
        elements = checkpoint['elements']
        positions = checkpoint['positions']
        atomic_classes = checkpoint['atomic_classes']
        atomic_charges = checkpoint['atomic_charges']
        energy_prediction = checkpoint['energy_prediction']
        forces_prediction = checkpoint['forces_prediction']
        energy_uncertainty_1 = checkpoint['energy_uncertainty_1']
        forces_uncertainty_1 = checkpoint['forces_uncertainty_1']
        energy_uncertainty_2 = checkpoint['energy_uncertainty_2']
        forces_uncertainty_2 = checkpoint['forces_uncertainty_2']
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_7.ini',
                               uncertainty_scaling=1000.0)
        energy, energy_uncertainty = lmlp.predict(
            elements, positions, atomic_classes=atomic_classes, atomic_charges=atomic_charges,
            calc_forces=False, calc_uncertainty=True)
        energy, forces, energy_uncertainty, forces_uncertainty = lmlp.predict(
            elements, positions, atomic_classes=atomic_classes, atomic_charges=atomic_charges,
            calc_uncertainty=True)
        assert np.isclose(energy, energy_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy is {energy} but it should be {energy_prediction}.'
        assert np.allclose(forces, forces_prediction, rtol=1e-5, atol=1e-8), \
            f'ERROR: Forces are {forces} but they should be {forces_prediction}.'
        assert np.isclose(energy_uncertainty, energy_uncertainty_1, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy uncertainty is {energy_uncertainty} but it should be {energy_uncertainty_1}.'
        assert np.allclose(forces_uncertainty, forces_uncertainty_1, rtol=1e-5, atol=1e-8), \
            f'ERROR: Force uncertainties are {forces_uncertainty} but they should be {forces_uncertainty_1}.'
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_7.ini',
                               active_learning_file=f'{RESOURCES_DIR}/input.data_active_learning_tmp',
                               active_learning_thresholds=(0.0, 3.0, 3.0))
        energy, forces, energy_uncertainty, forces_uncertainty = lmlp.predict(
            elements, positions, atomic_classes=atomic_classes, atomic_charges=atomic_charges, name='1')
        lmlp = lMLP_calculator(generalization_setting_file=f'{RESOURCES_DIR}/test_7.ini',
                               active_learning_file=f'{RESOURCES_DIR}/input.data_active_learning_tmp',
                               active_learning_thresholds=(3.0, 0.0, 3.0))
        energy, forces, energy_uncertainty, forces_uncertainty = lmlp.predict(
            elements, positions, atomic_classes=atomic_classes, atomic_charges=atomic_charges, name='1')
        remove(f'{RESOURCES_DIR}/input.data_active_learning_tmp')
        assert np.isclose(energy_uncertainty, energy_uncertainty_2, rtol=1e-5, atol=1e-8), \
            f'ERROR: Energy uncertainty is {energy_uncertainty} but it should be {energy_uncertainty_2}.'
        assert np.allclose(forces_uncertainty, forces_uncertainty_2, rtol=1e-5, atol=1e-8), \
            f'ERROR: Force uncertainties are {forces_uncertainty} but they should be {forces_uncertainty_2}.'
