from typing import Tuple

import jax
import jax.numpy as jnp

from egxc.utils.typing import FloatN
from egxc.xc_energy.features import (
    ueg_e_x,
    ueg_spin_pol_e_x_factor,
    wigner_seitz_radius,
)
from egxc.xc_energy.functionals.base import BaseEnergyFunctional


def e_x_spin_polarized_uniform_electron_gas(n: FloatN, zeta: FloatN) -> FloatN:
    return ueg_e_x(n) * ueg_spin_pol_e_x_factor(zeta)


def vwn5_correlation_energy_density(n: FloatN, zeta: FloatN, use_rpa=False) -> FloatN:
    """
    Compute the VWN5 correlation energy density.
    Vosko, S. H.; Wilk, L.; Nusair, M. Accurate Spin-Dependent Electron Liquid
    Correlation Energies for Local Spin Density Calculations: A Critical Analysis.
    Can. J. Phys. 1980, 58 (8), 1200–1211. https://doi.org/10.1139/p80-159.

    Args:
        n (FloatN): electron density
        zeta (FloatN): spin polarization
        use_rpa (bool): whether to use the random phase approximation

    implementation taken from H. Helal's and A Fitzgibbon's
    "MESS: Modern Electronic Structure Simulations" package
    https://arxiv.org/abs/2406.03121
    """

    # https://gitlab.com/libxc/libxc/-/blob/devel/maple/vwn.mpl?ref_type=heads
    # paramagnetic (eps_0) / ferromagnetic (eps_1) / spin stiffness (alpha)
    A = jnp.array([0.0310907, 0.5 * 0.0310907, -1 / (6 * jnp.pi**2)])

    if use_rpa:
        x0 = jnp.array([-0.409286, -0.743294, -0.228344])
        b = jnp.array([13.0720, 20.1231, 1.06835])
        c = jnp.array([42.7198, 101.578, 11.4813])
    else:
        # https://math.nist.gov/DFTdata/atomdata/node5.html
        x0 = jnp.array([-0.10498, -0.32500, -4.75840e-3])
        b = jnp.array([3.72744, 7.06042, 1.13107])
        c = jnp.array([12.9352, 18.0578, 13.0045])

    def f(xi):
        u = (1 + xi) ** (4 / 3) + (1 - xi) ** (4 / 3) - 2
        v = 2 * (2 ** (1 / 3) - 1)
        return u / v

    def d2fdz20():
        grad_f = jax.grad(jax.grad(f))
        return grad_f(0.0).astype(float)

    F2 = d2fdz20()

    rs = jnp.power(3 / (4 * jnp.pi * n), 1 / 3).reshape(-1, 1)
    x = jnp.sqrt(rs).reshape(-1, 1)
    X = rs + b * x + c
    X0 = x0**2 + b * x0 + c
    Q = jnp.sqrt(4 * c - b**2)

    u = jnp.log(x**2 / X) + 2 * b / Q * jnp.arctan(Q / (2 * x + b))
    v = jnp.log((x - x0) ** 2 / X) + 2 * (b + 2 * x0) / Q * jnp.arctan(Q / (2 * x + b))
    ec = A * (u - b * x0 / X0 * v)
    e0, e1, alpha = ec.T
    beta = F2 * (e1 - e0) / alpha - 1
    eps_c = e0 + alpha * f(zeta) / F2 * (1 + beta * zeta**4)
    return eps_c


def _pw92_analytic_base_form(n: FloatN, coeffs: jax.Array) -> FloatN:
    """
    coeffs is a sequence of the form (p, A, a1, b1, b2, b3, b4)
    p and A are constrained the remaining parameters were fitted (see ref)
    """
    p, A, a1, b1, b2, b3, b4 = coeffs
    r_s = wigner_seitz_radius(n)
    beta_sum = b1 * r_s ** (1 / 2) + b2 * r_s + b3 * r_s ** (3 / 2) + b4 * r_s ** (p + 1)
    log_term = jnp.log1p(1 / (2 * A * beta_sum))
    return -2 * A * (1 + a1 * r_s) * log_term


def _pw92_correlation_components(
    n: FloatN, use_RPA=False, modified=False
) -> Tuple[FloatN, FloatN, FloatN]:
    A_unpolar = 0.031091 if not modified else 0.0310907
    A_polar = 0.015545 if not modified else 0.01554535
    A_alpha_c = 0.016887 if not modified else 0.0168869
    # Always vmap both branches for consistency and performance.
    if use_RPA:
        # fmt: off
        coeffs = jnp.array([
            #   p,      A,          a1,        b1,      b2,      b3,      b4
            [0.75, A_unpolar, 0.082477, 5.1486, 1.6483, 0.2347, 0.20614],   # unpolarized
            [0.75, A_polar,   0.035374, 6.4869, 1.3083, 0.15180, 0.082349], # polarized
            [1.0,  A_alpha_c, 0.028829, 10.357, 3.6231, 0.47990, 0.12279],  # alpha_c
        ])
    else:
        # fmt: off
        coeffs = jnp.array([
            #   p,      A,          a1,        b1,      b2,      b3,      b4
            [1.0,  A_unpolar, 0.21370,  7.5957, 3.5876, 1.6382,  0.49294],  # unpolarized
            [1.0,  A_polar,   0.20548, 14.1189, 6.1977, 3.3662,  0.62517],  # polarized
            [1.0,  A_alpha_c, 0.11125, 10.357,  3.6231, 0.88026, 0.49671],  # alpha_c
        ])
    # vmap over the first axis of pw92_coeffs, n is broadcasted
    out = jax.vmap(_pw92_analytic_base_form, in_axes=(None, 0))(n, coeffs)  # type: ignore
    return out[0], out[1], -out[2]


def pw92_correlation_energy_density(
    n: FloatN, zeta: FloatN, use_RPA=False, modified=False
) -> FloatN:
    """
    The correlation energy (LSDA1) of the uniform electron gas by
    Perdew and Wang (1992) (PW92).
    https://doi.org/10.1103/PhysRevB.45.13244,

    libxc reference implementation:
    https://github.com/ElectronicStructureLibrary/libxc/blob/master/src/lda_c_pw.c
    https://github.com/ElectronicStructureLibrary/libxc/blob/master/src/maple2c/lda_exc/lda_c_pw.c#L14
    https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/lda_exc/lda_c_pw.mpl
    https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/util.mpl
    """
    ec_unpolarized, ec_polarized, alpha_c = _pw92_correlation_components(
        n, use_RPA, modified
    )
    dd_f_zero = 1.709921 if not modified else 1.709920934161365617563962776245
    # alpha_c = dd_f_zero * (ec_polarized - ec_unpolarized)  # from another reference
    f_xi = ((1 + zeta) ** (4 / 3) + (1 - zeta) ** (4 / 3) - 2) / (2 ** (4 / 3) - 2)

    return (
        ec_unpolarized
        + alpha_c * f_xi / dd_f_zero * (1 - zeta**4)
        + (ec_polarized - ec_unpolarized) * f_xi * zeta**4
    )


class LSDA(BaseEnergyFunctional):
    """
    Local spin density approximation.
    """

    key: str
    is_graph_based = False

    def setup(self) -> None:
        match self.key.lower():
            case 'pw92':
                self._correlation_energy_density = pw92_correlation_energy_density
            case _:
                raise ValueError('Invalid correlation type.')

    def xc_energy_density(self, density: FloatN, zeta: FloatN) -> FloatN:  # type: ignore
        e_x = e_x_spin_polarized_uniform_electron_gas(density, zeta)
        e_c = self._correlation_energy_density(density, zeta)
        return e_x + e_c
