from typing import Any

import jax.numpy as jnp

from egxc.utils.typing import (
    PRECISION,
    Float1,
    Float2xBxB,
    FloatBxB,
    FloatBxBxBxB,
    FloatN,
    FloatQxBxB,
)
from egxc.xc_energy.functionals.base import BaseEnergyFunctional
from egxc.xc_energy.functionals.classical import gga


def e_xc_pbe0(n: FloatN, zeta: FloatN, s: FloatN, **unused_kwargs) -> FloatN:
    return 0.75 * gga.e_x_pbe(n, zeta, s) + gga.e_c_pbe(n, zeta, s)


def e_xc_b3lyp(
    n: FloatN, zeta: FloatN, s: FloatN, unused_tau: FloatN | None = None, **unused_kwargs
) -> FloatN:
    """
    The exchange-correlation energy density of B3LYP functional.
    TODO: check whether spin polarization is assumed to be zero.
    """
    return gga.e_x_local_b3lyp(n, zeta, s) + gga.e_c_local_b3lyp(n, zeta, s)


def exact_exchange(
    density_matrix: FloatBxB | Float2xBxB, eri_tensor: FloatBxBxBxB, spin_restricted: bool
) -> Float1:
    if spin_restricted:
        P = 0.5 * density_matrix
        out = 2 * jnp.einsum('ijkl,ik,jl', eri_tensor, P, P)
    else:
        Ps = density_matrix
        out = jnp.einsum('ijkl,sik,sjl', eri_tensor, Ps, Ps)
    return -0.5 * out


def density_fitted_exact_exchange(
    density_matrix: FloatBxB | Float2xBxB, df_tensor: FloatQxBxB, spin_restricted: bool
) -> Float1:
    if spin_restricted:
        P = 0.5 * density_matrix
        out = 2 * jnp.einsum('Pij,Pkl,ik,jl', df_tensor, df_tensor, P, P)
    else:
        Ps = density_matrix
        out = jnp.einsum('Pij,Pkl,sik,sjl', df_tensor, df_tensor, Ps, Ps)
    return -0.5 * out


class Hybrid(BaseEnergyFunctional):
    """
    A hybrid exchange-correlation functional, meaning composite functionals
    of local xc energy densities with a fraction of "exact" hatree-fock exchange.
    """

    key: str
    use_density_fitting: bool
    spin_restricted: bool
    is_graph_based = False

    def setup(self):
        match self.key.lower():
            case 'hf_x':
                self.exact_exchange_fraction = 1.0
                self._xc_energy_density = lambda **kwargs: 0
            case 'pbe0':
                self.exact_exchange_fraction = 0.25
                self._xc_energy_density = e_xc_pbe0
            case 'b3lyp':
                self.exact_exchange_fraction = 0.2
                self._xc_energy_density = e_xc_b3lyp
            case _:
                raise ValueError(f'Invalid hybrid type: {self.key}')

    def __call__(
        self, weights: FloatN, *feats: FloatN, **non_local_kwargs: Any
    ) -> Float1:
        """
        overwrites the base class method to include the non-local contributions
        due to exact exchange.
        """
        E_loc = self.integrate_energy_density(weights, *feats)
        E_glob = self.non_local_contribution(**non_local_kwargs)
        assert E_loc.dtype == PRECISION.xc_energy
        assert E_glob.dtype == PRECISION.xc_energy
        return E_loc + E_glob

    def xc_energy_density(self, *feats: FloatN, **kwargs: FloatN) -> FloatN:
        # append feats to kwargs
        map_feats = ('n', 'zeta', 's', 'tau', 'weights')
        for i, f in enumerate(feats):
            kwargs[map_feats[i]] = f  # type: ignore
        return self._xc_energy_density(**kwargs)  # type: ignore

    def non_local_contribution(
        self, density_matrix: FloatBxB | Float2xBxB, eri_tensor: FloatBxBxBxB | FloatQxBxB
    ) -> Float1:
        if not self.use_density_fitting:
            E_HFx = exact_exchange(density_matrix, eri_tensor, self.spin_restricted)
        else:
            E_HFx = density_fitted_exact_exchange(
                density_matrix, eri_tensor, self.spin_restricted
            )
        return self.exact_exchange_fraction * E_HFx
