"""
Adapted from H. Helal's and A Fitzgibbon's
"MESS: Modern Electronic Structure Simulations" package
https://arxiv.org/abs/2406.03121
"""

import jax.numpy as jnp

from egxc.utils.typing import FloatN
from egxc.xc_energy.features import (
    fermi_wave_vector,
    transform_s_to_abs_grad_n,
    ueg_e_x,
    ueg_spin_pol_e_x_factor,
)
from egxc.xc_energy.functionals.base import (
    BaseEnergyFunctional,
)
from egxc.xc_energy.functionals.classical import lsda


def phi_fn(zeta: FloatN) -> FloatN:
    """
    Spin rescaling factor for the reduced density gradient
    """
    return 0.5 * ((1 + zeta) ** (2 / 3) + (1 - zeta) ** (2 / 3))


def e_x_b88(n: FloatN, zeta: FloatN, s: FloatN) -> FloatN:
    beta = 0.0042 * 2 ** (1 / 3)
    abs_grad_n = transform_s_to_abs_grad_n(n, s)
    x = abs_grad_n / n ** (4 / 3)
    d = 1 + 6 * beta * x * jnp.arcsinh(2 ** (1 / 3) * x)
    e_unpol = ueg_e_x(n) - beta * n ** (1 / 3) * x**2 / d
    return e_unpol * ueg_spin_pol_e_x_factor(zeta)


def e_x_pbe(n: FloatN, zeta: FloatN, s: FloatN) -> FloatN:
    beta = 0.066725  # Eq 4
    mu = beta * jnp.pi**2 / 3  # Eq 12
    kappa = 0.8040  # Eq 14
    s_tilde = s / phi_fn(zeta)
    F = 1 + kappa - kappa / (1 + mu * s_tilde**2 / kappa)
    return ueg_e_x(n) * ueg_spin_pol_e_x_factor(zeta) * F


def e_c_pbe(n: FloatN, zeta: FloatN, s: FloatN) -> FloatN:
    # Use full-precision constants
    beta = 0.06672455060314922  # https://gitlab.com/libxc/libxc/-/blob/devel/maple/gga_exc/gga_c_pbe.mpl
    gamma = (1 - jnp.log(2.0)) / jnp.pi**2
    phi = 0.5 * (jnp.power(1 + zeta, 2 / 3) + jnp.power(1 - zeta, 2 / 3))
    ec_pw = lsda.pw92_correlation_energy_density(n, zeta)
    # Use expm1 for better accuracy: 1/(exp(x)-1) = 1/expm1(x)
    A = beta / gamma / jnp.expm1(-ec_pw / (gamma * phi**3))  # Eq 8
    kf = fermi_wave_vector(n)
    ks = jnp.sqrt(4 * kf / jnp.pi)
    abs_grad_n = transform_s_to_abs_grad_n(n, s)
    t = abs_grad_n / (2 * phi * ks * n)
    delta = beta / gamma * t**2 * (1 + A * t**2) / (1 + A * t**2 + A**2 * t**4)
    H = gamma * phi**3 * jnp.log1p(delta)  # Eq 7, use log1p for accuracy
    return ec_pw + H


def e_c_lyp(
    n: FloatN,
    unused_zeta: FloatN,  # to match signature of e_c_pbe
    s: FloatN,
) -> FloatN:
    a = 0.04918
    b = 0.132
    c = 0.2533
    d = 0.349
    CF = 0.3 * (3 * jnp.pi**2) ** (2 / 3)

    x_n = n ** (-1 / 3)

    v = 1 + d * x_n
    omega = jnp.exp(-c * x_n) / v * n ** (-11 / 3)
    delta = c * x_n + d * x_n / v
    abs_grad_n = transform_s_to_abs_grad_n(n, s)
    g = (1 / 24 + 7 * delta / 72) * n * abs_grad_n**2

    e_c = -a / v - a * b * omega * (CF * n ** (11 / 3) - g)
    return e_c


def e_x_local_b3lyp(
    n: FloatN, zeta: FloatN, s: FloatN, unused_tau: FloatN | None = None, **unused_kwargs
) -> FloatN:
    """
    The exchange energy density term used in B3LYP functional.
    """
    return 0.08 * ueg_e_x(n) + 0.72 * e_x_b88(n, zeta, s)


def e_c_local_b3lyp(
    n: FloatN, zeta: FloatN, s: FloatN, unused_tau: FloatN | None = None, **unused_kwargs
) -> FloatN:
    """
    The correlation energy density term used in B3LYP functional.
    """
    vwn_c = (1 - 0.81) * lsda.vwn5_correlation_energy_density(n, zeta, use_rpa=True)  # type: ignore
    lyp_c = 0.81 * e_c_lyp(n, zeta, s)
    return vwn_c + lyp_c


class GGA(BaseEnergyFunctional):
    """
    Local density approximation.
    """

    key: str
    is_graph_based = False

    def setup(self):
        match self.key:
            case 'pbe':
                self._exchange_energy_density = e_x_pbe
                self._correlation_energy_density = e_c_pbe
            case 'b88':
                self._exchange_energy_density = e_x_b88
                self._correlation_energy_density = e_c_pbe
            case 'lyp':
                self._exchange_energy_density = e_x_pbe
                self._correlation_energy_density = e_c_lyp
            case 'local_b3lyp':
                self._exchange_energy_density = e_x_local_b3lyp
                self._correlation_energy_density = e_c_local_b3lyp
            case _:
                raise ValueError('Invalid GGA type.')

    def xc_energy_density(  # type: ignore
        self, n: FloatN, zeta: FloatN, s: FloatN, unused_tau: FloatN | None = None
    ) -> FloatN:
        e_x = self._exchange_energy_density(n, zeta, s)
        e_c = self._correlation_energy_density(n, zeta, s)
        return e_x + e_c
