import jax.numpy as jnp

from egxc.utils.typing import FloatN
from egxc.xc_energy.features import ueg_e_x, wigner_seitz_radius
from egxc.xc_energy.functionals.base import BaseEnergyFunctional
from egxc.xc_energy.functionals.classical import lsda


def pz81_correlation_energy_density(
    n: FloatN, unused_zeta: FloatN | None = None
) -> (
    FloatN
):  # unused zeta argument added to match signature of other correlation functions
    """
    Compute the LDA perdew-zunger correlation energy density.
    Perdew, J. P.; Zunger, A. Self-Interaction Correction to Density-Functional
    Approximations for Many-Electron Systems. Phys. Rev. B 1981, 23 (10)
    https://doi.org/10.1103/PhysRevB.23.5048.

    Args:
        n (jax.Array): electron density
    """
    # Constants
    A = 0.0311
    B = -0.048
    C = 0.002
    D = -0.0116
    gamma = -0.1423
    beta1 = 1.0529
    beta2 = 0.3334

    rs = wigner_seitz_radius(n)
    rs_sqrt = jnp.sqrt(rs)
    e_c = jnp.where(
        rs >= 1,
        gamma / (1 + beta1 * rs_sqrt + beta2 * rs),
        A * jnp.log(rs) + B + C * rs * jnp.log(rs) + D * rs,
    )
    return e_c


class LDA(BaseEnergyFunctional):
    """
    Local density approximation.
    """

    key: str
    is_graph_based = False

    def setup(self):
        match self.key.lower():
            case 'vwn5':
                self._correlation_energy_density = lsda.vwn5_correlation_energy_density
            case 'pz81':
                self._correlation_energy_density = pz81_correlation_energy_density
            case 'pw92_spin_restricted':
                self._correlation_energy_density = lsda.pw92_correlation_energy_density

    def xc_energy_density(self, n: FloatN, _) -> FloatN:  # type: ignore
        e_x = ueg_e_x(n)
        zeta = jnp.zeros_like(n)
        e_c = self._correlation_energy_density(n, zeta)
        return e_x + e_c
