from typing import Any, Callable

import flax.linen as nn

from egxc.utils.typing import (
    PRECISION,
    Float1,
    FloatN,
    FloatNx3,
)


class BaseEnergyFunctional(nn.Module):
    """
    A baseclass which includes numeric quadrature utilities for (semi-)local
    energy densities.
    """

    def __call__(
        self, weights: FloatN, *feats: FloatN, **non_local_kwargs: Any
    ) -> Float1:
        """
        Integrates the exchange-correlation energy density over the grid.
        By default assumes canonical / non spin-resolved density features.
        """
        xc_energy = self.integrate_energy_density(weights, *feats)
        return xc_energy

    def integrate_energy_density(self, weights: FloatN, *feats: FloatN) -> Float1:
        """
        weights (FloatN): Quadrature weights
        feats (FloatN): local electron density features in the following order:
            n (FloatN): electron density
            xi (FloatN): spin polarization
            s (FloatN): related to |grad n|
            tau (FloatN): reduced density gradient

        Returns: Float1: exchange-correlation energy
        """
        n = feats[0]
        e_xc = self.xc_energy_density(*feats)
        out = (weights * n * e_xc).sum()
        assert out.dtype == PRECISION.xc_energy
        return out

    def integrate_spin_resolved_energy_density(
        self, weights: FloatN, *feats: FloatN | FloatNx3
    ) -> Float1:
        """
        weights (FloatN): Quadrature weights
        feats (FloatN): local electron density features in the following order:
            n_up, n_down (FloatN): electron density
            grad_n_up, grad_n_down (FloatNx3): gradient of electron density
            tau_up, tau_down (FloatN): reduced density gradient

        Returns: Float1: exchange-correlation energy
        """
        n_up, n_down = feats[0], feats[1]
        e_xc = self.xc_energy_density_spin_resolved(*feats)
        out = (weights * (n_up + n_down) * e_xc).sum()
        assert out.dtype == PRECISION.xc_energy
        return out

    def xc_energy_density(self, *feats: FloatN) -> FloatN:
        """
        Abstract method for the exchange-correlation energy density.
        """
        raise NotImplementedError

    def xc_energy_density_spin_resolved(self, *feats: FloatN | FloatNx3) -> FloatN:
        """
        Optional method for spin-resolved energy densities.
        """
        raise NotImplementedError


class EnergyDensityToFunctionalWrapper(BaseEnergyFunctional):
    """
    A wrapper that creates a BaseFunctional from a given energy density.
    """

    _xc_energy_density: Callable[..., FloatN]

    def xc_energy_density(self, *feats: FloatN) -> FloatN:
        """
        Abstract method for the exchange-correlation energy density.
        """
        return self._xc_energy_density(*feats)
