from typing import Annotated, Literal, Tuple, overload

import flax.linen as nn
import jax.numpy as jnp

from egxc.utils import linalg
from egxc.utils.typing import (
    BoolN,
    Float2xBxB,
    FloatBxB,
    FloatN,
    FloatNx3,
    FloatNxB,
    FloatNxBx3,
)


def ueg_e_x(n: FloatN) -> FloatN:
    """
    The exchange energy per particle of the uniform electron gas.
    """
    return -(3 / 4) * (3 / jnp.pi) ** (1 / 3) * n ** (1 / 3)


def ueg_spin_pol_e_x_factor(zeta: FloatN) -> FloatN:
    """
    Computes the dependency factor of the exchange energy of the uniformly
    spin-polarized electron gas on the spin polarization zeta.

    Oliver, G. L.; Perdew, J. P.
    "Spin-Density Gradient Expansion for the Kinetic Energy."
    Phys. Rev. A 1979, 20 (2), 397–403.
    https://doi.org/10.1103/PhysRevA.20.397.
    """
    return 0.5 * ((1 + zeta) ** (4 / 3) + (1 - zeta) ** (4 / 3))


def ueg_tau(n: FloatN) -> FloatN:
    """
    Computes the kinetic energy density of the uniform electron gas.
    """
    tau_unif = (3 / 10) * (3 * jnp.pi**2) ** (2 / 3) * n ** (5 / 3)
    return tau_unif


def wigner_seitz_radius(n: FloatN, epsilon: float = 0) -> FloatN:
    """
    The Wigner-Seitz radius
    """
    return (3 / (4 * jnp.pi * n + epsilon)) ** (1 / 3)


def fermi_wave_vector(n: FloatN) -> FloatN:
    """
    The Fermi wave vector commonly denoted as k_f
    """
    return (3 * jnp.pi**2 * n) ** (1 / 3)


def transform_abs_grad_n_to_s(n: FloatN, abs_grad_n: FloatN) -> FloatN:
    """
    Computes the reduced density gradient s which can be expressed using
    the fermi wave vector k_f as: s = |grad n| / (2 * k_f * n)

    TODO: check in previous implementation we needed to do:
    s = np.clip(s, 0) is this still the case?
    """
    return abs_grad_n / (2 * (3 * jnp.pi**2) ** (1 / 3) * n ** (4 / 3))


def transform_s_to_abs_grad_n(n: FloatN, s: FloatN) -> FloatN:
    """
    Computes the absolute gradient of the electron density |grad n|
    """
    return 2 * (3 * jnp.pi**2) ** (1 / 3) * n ** (4 / 3) * s


def ueg_spin_pol_e_kin_factor(zeta: FloatN) -> FloatN:
    """
    Computes the dependency factor of the kinetic energy of the uniformly
    spin-polarized electron gas on the spin polarization zeta and the kinetic
    energy density tau.
    """
    return (1 / 2) * ((1 + zeta) ** (5 / 3) + (1 - zeta) ** (5 / 3))


def weizsacker_kinetic_energy_density(n: FloatN, s: FloatN) -> FloatN:
    """
    Computes the Weizsäcker kinetic energy density
    """
    return (1 / 2) * (3 * jnp.pi**2) ** (2 / 3) * n ** (5 / 3) * s**2


def transform_tau_to_alpha(n: FloatN, zeta: FloatN, s: FloatN, tau: FloatN) -> FloatN:
    """
    Computes the kinetic energy density parameter alpha
    """
    tau_w = weizsacker_kinetic_energy_density(n, s)
    d_zeta = ueg_spin_pol_e_kin_factor(zeta)
    tau_unif = (3 / 10) * (3 * jnp.pi**2) ** (2 / 3) * n ** (5 / 3) * d_zeta
    alpha = (tau - tau_w) / tau_unif
    return alpha


def combine_from_spin_resolved(
    n_up, n_dn, grad_n_up, grad_n_dn, tau_up, tau_dn, eps=1e-30
):
    """Exact totals from spin-resolved ingredients."""
    n = n_up + n_dn
    grad_n = grad_n_up + grad_n_dn
    abs_grad_n = linalg.safe_norm(grad_n)
    s = abs_grad_n / (
        2.0 * (3.0 * jnp.pi**2) ** (1.0 / 3.0) * jnp.clip(n, eps) ** (4.0 / 3.0)
    )
    tau = tau_up + tau_dn
    return n, s, tau


