from functools import partial
from typing import Any, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp

from deixc.orbital_transforms import (
    dm_gradient_to_orbital_rotation_gradient,
    first_order_orbital_rotation,
)
from egxc.systems import Grid
from egxc.utils.linalg import coeffs_to_density_matrix
from egxc.utils.typing import (
    BoolB,
    Float1,
    Float2xBxB,
    FloatBxB,
    FloatOV,
    FloatOxV,
    FloatTx2xBxB,
    FloatTx2xOxV,
    FloatTxBxB,
    FloatTxOxV,
    UIntB,
)
from egxc.xc_energy.features import DensityFeatures
from egxc.xc_energy.functionals.base import BaseEnergyFunctional
from egxc.xc_energy.functionals.classical import BaseRangeSeparatedHybrid, Hybrid


class XCModule(nn.Module):
    functional: BaseEnergyFunctional  # only this module contains trainable parameters
    feature_fn: DensityFeatures

    def __call__(
        self,
        density_matrix: FloatBxB | Float2xBxB,
        grid: Grid,
        **non_local_kwargs: Any,
    ) -> Float1:
        return self.xc_energy(density_matrix, grid, **non_local_kwargs)

    def xc_energy(
        self,
        density_matrix: FloatBxB | Float2xBxB,
        grid: Grid,
        **non_local_kwargs: Any,
    ) -> Float1:
        mask, feats = self.feature_fn(density_matrix, grid.aos, grid.grad_aos)
        # If features are spin-resolved, the mask is a tuple of (mask_up, mask_down)
        if isinstance(mask, tuple):
            mask = mask[0] & mask[1]
        # Inject density matrix for exact exchange contraction in hybrid functionals
        if isinstance(self.functional, (Hybrid, BaseRangeSeparatedHybrid)):
            non_local_kwargs['density_matrix'] = density_matrix
        return self.functional(grid.weights * mask, *feats, **non_local_kwargs)

    def xc_potential(
        self,
        density_matrix: FloatBxB | Float2xBxB,
        grid: Grid,
        basis_mask: BoolB,
        **non_local_kwargs: Any,
    ) -> FloatBxB | Float2xBxB:
        V = jax.grad(self.xc_energy, argnums=0)(density_matrix, grid, **non_local_kwargs)
        return jnp.where(basis_mask[:, None] * basis_mask[None, :], V, 0.0)  # type: ignore

    def xc_energy_and_potential(
        self,
        density_matrix: FloatBxB | Float2xBxB,
        grid: Grid,
        basis_mask: BoolB,
        **non_local_kwargs: Any,
    ) -> Tuple[Float1, FloatBxB | Float2xBxB]:
        e, V = jax.value_and_grad(self.xc_energy, argnums=0)(
            density_matrix, grid, **non_local_kwargs
        )
        return e, jnp.where(basis_mask[:, None] * basis_mask[None, :], V, 0.0)  # type: ignore

    def xc_potential_and_linear_response(
        self,
        density_matrix: FloatBxB | Float2xBxB,
        perturbation: FloatBxB | Float2xBxB,
        grid: Grid,
        basis_mask: BoolB,
        **non_local_kwargs: Any,
    ) -> Tuple[FloatBxB, FloatBxB] | Tuple[Float2xBxB, Float2xBxB]:
        """
        Compute the XC potential and the linear response of the XC potential to a perturbation.
        This is essentially an HVP evaluated at 'density_matrix' contracted with 'perturbation'.
        """
        xc_potential_fn = partial(
            self.xc_potential, grid=grid, basis_mask=basis_mask, **non_local_kwargs
        )
        V, delta_V = jax.jvp(xc_potential_fn, (density_matrix,), (perturbation,))
        return V, delta_V

    def xc_potential_linear_responses(
        self,
        density_matrix: FloatBxB | Float2xBxB,
        perturbations: FloatTxBxB | FloatTx2xBxB,
        grid: Grid,
        basis_mask: BoolB,
        **non_local_kwargs: Any,
    ) -> FloatTxBxB | FloatTx2xBxB:
        """
        Efficient computation of linear responses of the XC potential to a set of perturbations.
        This is essentially an HVP evaluated at 'density_matrix' contracted with 'perturbations'.
        """
        xc_potential_fn = partial(
            self.xc_potential, grid=grid, basis_mask=basis_mask, **non_local_kwargs
        )
        _, hvp_fn = jax.linearize(xc_potential_fn, density_matrix)
        return jax.vmap(hvp_fn)(perturbations)

    def xc_rotation_hvp(
        self,
        mo_coeffs: FloatBxB | Float2xBxB,
        directions: FloatTxOxV | FloatTx2xOxV,
        grid: Grid,
        basis_mask: BoolB,
        occupancies: BoolB | UIntB,
        occupied_virtual_shape: Tuple[int, int],
        **non_local_kwargs: Any,
    ) -> FloatTxOxV | FloatTx2xOxV:
        """
        Efficient computation of Hessian-vector products (HVPs) for XC energy w.r.t. orbital rotations.

        This computes H(theta=0) @ theat_ov for each direction theat_ov, where H is the Hessian of
        E_xc w.r.t. orbital rotations at theta=0.
        Uses jax.linearize to avoid materializing the full Hessian matrix.
        """
        O, V = occupied_virtual_shape  # noqa: E741
        directions = directions.reshape(-1, O * V)  # flatten directions
        zero_theta = jnp.zeros(O * V, dtype=mo_coeffs.dtype)

        def gradient_theta(theta_flat: FloatOV) -> FloatOxV:
            """Gradient of E_xc w.r.t. theta (flattened)."""
            theta = theta_flat.reshape(O, V)
            C_dash = first_order_orbital_rotation(mo_coeffs, theta)
            P = coeffs_to_density_matrix(C_dash, occupancies)
            V_xc = self.xc_potential(P, grid, basis_mask, **non_local_kwargs)
            return dm_gradient_to_orbital_rotation_gradient(V_xc, C_dash, O)
            # return dm_gradient_to_orbital_rotation_gradient(V_xc, mo_coeffs, O)

        # We do not use jax.linearize here:
        #   _, hvp_fn = jax.linearize(gradient_theta, zero_theta)
        #   return jax.vmap(hvp_fn)(directions)
        # Since the documentations explicitly states "This function is mainly useful to apply f_jvp
        # multiple times at the same linearization point. Moreover if all the input tangent vectors
        # are known at once, it can be more efficient to vectorize using vmap()"
        pushfwd = partial(jax.jvp, gradient_theta, (zero_theta,))
        _, out_tangents = jax.vmap(pushfwd, out_axes=(None, 0))((directions,))
        return out_tangents
