from typing import Callable

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen.initializers import Initializer

from egxc.utils.typing import FloatN, FloatNxF, FloatNxT
from egxc.xc_energy.features import (
    ueg_e_x,
    ueg_spin_pol_e_kin_factor,
    ueg_spin_pol_e_x_factor,
    weizsacker_kinetic_energy_density,
)
from egxc.xc_energy.functionals.base import BaseEnergyFunctional

from .nn.mlp import FeatureMLP


def scaled_shifted_softplus(
    x: jax.Array, scale: float = 1.0, min_value: float = -1.0, x_shift: float = 0.5413
) -> jax.Array:
    """
    Defaults chosen to match the value range of elu (used by Nagai et al. 2020),
    while still being zero preserving (x shift = log(e-1) = 0.5413).
    NOTE: when using this activation change variance scaling factor to 2.2 (with default init)
    """
    return nn.softplus(scale * (x + x_shift)) / scale + min_value


class Nagai2020(BaseEnergyFunctional):
    """
    mGGA level machine learnable exchange-correlation functional.
    Ryo Nagai, Ryosuke Akashi, and Osamu Sugino.
    “Completing Density Functional Theory by Machine Learning Hidden Messages from Molecules.”
    Npj Computational Materials 6, no. 1 (May 5, 2020): 1-8.
    https://doi.org/10.1038/s41524-020-0310-0.

    https://github.com/ml-electron-project/NNfunctional

    We made some slight modifications to the original implementation, to improve gradient based training:
    - we propose to replace the elu activation function with a scaled shifted softplus activation
      function which has the same value range as elu (used by Nagai et al. 2020), while still being
      zero preserving (x shift = log(e-1) = 0.5413) and infinitely smooth differentiable.
    - we added the correct UEG exchange energy density prefactor (3 / 4) * (3 / pi) ** (1 / 3).
    - we introduce a scale factor to improve gradient based training.
    - we introduce an offset to improve gradient based training, which can be unlearned by the bias
      term of the last layer, this is done since LDA is known to systematically underestimate the
      exchange correlation energy on stable organic molecules. By shifting the output at initialization,
      the training already starts from a better point in parameter space.
    """

    # Default used in their publication:
    n_layers: int = 4  # including the final output layer
    hidden_dim: int = 100
    activation: Callable[[jax.Array], jax.Array] = (
        scaled_shifted_softplus  # NOTE: original paper uses elu but scaled shifted softplus above might be better due to higher smoothness
    )
    kernel_init: Initializer = nn.initializers.variance_scaling(
        2.2,  # 1.5 for elu activation, 2.2 for scaled shifted softplus with default parameters
        'fan_in',
        'uniform',
    )
    epsilon: float = 1e-8
    is_graph_based = False
    initial_enhancement_offset: float = 0.15  # original paper uses 0.0. This is added to before the final activation as to not shift the value range.
    scale_factor: float = 0.1  # original paper uses 1.0
    orbital_free: bool = False

    def setup(self) -> None:
        self.net = FeatureMLP(
            self.n_layers,
            self.hidden_dim,
            self.activation,
            init_last_layer_to_zero=False,
            concatenate=True,
            kernel_init=self.kernel_init,
        )

    def xc_energy_density(  # type: ignore
        self, n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN
    ) -> FloatN:
        # https://github.com/ml-electron-project/NNfunctional/blob/master/metaGGA.py
        # NOTE: original paper omits the prefactor (3 / 4) * (3 / pi) ** (1 / 3) of the UEG exchange energy density
        e_x_unif = ueg_e_x(n)
        xi_t = ueg_spin_pol_e_x_factor(zeta)
        x_local = self.local_input_transform(n, zeta, s, tau)
        return e_x_unif * xi_t * self.learned_enhancement_factor(x_local)

    def non_local_xc_energy_density(
        self,
        non_local_grid_features: FloatNxF,
        n: FloatN,
        zeta: FloatN,
        s: FloatN,
        tau: FloatN,
    ) -> FloatN:
        e_x_unif = ueg_e_x(n)
        xi_t = ueg_spin_pol_e_x_factor(zeta)
        x_local = self.local_input_transform(n, zeta, s, tau)
        x_non_local = self.non_local_input_transform(non_local_grid_features)
        return e_x_unif * xi_t * self.learned_enhancement_factor(x_local + x_non_local)

    def local_input_transform(
        self, n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN
    ) -> FloatNxT:
        n_t = n ** (1 / 3)
        s_t = s
        xi_t = ueg_spin_pol_e_x_factor(zeta)
        d_zeta = ueg_spin_pol_e_kin_factor(zeta)
        if self.orbital_free:
            tau = weizsacker_kinetic_energy_density(n, s)
        tau_t = tau / (n_t**5 * 2 * d_zeta)
        x = jnp.stack([n_t, xi_t, s_t, tau_t], axis=-1)
        x = jnp.log(x + self.epsilon)
        return self.net._local_input(x)

    @nn.compact
    def non_local_input_transform(self, non_local_grid_features: FloatNxF) -> FloatNxT:
        x_non_local = nn.Dense(
            self.hidden_dim,
            kernel_init=nn.initializers.zeros,
            use_bias=False,
            name='non_local_input_transform',
        )(nn.silu(non_local_grid_features))
        return x_non_local

    def learned_enhancement_factor(self, x: FloatNxT) -> FloatN:
        # Apply activation and remaining MLP layers
        x = self.activation(x)
        x = self.net.mlp(x)[:, 0]
        x *= self.scale_factor
        return (
            1 + self.activation(x + self.initial_enhancement_offset)
        )  # See e.g. https://github.com/ml-electron-project/NNfunctional/blob/master/LSDA.py
