from typing import Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp

from egxc.utils.typing import FloatN
from egxc.xc_energy.features import (
    transform_tau_to_alpha,
    ueg_e_x,
    ueg_spin_pol_e_x_factor,
    weizsacker_kinetic_energy_density,
)
from egxc.xc_energy.functionals.base import BaseEnergyFunctional
from egxc.xc_energy.functionals.classical.lsda import pw92_correlation_energy_density

from .nn.mlp import FeatureMLP


def I_transform(x: jax.Array, a: float) -> jax.Array:
    ex = jnp.exp(x)
    return a / (1 + (a - 1) * ex) - 1


def _input_transform(
    n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN
) -> Tuple[FloatN, FloatN, FloatN, FloatN]:
    n_t = n ** (1 / 3)
    zeta_t = ueg_spin_pol_e_x_factor(zeta)
    s_t = s
    tau_t = transform_tau_to_alpha(n, zeta, s, tau)

    eps_log = 1e-5

    n_t = jnp.log(n_t + eps_log)
    zeta_t = jnp.log(zeta_t + eps_log)
    s_t = -jnp.expm1(-(s_t**2)) * jnp.log1p(s_t)
    tau_t = jnp.log1p(tau_t) - jnp.log(2)
    return n_t, zeta_t, s_t, tau_t


class XCDiff(BaseEnergyFunctional):
    """
    xc-diff functional introduced by
    Sebastian Dick and Marivi Fernandez-Serra.
    “Highly Accurate and Constrained Density Functional
    Obtained with Differentiable Programming.”
    Physical Review B 104, no. 16 (October 12, 2021): L161109.
    https://doi.org/10.1103/PhysRevB.104.L161109.
    """

    # Default used in their publication:
    n_layers: int = 4  # including the final output layer
    hidden_dim: int = 16
    is_graph_based = False
    orbital_free: bool = False

    def setup(self) -> None:
        initializer = nn.initializers.variance_scaling(
            2.2,  # fitting GELU activation
            'fan_in',
            'uniform',
        )
        self.x_net = FeatureMLP(
            self.n_layers, self.hidden_dim, nn.gelu, False, kernel_init=initializer
        )
        self.c_net = FeatureMLP(
            self.n_layers, self.hidden_dim, nn.gelu, False, kernel_init=initializer
        )

    def xc_energy_density(  # type: ignore
        self, n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN
    ) -> FloatN:
        if self.orbital_free:
            # Replace orbital-dependent kinetic energy density by an orbital-free approximation
            tau = weizsacker_kinetic_energy_density(n, s)
        n_t, zeta_t, s_t, tau_t = _input_transform(n, zeta, s, tau)
        ueg_limit_factor = s_t + jnp.tanh(tau_t) ** 2
        NNx = self.x_net(s_t, tau_t)
        Fx = 1 + I_transform(NNx[:, 0] * ueg_limit_factor, 1.147)
        e_x = ueg_e_x(n) * Fx

        NNc = self.c_net(n_t, zeta_t, s_t, tau_t)
        Fc = 1 + I_transform(NNc[:, 0] * ueg_limit_factor, 2)
        e_c = pw92_correlation_energy_density(n, zeta) * Fc
        return e_x + e_c