def _mask_density(threshold: float, n: FloatN) -> Tuple[BoolN, FloatN]:
    """
    avoid divide by zero when n = 0 by replacing with threshold
    """
    mask = n > threshold
    n = jnp.where(mask, n, threshold)
    return mask, n


DensityFeatsNoGrad = Tuple[
    Annotated[BoolN, 'mask'], Tuple[Annotated[FloatN, 'n'], Annotated[FloatN, 'zeta']]
]
DensityFeatsWithGrad = Tuple[
    Annotated[BoolN, 'mask'],
    Tuple[
        Annotated[FloatN, 'n'],
        Annotated[FloatN, 'zeta'],
        Annotated[FloatN, 's'],
        Annotated[FloatN, 'tau'],
    ],
]
DensityFeatsSpinResolvedNoGrad = Tuple[
    Tuple[Annotated[BoolN, 'mask_up'], Annotated[BoolN, 'mask_down']],
    Tuple[Annotated[FloatN, 'n_up'], Annotated[FloatN, 'n_down']],
]
DensityFeatsSpinResolvedWithGrad = Tuple[
    Tuple[Annotated[BoolN, 'mask_up'], Annotated[BoolN, 'mask_down']],
    Tuple[
        Annotated[FloatN, 'n_up'],
        Annotated[FloatN, 'n_down'],
        Annotated[FloatNx3, 'n_grad_up'],
        Annotated[FloatNx3, 'n_grad_down'],
        Annotated[FloatN, 'tau_up'],
        Annotated[FloatN, 'tau_down'],
    ],
]


