from functools import partial
from typing import Callable

import flax.linen as nn
import jax
import jax.numpy as jnp

from egxc.utils.typing import FloatAxN, FloatAxNx4, FloatN, FloatNx4
from egxc.xc_energy.features import (
    transform_tau_to_alpha,
    ueg_spin_pol_e_x_factor,
    ueg_tau,
)
from egxc.xc_energy.functionals.base import BaseEnergyFunctional
from egxc.xc_energy.functionals.classical.mgga import e_c_scan, e_x_scan

from .nn.mlp import DoubleMLPWithCrossConnections


@jax.jit
def _nagai22_tau_normalisation(n: FloatN, tau: FloatN) -> FloatN:
    """
    The normalisation of the kinetic energy density based on the uniform electron gas limit
    used in Nagai et al. 2022 (see Equations 11, 12). Note that this differs from the usual
    normalisation in the literature, which subtracts the Weizsäcker kinetic energy density.
    """
    tau_unif = ueg_tau(n)
    nagai_tau = (tau - tau_unif) / tau_unif
    return nagai_tau


@jax.jit
def _nagai22_softplus(x: jax.Array) -> jax.Array:
    """
    The adapted softplus function used in Nagai et al. 2022 (see Appendix A).
    Both itself and its derivative take the value 1 at the origin.
    """
    return nn.softplus(2 * jnp.log(2) * x) / jnp.log(2)


@jax.jit
def _nagai22_scalarisation_function(x: jax.Array) -> jax.Array:
    """
    The scalarisation function l(x) = tanh(‖x‖²) used in Nagai et al. 2022 (see Equation 2).
    It is applied over the final dimension of the input array.
    """
    return jnp.tanh(jnp.sum(x**2, axis=-1))


@jax.jit
def _nagai22_input_transform(
    n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN
) -> tuple[FloatN, FloatN, FloatN, FloatN]:
    """
    The mGGA descriptor transformation applied in Nagai et al. 2022 (see Equation 10).
    Their key trick is to convert limits at ±inf to limits at +-1 by virtue of tanh.
    WARNING: their descriptor definitions differ from the default in the literature as follows:
    - the reduced density gradient s is missing a constant factor of 2(3π²)^(1/3)
    - the normalised kinetic energy density is defined as (tau - tau_unif) / tau_unif rather than
      (tau - tau_w) / tau_unif, i.e., the Weizsäcker kinetic energy density tau_w is not used
    - the UEG kinetic energy density tau_unif is missing the spin scaling factor d(zeta)
    """
    n_t = jnp.tanh(n ** (1 / 3))
    zeta_t = jnp.tanh(ueg_spin_pol_e_x_factor(zeta))
    s_t = jnp.tanh(2 * (3 * jnp.pi**2) ** (1 / 3) * s)
    tau_t = jnp.tanh(_nagai22_tau_normalisation(n, tau))

    return n_t, zeta_t, s_t, tau_t


@jax.jit
@partial(jax.vmap, in_axes=(0, None, None))
def _nagai22_basis_polynomials(i, diff_j, diff_ij):
    """
    The basis polynomials for the Lagrange interpolation in Nagai et al. 2022 (see Equation 2).

    i: the index of the basis polynomial to compute (0 <= i < #constraints)
    diff_j: the scalarisation function l(x - x_j) for each constraint j
    diff_ij: the scalarisation function l(x_i - x_j) for each pair of constraints (i, j)
    """
    PRODUCT_EPSILON = 1e-7
    diff_j = jnp.roll(diff_j, -i, axis=0)[1:]  # shape: (#constraints, #grid_points)
    diff_ij = jnp.roll(diff_ij[i], -i, axis=0)[1:]  # shape: (#constraints, #grid_points)
    numerator = jnp.prod(jnp.where(diff_j < PRODUCT_EPSILON, 1.0, diff_j), axis=0)
    denominator = jnp.prod(jnp.where(diff_ij < PRODUCT_EPSILON, 1.0, diff_ij), axis=0)
    return numerator / denominator  # shape: (#grid_points)


@jax.jit
def _nagai22_lagrange_polynomial(x, fx, x0, fx0, f0):
    """
    The Lagrange interpolation polynomial used in Nagai et al. 2022 (see Equation 2).
    The input variable names match the notation in Equation 1.
    """
    values = fx - fx0 + f0
    num_constraints = x0.shape[0]
    # shape: (#constraints, #grid_points)
    diff_j = _nagai22_scalarisation_function(x - x0)
    # shape: (#constraints, #constraints, #grid_points)
    diff_ij = _nagai22_scalarisation_function(
        jnp.expand_dims(x0, 0) - jnp.expand_dims(x0, 1)
    )
    basis_polynomials = _nagai22_basis_polynomials(
        jnp.arange(num_constraints), diff_j, diff_ij
    )
    normz_factor = 1 / jnp.sum(basis_polynomials, axis=0)
    lagrange_polynomial = normz_factor * jnp.vecdot(basis_polynomials, values, axis=0)
    return lagrange_polynomial


