from typing import Any, Tuple

import jax.numpy as jnp
import jax.scipy as jsp

from egxc.utils.linalg import safe_norm
from egxc.utils.typing import (
    PRECISION,
    Float1,
    Float2xBxB,
    FloatBxB,
    FloatBxBxBxB,
    FloatN,
    FloatNx3,
    FloatNxB,
    FloatNxBx3,
    FloatQxBxB,
)
from egxc.xc_energy.features import (
    combine_from_spin_resolved,
    fermi_wave_vector,
    ueg_e_x,
    ueg_tau,
    wigner_seitz_radius,
)
from egxc.xc_energy.functionals.base import BaseEnergyFunctional
from egxc.xc_energy.functionals.classical.hybrid import (
    density_fitted_exact_exchange,
    exact_exchange,
)
from egxc.xc_energy.functionals.classical.lsda import (
    _pw92_correlation_components,
    pw92_correlation_energy_density,
)
from egxc.xc_energy.functionals.dispersion.vv10 import (
    VV10_PARAMS,
    VV10_wB97M_V_PARAMS,
    vv10_energy,
)


def range_separation_factor(n: FloatN, omega: float) -> FloatN:
    """
    Compute the range-separation exchange factor F_sigma.
    Args:
        n: electron density in bohr^-3
        omega: range-separation parameter in bohr^-1
    """
    k_f = fermi_wave_vector(n)
    a = omega / k_f
    erf_term = 2 * jnp.sqrt(jnp.pi) * jsp.special.erf(1 / a)
    exp_term = (2 * a - a**3) * jnp.exp(-1 / a**2)

    return 1.0 - (2 / 3) * a * (erf_term - 3 * a + a**3 + exp_term)


class BaseRangeSeparatedHybrid(BaseEnergyFunctional):
    """Base class for range separated hybrid functionals with separate HF fractions."""

    short_range_fraction: float
    long_range_fraction: float
    use_density_fitting: bool
    spin_restricted: bool

    is_graph_based = False

    def __call__(
        self, weights: FloatN, *feats: FloatN, **non_local_kwargs: Any
    ) -> Float1:
        e_loc = self.integrate_energy_density(weights, *feats)
        e_glob = self.non_local_contribution(**non_local_kwargs)
        return e_loc + e_glob

    def non_local_contribution(self, **non_local_kwargs: Any) -> Float1:
        return self.exact_exchange_contribution(**non_local_kwargs)

    def exact_exchange_contribution(
        self,
        density_matrix: FloatBxB | Float2xBxB,
        eri_sr_tensor: FloatBxBxBxB | FloatQxBxB,
        eri_lr_tensor: FloatBxBxBxB | FloatQxBxB,
        **unused_non_local_kwargs: Any,  # additional kwargs for child classes
    ) -> Float1:
        assert eri_sr_tensor is not None and eri_lr_tensor is not None, (
            'ERI tensors must be provided'
        )
        if not self.use_density_fitting:
            e_sr = exact_exchange(density_matrix, eri_sr_tensor, self.spin_restricted)
            e_lr = exact_exchange(density_matrix, eri_lr_tensor, self.spin_restricted)
        else:
            e_sr = density_fitted_exact_exchange(
                density_matrix, eri_sr_tensor, self.spin_restricted
            )
            e_lr = density_fitted_exact_exchange(
                density_matrix, eri_lr_tensor, self.spin_restricted
            )
        out = self.short_range_fraction * e_sr + self.long_range_fraction * e_lr
        assert out.dtype == PRECISION.xc_energy
        return out


