#!/usr/bin/python3

####################################################################################################

####################################################################################################

'''
Lifelong Machine Learning Potentials (lMLP)
'''
__copyright__ = '''This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.'''

####################################################################################################

####################################################################################################

from typing import Any, List
import torch
from torch import Tensor
from .standardization import Standardization


####################################################################################################

####################################################################################################

@torch.jit.interface
class ModuleInterface(torch.nn.Module):
    '''
    Interface for just-in-time compilation of HDNNP model with TorchScript, while the atomic neural
    networks can be accessed by integer literals
    '''

####################################################################################################

    def forward(self, input: torch.Tensor) -> torch.Tensor:   # type: ignore[empty-body]
        '''
        Return: energy_prediction_torch
        '''
        pass   # pylint: disable=unnecessary-pass


####################################################################################################

####################################################################################################

class HDNNP(torch.nn.Module):
    '''
    HDNNP model
    '''

####################################################################################################

    def __init__(self, n_element_types: int, n_descriptors: int, n_neurons_hidden_layers: List[int],
                 activation_function: Any, scale_shift_layer: bool) -> None:
        '''
        Initialization
        '''
        # initialize HDNNP parameters
        super().__init__()
        self.n_element_types = n_element_types
        self.n_descriptors = n_descriptors
        self.n_neurons_hidden_layers = n_neurons_hidden_layers
        self.N_hidden_layers = len(n_neurons_hidden_layers)
        self.activation_function = activation_function
        self.scale_shift_layer = scale_shift_layer

        # initialize HDNNP architecture
        self.atomic_neural_networks = torch.nn.ModuleList()
        for i_element in range(self.n_element_types):
            self.atomic_neural_networks.append(torch.nn.Sequential())
            if self.scale_shift_layer:
                self.atomic_neural_networks[i_element].append(Standardization(self.n_descriptors))
            self.atomic_neural_networks[i_element].append(
                torch.nn.Linear(self.n_descriptors, self.n_neurons_hidden_layers[0]))
            self.atomic_neural_networks[i_element].append(self.activation_function())
            for i_layer in range(self.N_hidden_layers - 1):
                self.atomic_neural_networks[i_element].append(
                    torch.nn.Linear(self.n_neurons_hidden_layers[i_layer],
                                    self.n_neurons_hidden_layers[i_layer + 1]))
                self.atomic_neural_networks[i_element].append(self.activation_function())
            self.atomic_neural_networks[i_element].append(
                torch.nn.Linear(self.n_neurons_hidden_layers[-1], 1))

####################################################################################################

    def forward(self, elements_int_sys: List[int], descriptors_torch: Tensor,
                n_atoms_sys: int) -> Tensor:
        '''
        Return: energy_prediction_torch
        '''
        # calculate energy prediction
        E = torch.empty(n_atoms_sys)
        for i in range(n_atoms_sys):
            ann: ModuleInterface = self.atomic_neural_networks[elements_int_sys[i]]
            E[i] = ann.forward(descriptors_torch[i])[0]
        energy_prediction_torch = torch.sum(E, 0, keepdim=True)

        return energy_prediction_torch


####################################################################################################

####################################################################################################

@torch.jit.script
def calculate_forces(energy_prediction_torch: Tensor, descriptors_torch: Tensor,
                     descriptor_i_derivatives_torch: List[Tensor],
                     descriptor_neighbor_derivatives_torch: List[Tensor],
                     neighbor_indices: List[List[int]], n_atoms_active: int,
                     create_graph: bool = True) -> Tensor:
    '''
    Return: forces_prediction_torch
    '''
    # initialize forces prediction
    forces_prediction_torch = torch.zeros((n_atoms_active, 3))
    # calculate model gradient
    model_gradient = torch.autograd.grad(
        [energy_prediction_torch], [descriptors_torch], create_graph=create_graph)[0]
    assert isinstance(model_gradient, Tensor)
    # combine model gradient and descriptor gradient
    for i in range(n_atoms_active):
        forces_prediction_torch[i] -= torch.sum(
            torch.t(model_gradient[i] * torch.t(descriptor_i_derivatives_torch[i])), 0)
        forces_prediction_torch[i] -= torch.sum(
            torch.t(torch.flatten(model_gradient[neighbor_indices[i]], -2, -1) * torch.t(
                torch.flatten(descriptor_neighbor_derivatives_torch[i], -3, -2))), 0)

    return forces_prediction_torch


####################################################################################################

@torch.jit.script
def calculate_forces_QMMM(energy_prediction_torch: Tensor, descriptors_torch: Tensor,
                          descriptor_i_derivatives_torch: List[Tensor],
                          descriptor_neighbor_derivatives_torch: List[Tensor],
                          neighbor_indices: List[List[int]], n_atoms_sys: int,
                          descriptor_neighbor_derivatives_torch_env: List[Tensor],
                          neighbor_indices_env: List[List[int]], n_atoms_active: int,
                          MM_gradients: List[int], create_graph: bool = True) -> Tensor:
    '''
    Return: forces_prediction_torch
    '''
    # initialize forces prediction
    forces_prediction_torch = torch.zeros((n_atoms_active, 3))
    # calculate model gradient
    model_gradient = torch.autograd.grad(
        [energy_prediction_torch], [descriptors_torch], create_graph=create_graph)[0]
    assert isinstance(model_gradient, Tensor)
    # combine model gradient and descriptor gradient
    for i in range(n_atoms_sys):
        forces_prediction_torch[i] -= torch.sum(
            torch.t(model_gradient[i] * torch.t(descriptor_i_derivatives_torch[i])), 0)
        forces_prediction_torch[i] -= torch.sum(
            torch.t(torch.flatten(model_gradient[neighbor_indices[i]], -2, -1) * torch.t(
                torch.flatten(descriptor_neighbor_derivatives_torch[i], -3, -2))), 0)
    for i in range(n_atoms_active - n_atoms_sys):
        forces_prediction_torch[i + n_atoms_sys] -= torch.sum(
            torch.t(torch.flatten(model_gradient[neighbor_indices_env[i]][:, MM_gradients], -2, -1) * torch.t(
                torch.flatten(descriptor_neighbor_derivatives_torch_env[i], -3, -2))), 0)

    return forces_prediction_torch