class Nagai2022(BaseEnergyFunctional):
    """
    Ryo Nagai, Ryosuke Akashi, and Osamu Sugino.
    ”Machine-learning-based exchange correlation functional with physical asymptotic constraints.”
    Physical Review Research 4 (11 February 2022): 013106.
    https://doi.org/10.1103/PhysRevResearch.4.013106.
    """

    # Default used in their publication:
    n_layers: int = 4  # including the final output layer
    hidden_dim: int = 100
    activation: Callable[[jax.Array], jax.Array] = _nagai22_softplus
    is_graph_based = False

    def setup(self) -> None:
        self.xc_net = DoubleMLPWithCrossConnections(
            dims=(self.n_layers - 1) * [self.hidden_dim] + [1],
            cross_connections=[2],
            activation=self.activation,
            init_last_layer_to_zero=False,
            apply_activation_to_output=True,
        )

    def call_NN(self, features: FloatNx4) -> tuple[jax.Array, jax.Array]:
        y1, y2 = self.xc_net(features[..., 2:], features)
        return jnp.squeeze(y1), jnp.squeeze(y2)

    def xc_energy_density(  # type: ignore
        self, n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN
    ) -> FloatN:
        n_t, zeta_t, s_t, tau_t = _nagai22_input_transform(n, zeta, s, tau)
        features = jnp.stack([n_t, zeta_t, s_t, tau_t], axis=-1)

        NN_x, NN_c = self.call_NN(features)
        alpha = transform_tau_to_alpha(n, zeta, s, tau)
        e_x = e_x_scan(n, s, alpha) * self.e_x_nagai2022(features, NN_x)
        e_c = e_c_scan(n, zeta, s, alpha) * self.e_c_nagai2022(features, NN_c)

        return e_x + e_c

    def e_x_nagai2022(self, x: FloatNx4, fx: FloatN) -> FloatN:
        x0, f0 = self.make_constraints_x(x)
        fx0, _ = self.call_NN(x0)  # shape: (#constraints, #grid_points)

        # project onto (s_t, tau_t) for the lagrange polynomial
        x = x[..., 2:]
        x0 = x0[..., 2:]

        return _nagai22_lagrange_polynomial(x, fx, x0, fx0, f0)

    def e_c_nagai2022(self, x: FloatNx4, fx: FloatN) -> FloatN:
        x0, f0 = self.make_constraints_c(x)
        _, fx0 = self.call_NN(x0)  # shape: (#constraints, #grid_points)

        return _nagai22_lagrange_polynomial(x, fx, x0, fx0, f0)

    def make_constraints_x(self, features: FloatNx4) -> tuple[FloatAxNx4, FloatAxN]:
        n_t, zeta_t, _, tau_t = tuple(jnp.transpose(features))
        zeros = jnp.zeros_like(n_t)
        ones = jnp.ones_like(n_t)

        x0_x = jnp.stack(  # shape: (#constraints, #grid_points, 4)
            [
                jnp.stack([n_t, zeta_t, zeros, zeros], axis=-1),  # constraint X3
                jnp.stack([n_t, zeta_t, ones, tau_t], axis=-1),  # constraint X4
            ],
            axis=0,
        )
        f0_x = jnp.stack([ones, ones], axis=0)  # shape: (#constraints, #grid_points)

        return x0_x, f0_x

    def make_constraints_c(self, features: FloatNx4) -> tuple[FloatAxNx4, FloatAxN]:
        n_t, zeta_t, s_t, tau_t = tuple(jnp.transpose(features))
        zeros = jnp.zeros_like(n_t)
        ones = jnp.ones_like(n_t)

        # Forward passes for constraint C2
        features_c0 = jnp.stack([n_t, zeros, s_t, tau_t], axis=-1)
        features_c00 = jnp.stack([zeros, zeros, s_t, tau_t], axis=-1)
        _, f_c0 = self.call_NN(features_c0)
        _, f_c00 = self.call_NN(features_c00)

        x0_c = jnp.stack(
            [  # (#constraints, #grid_points, 4)
                jnp.stack([n_t, zeta_t, zeros, zeros], axis=-1),  # constraint C1
                jnp.stack([zeros, zeta_t, s_t, tau_t], axis=-1),  # constraint C2
                jnp.stack([ones, zeta_t, s_t, tau_t], axis=-1),  # constraint C4
            ],
            axis=0,
        )
        f0_c = jnp.stack(
            [ones, f_c0 - f_c00 + ones, ones], axis=0
        )  # shape: (constraints, #grid_points)

        return x0_c, f0_c