class wB97M_V(BaseRangeSeparatedHybrid):
    """
    wB97M-V range separated hybrid functional.

    This functional requires actual spin-resolved density matrices to work correctly.
    Use SpinResolvedDensityFeatures to compute the required features.

    LibXC Reference Implementations:
    https://gitlab.com/libxc/libxc/-/blob/devel/maple/b97mv.mpl
    https://gitlab.com/libxc/libxc/-/blob/devel/src/hyb_gga_xc_wb97.c
    https://gitlab.com/libxc/libxc/-/blob/devel/src/hyb_mgga_xc_wb97mv.c
    https://gitlab.com/libxc/libxc/-/blob/devel/maple/mgga_exc/hyb_mgga_xc_wb97mv.mpl
    """

    short_range_fraction: float = 0.15  # a_SR^HF
    long_range_fraction: float = 1.0  # a_LR^HF
    omega: float = 0.30  # bohr^-1
    use_density_fitting: bool = False
    match_pyscf: bool = True  # controls VV10 parameterization
    spin_restricted: bool = False  # Requires spin-resolved features

    requires_spin_resolved_features: bool = (
        False  # Can handle both resolved and unresolved
    )

    # --- ωB97M-V semilocal parameterization (from LibXC hyb_mggaxc_wb97mv.c) ---
    # Exchange coefficients (u^0 w^0, u^1 w^0, u^0 w^1)
    _coeff_x_c0: float = 0.85
    _coeff_x_cu: float = 1.007
    _coeff_x_cw: float = 0.259

    # Same-spin correlation coefficients
    # maps to exponents: (w^0 u^0), (w^0 u^4), (w^1 u^0), (w^2 u^0), (w^4 u^3)
    _coeff_ss_c0: float = 0.443
    _coeff_ss_c1: float = -1.437
    _coeff_ss_c2: float = -4.535
    _coeff_ss_c3: float = -3.390
    _coeff_ss_c4: float = 4.278

    # Opposite-spin correlation coefficients
    # maps to exponents: (w^0 u^0), (w^1 u^0), (w^2 u^0), (w^2 u^1), (w^6 u^0), (w^6 u^1)
    _coeff_os_c0: float = 1.0
    _coeff_os_c1: float = 1.358
    _coeff_os_c2: float = 2.924
    _coeff_os_c3: float = -8.812
    _coeff_os_c4: float = -1.390
    _coeff_os_c5: float = 9.142

    _gamma_x = 0.004
    _gamma_ss = 0.2
    _gamma_os = 0.006

    # Screening thresholds (LibXC-style)
    # From hyb_mggaxc_wb97mv.c line 75: 1e-13 is the tolerance parameter
    _zeta_threshold: float = 1e-13
    _dens_threshold: float = 1e-13

    def __call__(
        self, weights: FloatN, *feats: FloatN, **non_local_kwargs: Any
    ) -> Float1:
        # Handle both spin-resolved and non-spin-resolved features
        # spin-resolved: (n_up, n_dn, grad_n_up, grad_n_dn, tau_up, tau_dn) - 6 args
        # non-spin-resolved: (n, zeta, s, tau) - 4 args
        if len(feats) == 6:
            e_loc = self.integrate_spin_resolved_energy_density(weights, *feats)
        elif len(feats) == 4:
            e_loc = self.integrate_energy_density(weights, *feats)
        else:
            raise ValueError(
                f'Expected 4 (non-spin-resolved) or 6 (spin-resolved) features, got {len(feats)}'
            )
        e_glob = self.non_local_contribution(**non_local_kwargs)
        return e_loc + e_glob

    def _ux(self, gamma: float, x: FloatN) -> FloatN:
        return (gamma * x * x) / (1.0 + gamma * x * x)

    def _wx_ss(self, t: FloatN) -> FloatN:
        # Maple uses w = (K - t)/(K + t), with t = 2 * tau_sigma / n_sigma^(5/3)
        K = (3.0 / 10.0) * (6.0 * jnp.pi**2) ** (2.0 / 3.0)
        return (K - t) / (K + t)

    def _wx_os(self, t_up: FloatN, t_dn: FloatN) -> FloatN:
        # w_os = (K*(t_up + t_dn) - 2 t_up t_dn)/(K*(t_up + t_dn) + 2 t_up t_dn)
        K = (3.0 / 10.0) * (6.0 * jnp.pi**2) ** (2.0 / 3.0)
        num = K * (t_up + t_dn) - 2.0 * t_up * t_dn
        den = K * (t_up + t_dn) + 2.0 * t_up * t_dn
        return num / den

    def _g_x(self, s_sigma: FloatN, t_sigma: FloatN) -> FloatN:
        # g for exchange: terms (u^0 w^0), (u^1 w^0), (u^0 w^1)
        u = self._ux(self._gamma_x, s_sigma)
        w = self._wx_ss(t_sigma)
        c0 = self._coeff_x_c0
        c_u = self._coeff_x_cu
        c_w = self._coeff_x_cw
        return c0 + c_u * u + c_w * w

    def _g_ss_single(self, s_sigma: FloatN, t_sigma: FloatN) -> FloatN:
        # Same-spin correlation single-channel g: exponents (w^0 u^0), (w^0 u^4), (w^1 u^0), (w^2 u^0), (w^4 u^3)
        c0 = self._coeff_ss_c0
        c1 = self._coeff_ss_c1
        c2 = self._coeff_ss_c2
        c3 = self._coeff_ss_c3
        c4 = self._coeff_ss_c4
        u = self._ux(self._gamma_ss, s_sigma)
        w = self._wx_ss(t_sigma)
        return c0 + c1 * (u**4) + c2 * w + c3 * (w**2) + c4 * (w**4) * (u**3)

    def _g_os(self, s_up: FloatN, s_dn: FloatN, t_up: FloatN, t_dn: FloatN) -> FloatN:
        # Opposite-spin correlation: exponents (w^0 u^0), (w^1 u^0), (w^2 u^0), (w^2 u^1), (w^6 u^0), (w^6 u^1)
        c0 = self._coeff_os_c0
        c1 = self._coeff_os_c1
        c2 = self._coeff_os_c2
        c3 = self._coeff_os_c3
        c4 = self._coeff_os_c4
        c5 = self._coeff_os_c5
        s_bar = jnp.sqrt(0.5 * (s_up * s_up + s_dn * s_dn))
        u = self._ux(self._gamma_os, s_bar)
        w = self._wx_os(t_up, t_dn)
        return c0 + c1 * w + c2 * (w**2) + c3 * (w**2) * u + c4 * (w**6) + c5 * (w**6) * u

    def _sr_attenuation_erf(self, a: FloatN) -> FloatN:
        # Maple attenuation_erf with large-a smoothing (enforce_smooth_lr)
        # Base expression (attenuation_erf0)
        inv2a = 1.0 / (2.0 * a)
        exp_term = jnp.exp(-1.0 / (4.0 * a * a))
        # base branch expression (not used directly in smoothed output)
        _ = (
            jnp.sqrt(jnp.pi) * jsp.special.erf(inv2a)
            + 2.0 * a * (exp_term - 1.0)
            - (2.0 * a * a * (exp_term - 1.0) + 0.5)
        )
        # f0 (unused here) = 1.0 - (8.0 / 3.0) * a * _

        # Large-a asymptotic series up to O(a^{-6}): 1/36 a^{-2} - 1/960 a^{-4} + 1/26880 a^{-6}
        # series (unused here): (1/36)a^-2 - (1/960)a^-4 + (1/26880)a^-6

        a_cut = jnp.array(1.35, dtype=a.dtype)
        a_small = jnp.minimum(a, a_cut)
        a_large = jnp.maximum(a, a_cut)

        # Evaluate small-a branch at a_small, large-a series at a_large
        inv2a_small = 1.0 / (2.0 * a_small)
        exp_small = jnp.exp(-1.0 / (4.0 * a_small * a_small))
        term_small = (
            jnp.sqrt(jnp.pi) * jsp.special.erf(inv2a_small)
            + 2.0 * a_small * (exp_small - 1.0)
            - (2.0 * a_small * a_small * (exp_small - 1.0) + 0.5)
        )
        f_small = 1.0 - (8.0 / 3.0) * a_small * term_small

        inv_a2_large = 1.0 / jnp.clip(a_large * a_large, 1e-30)
        inv_a4 = inv_a2_large * inv_a2_large
        inv_a6 = inv_a4 * inv_a2_large
        inv_a8 = inv_a4 * inv_a4
        inv_a10 = inv_a8 * inv_a2_large
        inv_a12 = inv_a6 * inv_a6
        inv_a14 = inv_a12 * inv_a2_large
        inv_a16 = inv_a8 * inv_a8
        # Maple2C series up to a^{-16}
        series_large = (
            (1.0 / 36.0) * inv_a2_large
            - (1.0 / 960.0) * inv_a4
            + (1.0 / 26880.0) * inv_a6
            - (1.0 / 829440.0) * inv_a8
            + (1.0 / 28385280.0) * inv_a10
            - (1.0 / 10734796800.0) * inv_a12
            + (1.0 / 445906944000.0) * inv_a14
            - (1.0 / 20214448128000.0) * inv_a16
        )

        return jnp.where(a >= a_cut, series_large, f_small)

    def _z_thr(self, z: FloatN) -> FloatN:
        zth = self._zeta_threshold
        cond_lo = (1.0 + z) <= zth
        cond_hi = (1.0 - z) <= zth
        return jnp.where(cond_lo, zth - 1.0, jnp.where(cond_hi, 1.0 - zth, z))

    def _opz_pow_n(self, z: FloatN, power: float) -> FloatN:
        zth = self._zeta_threshold
        return jnp.where(
            (1.0 + z) <= zth, jnp.array(zth, dtype=z.dtype) ** power, (1.0 + z) ** power
        )

    def _screen_dens_zeta(self, n: FloatN, z: FloatN) -> jnp.ndarray:
        # True if should screen (set to zero)
        n_spin = 0.5 * self._opz_pow_n(z, 1.0) * n  # (1+z_eff)/2 * n_total
        return (n_spin <= self._dens_threshold) | ((1.0 + z) <= self._zeta_threshold)

    def _sr_lda_exchange_spin_att(
        self, n_total: FloatN, zeta: FloatN, sign: int, omega: float
    ) -> Tuple[FloatN, FloatN]:
        # Returns (attenuation, n_sigma) for spin channel
        # sign = +1 (up) or -1 (down); spin-scaling per Maple: rs_sigma = rs * (2/(1±z))^{1/3}
        n_sigma = 0.5 * n_total * (1.0 + sign * zeta)
        # Short-range attenuation parameter a = (4/(9π))^{1/3} * omega/2 * rs_sigma
        rs_total = wigner_seitz_radius(n_total, epsilon=1e-30)
        factor = (2.0 / (1.0 + sign * zeta)) ** (1.0 / 3.0)
        rs_sigma = rs_total * factor
        a_cnst = (4.0 / (9.0 * jnp.pi)) ** (1.0 / 3.0) * (omega * 0.5)
        a = a_cnst * rs_sigma
        att = self._sr_attenuation_erf(a)
        return att, n_sigma

    def xc_energy_density(  # type: ignore
        self, n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN
    ) -> FloatN:
        """
        Local exchange-correlation energy density for wB97M-V (non-spin-resolved interface).
        Converts spin-polarized features to spin-resolved and delegates to xc_energy_density_spin_resolved.

        Args:
            n: Total electron density
            zeta: Spin polarization (n_up - n_down) / n
            s: Reduced density gradient
            tau: Kinetic energy density
        """
        # Convert from (n, zeta, s, tau) to spin-resolved features
        n_up = n * (1.0 + zeta) / 2.0
        n_down = n * (1.0 - zeta) / 2.0
        tau_up = tau * (1.0 + zeta) / 2.0
        tau_down = tau * (1.0 - zeta) / 2.0

        # Reconstruct |grad_n| from s and n
        # s = |grad_n| / (2 * (3π²)^(1/3) * n^(4/3))
        abs_grad_n = (
            s * 2.0 * (3.0 * jnp.pi**2) ** (1.0 / 3.0) * jnp.clip(n, 1e-30) ** (4.0 / 3.0)
        )

        # For the functional, only the magnitude of each spin-channel gradient matters
        # We'll create vectors along the x-axis with the appropriate magnitudes
        # For closed-shell (zeta≈0): grad_n_up = grad_n_down = grad_n / 2
        # For open-shell, scale by spin polarization
        grad_n_up = jnp.stack(
            [abs_grad_n * (1.0 + zeta) / 2.0, jnp.zeros_like(n), jnp.zeros_like(n)],
            axis=-1,
        )
        grad_n_down = jnp.stack(
            [abs_grad_n * (1.0 - zeta) / 2.0, jnp.zeros_like(n), jnp.zeros_like(n)],
            axis=-1,
        )

        return self.xc_energy_density_spin_resolved(
            n_up, n_down, grad_n_up, grad_n_down, tau_up, tau_down
        )

    def xc_energy_density_spin_resolved(  # type: ignore
        self,
        n_up: FloatN,
        n_down: FloatN,
        grad_n_up: FloatNx3,
        grad_n_down: FloatNx3,
        tau_up: FloatN,
        tau_down: FloatN,
    ) -> FloatN:
        """
        Local exchange-correlation energy density for wB97M-V.
        Args:
            n_up, n_down: Spin-resolved electron densities
            grad_n_up, grad_n_down: Spin-resolved reduced density gradients
            tau_up, tau_down: Spin-resolved kinetic energy densities

        """
        n, s, tau = combine_from_spin_resolved(
            n_up, n_down, grad_n_up, grad_n_down, tau_up, tau_down
        )

        # Detect closed-shell: for zeta≈0, use total density for reduced quantities
        # to match LibXC's behavior and avoid incorrect power-law scaling
        # Use 1e-4 threshold to catch numerically closed-shell systems
        zeta_for_scaling = (n_up - n_down) / jnp.clip(n, 1e-30)
        is_closed_shell = jnp.abs(zeta_for_scaling) < 1e-4

        abs_grad_n_up = safe_norm(grad_n_up)
        abs_grad_n_down = safe_norm(grad_n_down)

        # Spin-channel reduced gradients use k_F,σ = (6π^2 n_σ)^{1/3}
        # s_σ = |∇n_σ| / (2 k_F,σ n_σ) = |∇n_σ| / (2 (6π^2)^{1/3} n_σ^{4/3})
        # For closed-shell, use total density n instead of n_up, n_down to avoid scaling issues
        n_for_s_up = jnp.where(is_closed_shell, n, n_up)
        n_for_s_dn = jnp.where(is_closed_shell, n, n_down)

        denom_up = (
            2.0
            * (6.0 * jnp.pi**2) ** (1.0 / 3.0)
            * jnp.clip(n_for_s_up, 1e-30) ** (4.0 / 3.0)
        )
        denom_dn = (
            2.0
            * (6.0 * jnp.pi**2) ** (1.0 / 3.0)
            * jnp.clip(n_for_s_dn, 1e-30) ** (4.0 / 3.0)
        )
        s_up = abs_grad_n_up / denom_up
        s_down = abs_grad_n_down / denom_dn

        # Dimensionless t-sigma as ratio of uniform KE to orbital KE (absorbing constants into mapping)
        # t_sigma definition per Maple: t = 2 * tau_sigma / n_sigma^(5/3)
        # For closed-shell, use total density n instead of n_up, n_down
        n_for_t_up = jnp.where(is_closed_shell, n, n_up)
        n_for_t_dn = jnp.where(is_closed_shell, n, n_down)

        n_up_c = jnp.clip(n_for_t_up, 1e-30)
        n_dn_c = jnp.clip(n_for_t_dn, 1e-30)
        t_up = 2.0 * tau_up / (n_up_c ** (5.0 / 3.0))
        t_dn = 2.0 * tau_down / (n_dn_c ** (5.0 / 3.0))

        # Short-range exchange per Maple: (1±z)/2 * lda_x_erf_spin(rs_sigma,1) * g_x
        zeta_raw = (n_up - n_down) / jnp.clip(n, 1e-30)
        zeta = self._z_thr(zeta_raw)
        rs = wigner_seitz_radius(n, epsilon=1e-30)
        # Common constants
        RS_FACTOR = (3.0 / (4.0 * jnp.pi)) ** (1.0 / 3.0)
        X_FACTOR_C = (3.0 / 8.0) * (3.0 / jnp.pi) ** (1.0 / 3.0) * (4.0 ** (2.0 / 3.0))
        a_const = (4.0 / (9.0 * jnp.pi)) ** (1.0 / 3.0) * (self.omega * 0.5)
        two_to_third = 2.0 ** (1.0 / 3.0)

        # Up channel
        factor_up = 0.5 * self._opz_pow_n(zeta_raw, 1.0)
        rs_up_eff = rs * (2.0 / jnp.clip(1.0 + zeta_raw, 1e-30)) ** (1.0 / 3.0)
        eps_x_erf_up = (
            -RS_FACTOR
            * X_FACTOR_C
            / jnp.clip(rs_up_eff, 1e-30)
            * self._sr_attenuation_erf(a_const * rs_up_eff / two_to_third)
        )
        ex_up = jnp.where(
            ~self._screen_dens_zeta(n, zeta_raw),
            factor_up * eps_x_erf_up * self._g_x(s_up, t_up),
            0.0,
        )

        # Down channel
        factor_dn = 0.5 * self._opz_pow_n(-zeta_raw, 1.0)
        rs_dn_eff = rs * (2.0 / jnp.clip(1.0 - zeta_raw, 1.0e-30)) ** (1.0 / 3.0)
        eps_x_erf_dn = (
            -RS_FACTOR
            * X_FACTOR_C
            / jnp.clip(rs_dn_eff, 1e-30)
            * self._sr_attenuation_erf(a_const * rs_dn_eff / two_to_third)
        )
        ex_dn = jnp.where(
            ~self._screen_dens_zeta(n, -zeta_raw),
            factor_dn * eps_x_erf_dn * self._g_x(s_down, t_dn),
            0.0,
        )

        ex_local = ex_up + ex_dn

        # Correlation via B97M-V expansions with Stoll decomposition (util.mpl)
        RS_FACTOR = (3.0 / (4.0 * jnp.pi)) ** (1.0 / 3.0)
        rs_tot = wigner_seitz_radius(n, epsilon=1e-30)
        # Stoll parallel pieces for +z and -z (fully polarized baseline)
        rs_par_up = (
            rs_tot
            * (2.0 ** (1.0 / 3.0))
            * jnp.power(jnp.clip(1.0 + zeta_raw, 1e-30), -1.0 / 3.0)
        )
        rs_par_dn = (
            rs_tot
            * (2.0 ** (1.0 / 3.0))
            * jnp.power(jnp.clip(1.0 - zeta_raw, 1e-30), -1.0 / 3.0)
        )
        n_from_rs = lambda rs: (RS_FACTOR / jnp.clip(rs, 1e-30)) ** 3
        f_pw_par_up = pw92_correlation_energy_density(
            n_from_rs(rs_par_up), jnp.array(1.0, dtype=rs_par_up.dtype), modified=True
        )
        f_pw_par_dn = pw92_correlation_energy_density(
            n_from_rs(rs_par_dn), jnp.array(1.0, dtype=rs_par_dn.dtype), modified=True
        )
        stoll_par_up = jnp.where(
            ~self._screen_dens_zeta(n, zeta_raw),
            0.5 * self._opz_pow_n(zeta_raw, 1.0) * f_pw_par_up,
            0.0,
        )
        stoll_par_dn = jnp.where(
            ~self._screen_dens_zeta(n, -zeta_raw),
            0.5 * self._opz_pow_n(-zeta_raw, 1.0) * f_pw_par_dn,
            0.0,
        )
        # Perpendicular baseline
        # Use threshold zeta for PW92 to avoid numerical issues near zeta = ±1
        f_pw_tot = pw92_correlation_energy_density(n, zeta, modified=True)
        stoll_perp = f_pw_tot - stoll_par_up - stoll_par_dn
        # g factors
        g_ss_up = self._g_ss_single(s_up, t_up)
        g_ss_dn = self._g_ss_single(s_down, t_dn)
        g_os = self._g_os(s_up, s_down, t_up, t_dn)
        ec_local = stoll_par_up * g_ss_up + stoll_par_dn * g_ss_dn + stoll_perp * g_os
        return ex_local + ec_local

    def e_x_local(self, n: FloatN, s: FloatN, tau: FloatN) -> FloatN:
        # Unused in the hybrid; retained for completeness
        t = ueg_tau(n) / tau
        u_x = self._ux(self._gamma_x, s)
        w = (t - 1.0) / (t + 1.0)
        c0 = self._coeff_x_c0
        c1 = self._coeff_x_cu
        c2 = self._coeff_x_cw
        return ueg_e_x(n) * (c0 + c1 * u_x + c2 * w)

    def e_c_local(self, n_up, n_down, s_up, s_down, tau_up, tau_down) -> FloatN:
        # Unused in the hybrid; retained for completeness via new helpers
        n_up_c = jnp.clip(n_up, 1e-30)
        n_dn_c = jnp.clip(n_down, 1e-30)
        t_up = 2.0 * tau_up / (n_up_c ** (5.0 / 3.0))
        t_dn = 2.0 * tau_down / (n_dn_c ** (5.0 / 3.0))
        n_total = n_up + n_down
        zeta = (n_up - n_down) / jnp.clip(n_total, 1e-30)
        eps_c_ss, eps_c_os = self._pw92_spin_decomposition(n_total, zeta)
        g_ss = self._g_ss(s_up, s_down, t_up, t_dn)
        g_os = self._g_os(s_up, s_down, t_up, t_dn)
        return eps_c_ss * g_ss + eps_c_os * g_os

    def _pw92_spin_decomposition(self, n: FloatN, zeta: FloatN) -> Tuple[FloatN, FloatN]:
        ec_unpolarized, ec_polarized, _ = _pw92_correlation_components(n, False, False)
        f_zeta = ((1 + zeta) ** (4.0 / 3.0) + (1 - zeta) ** (4.0 / 3.0) - 2.0) / (
            2.0 ** (4.0 / 3.0) - 2.0
        )
        eps_c_ss = ec_polarized * f_zeta
        eps_c_os = ec_unpolarized * (1.0 - f_zeta)
        return eps_c_ss, eps_c_os

    def non_local_contribution(  # type: ignore
        self,
        density_matrix: FloatBxB | Float2xBxB,
        eri_sr_tensor: FloatBxBxBxB | FloatQxBxB,
        eri_lr_tensor: FloatBxBxBxB | FloatQxBxB,
        grid_coords: FloatNx3,
        grid_weights: FloatN,
        grid_aos: FloatNxB,
        grid_grad_aos: FloatNxBx3,
    ) -> Float1:
        e_hf = self.exact_exchange_contribution(
            density_matrix, eri_sr_tensor, eri_lr_tensor
        )
        # VV10 uses a spin-summed density matrix
        dm_for_vv10 = (
            density_matrix if self.spin_restricted else density_matrix.sum(axis=0)
        )
        disp = vv10_energy(
            dm_for_vv10,  # type: ignore
            grid_coords,
            grid_weights,
            grid_aos,
            grid_grad_aos,
            VV10_PARAMS if self.match_pyscf else VV10_wB97M_V_PARAMS,
        )
        return e_hf + disp