class DensityFeatures[SR: Literal[True] | Literal[False]](nn.Module):
    spin_restricted: bool
    spin_resolved: SR = False  # type: ignore
    min_density_threshold: float = 1e-15

    @overload
    def __call__(
        self: 'DensityFeatures[Literal[True]]',
        density_matrix: FloatBxB | Float2xBxB,
        aos: FloatNxB,
        grad_aos: None,
    ) -> DensityFeatsSpinResolvedNoGrad: ...
    @overload
    def __call__(
        self: 'DensityFeatures[Literal[True]]',
        density_matrix: FloatBxB | Float2xBxB,
        aos: FloatNxB,
        grad_aos: FloatNxBx3,
    ) -> DensityFeatsSpinResolvedWithGrad: ...
    @overload
    def __call__(
        self: 'DensityFeatures[Literal[False]]',
        density_matrix: FloatBxB | Float2xBxB,
        aos: FloatNxB,
        grad_aos: None,
    ) -> DensityFeatsNoGrad: ...
    @overload
    def __call__(
        self: 'DensityFeatures[Literal[False]]',
        density_matrix: FloatBxB | Float2xBxB,
        aos: FloatNxB,
        grad_aos: FloatNxBx3,
    ) -> DensityFeatsWithGrad: ...
    def __call__(
        self,
        density_matrix: FloatBxB | Float2xBxB,
        aos: FloatNxB,
        grad_aos: FloatNxBx3 | None,
    ):
        """
        Computes the electron density features in atomic units (length in bohr, density in bohr^-3):

            n >= 0 (density)
            zeta in [-1,1] (spin-polarization)
            s >= 0 (reduced density gradient)
            tau >= 0 (kinetic energy density)

        If grad_aos is None:
            - For spin_restricted and not spin_resolved:
                returns mask, (n, zeta)
            - For spin_resolved:
                returns (mask_up, mask_down), (n_up, n_down)
        Else:
            - For spin_restricted and not spin_resolved:
                returns mask, (n, zeta, s, tau)
            - For spin_resolved:
                returns (mask_up, mask_down), (n_up, n_down, n_grad_up, n_grad_down, tau_up, tau_down)

        where mask or (mask_up, mask_down) are boolean arrays to mask out densities below the threshold
        to avoid divisions by zero.
        """
        if self.spin_restricted:
            assert density_matrix.ndim == 2
            if self.spin_resolved:
                density_matrix = jnp.repeat(density_matrix[None] / 2, 2, axis=0)
                return self._spin_unrestricted_feats_spin_resolved(
                    density_matrix, aos, grad_aos
                )
            else:
                return self._spin_restricted_feats(density_matrix, aos, grad_aos)
        else:
            assert density_matrix.ndim == 3
            if self.spin_resolved:
                return self._spin_unrestricted_feats_spin_resolved(
                    density_matrix, aos, grad_aos
                )
            else:
                return self._spin_unrestricted_feats(density_matrix, aos, grad_aos)

    def _spin_restricted_feats(
        self, density_matrix: FloatBxB, aos: FloatNxB, grad_aos: FloatNxBx3 | None
    ):
        n = jnp.einsum('uv,iu,iv->i', density_matrix, aos, aos)
        mask, n = _mask_density(self.min_density_threshold, n)
        zeta = jnp.zeros_like(n)
        if grad_aos is None:
            return mask, (n, zeta)
        # NOTE: figure out precision issue with: _n_grad = 2 * jnp.einsum('uv,iuj,iv -> ij', density_matrix, grad_aos, aos)
        _n_grad = jnp.einsum(
            'uv,iuj,iv -> ij', density_matrix, grad_aos, aos
        ) + jnp.einsum('uv,iu,ivj -> ij', density_matrix, aos, grad_aos)
        abs_n_grad = linalg.safe_norm(_n_grad)
        s = transform_abs_grad_n_to_s(n, abs_n_grad)
        tau = 0.5 * jnp.einsum('uv,iuj,ivj->i', density_matrix, grad_aos, grad_aos)
        return mask, (n, zeta, s, tau)

    def _spin_unrestricted_feats(
        self, density_matrix: Float2xBxB, aos: FloatNxB, grad_aos: FloatNxBx3 | None
    ):
        n_up = jnp.einsum('uv,iu,iv->i', density_matrix[0], aos, aos)
        n_down = jnp.einsum('uv,iu,iv->i', density_matrix[1], aos, aos)
        n = n_up + n_down
        mask, n = _mask_density(self.min_density_threshold, n)
        zeta = (n_up - n_down) / n
        if grad_aos is None:
            return mask, (n, zeta)
        # _n_grad = 2 * jnp.einsum('suv,iuj,iv -> ij', density_matrix, grad_aos, aos)
        _n_grad = jnp.einsum(
            'suv,iuj,iv -> ij', density_matrix, grad_aos, aos
        ) + jnp.einsum('suv,iu,ivj -> ij', density_matrix, aos, grad_aos)
        abs_n_grad = linalg.safe_norm(_n_grad)
        s = transform_abs_grad_n_to_s(n, abs_n_grad)
        tau = 0.5 * jnp.einsum('suv,iuj,ivj->i', density_matrix, grad_aos, grad_aos)
        return mask, (n, zeta, s, tau)

    def _spin_unrestricted_feats_spin_resolved(
        self, density_matrix: Float2xBxB, aos: FloatNxB, grad_aos: FloatNxBx3 | None
    ):
        n_up = jnp.einsum('uv,iu,iv->i', density_matrix[0], aos, aos)
        n_down = jnp.einsum('uv,iu,iv->i', density_matrix[1], aos, aos)
        mask_up, n_up = _mask_density(self.min_density_threshold, n_up)
        mask_down, n_down = _mask_density(self.min_density_threshold, n_down)
        if grad_aos is None:
            return (mask_up, mask_down), (n_up, n_down)
        # grad_n_up = 2 * jnp.einsum('uv,iuj,iv -> ij', density_matrix[0], grad_aos, aos)
        # grad_n_down = 2 * jnp.einsum('uv,iuj,iv -> ij', density_matrix[1], grad_aos, aos)
        grad_n_up = jnp.einsum(
            'uv,iuj,iv -> ij', density_matrix[0], grad_aos, aos
        ) + jnp.einsum('uv,iu,ivj -> ij', density_matrix[0], aos, grad_aos)
        grad_n_down = jnp.einsum(
            'uv,iuj,iv -> ij', density_matrix[1], grad_aos, aos
        ) + jnp.einsum('uv,iu,ivj -> ij', density_matrix[1], aos, grad_aos)
        tau_up = 0.5 * jnp.einsum('uv,iuj,ivj->i', density_matrix[0], grad_aos, grad_aos)
        tau_down = 0.5 * jnp.einsum(
            'uv,iuj,ivj->i', density_matrix[1], grad_aos, grad_aos
        )
        return (mask_up, mask_down), (
            n_up,
            n_down,
            grad_n_up,
            grad_n_down,
            tau_up,
            tau_down,
        )
