"""TDDFT/TDA builders using the shared response function.

This module implements time-dependent density functional theory (TDDFT) and
Tamm-Dancoff approximation (TDA) matrix-vector products using JAX automatic
differentiation to compute Hessian-vector products of the total energy functional.

The implementation follows the Casida formulation:
    (A - B)(A + B) (X + Y) = ω² (X + Y)
    (A + B)(A - B) (X - Y) = ω² (X - Y)

where A and B are the TDDFT matrices and ω are excitation energies.

References:
    Casida, M. E. (1995). In Recent Advances in Density Functional Methods,
    Part I (pp. 155-192). World Scientific.
"""

from typing import Any, Callable, Tuple

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from egxc.solver.fock import get_coulomb_matrix_fn
from egxc.systems import System
from egxc.utils.typing import Float1, FloatBxB, FloatTxBxB
from egxc.xc_energy.functionals.classical import BaseRangeSeparatedHybrid, Hybrid
from egxc.xc_energy.functionals.learnable.egxc import EGXC
from egxc.xc_energy.xc_module import XCModule

# TDDFT-specific type aliases (O = occupied, V = virtual, M = multiple trial vectors)
FloatBxO = Float[Array, 'B O']  # Occupied orbital coefficients
FloatBxV = Float[Array, 'B V']  # Virtual orbital coefficients
FloatVxO = Float[Array, 'V O']  # Orbital energy differences
FloatOVxM = Float[Array, 'OV M']  # Flattened ia-space trial vectors


def build_total_energy_and_vresp(
    sys: System,
    xc_module: XCModule,
    params: Any,
    P_ref: FloatBxB,
    spin_restricted: bool,
    use_density_fitting: bool,
) -> Tuple[Callable[[FloatBxB], Float1], Callable[[FloatTxBxB], FloatTxBxB]]:
    """Create total energy functional E[P] and its batched Hessian-vector product.

    Constructs the total electronic energy E[P] = E_xc[P] + E_H[P] as a function
    of the density matrix P, along with a function that computes batched
    Hessian-vector products (HVPs) of this energy with respect to P.

    The HVP is computed using JAX forward-mode autodiff as:
        v_resp(dP) = ∂²E/∂P² |_{P=P_ref} · dP

    Args:
        sys: Physical system containing grid and Fock tensors.
        xc_module: Exchange-correlation functional module.
        params: Parameters for the XC functional (e.g., neural network weights).
        P_ref: Reference density matrix at which to evaluate the response (B×B).
        spin_restricted: Whether to use spin-restricted formalism.
        use_density_fitting: Whether to use density fitting for ERIs.

    Returns:
        Tuple containing:
            - energy_total: Function E[P] → scalar total energy
            - vresp: Batched HVP function dP_batch → dE_batch where
                     dP_batch has shape (T, B, B) and output has same shape
    """
    coulomb_matrix_fn = get_coulomb_matrix_fn(spin_restricted, use_density_fitting)

    @jax.jit
    def energy_total(P: FloatBxB) -> Float1:
        # Build non_local_kwargs for functionals that need them
        non_local_kwargs = {}
        if isinstance(xc_module.functional, EGXC):
            non_local_kwargs['nuc_pos'] = sys._nuc_pos
            non_local_kwargs['atom_mask'] = sys.atom_mask
            non_local_kwargs['grid_coords'] = sys.grid.coords

        if isinstance(xc_module.functional, (Hybrid, BaseRangeSeparatedHybrid)):
            non_local_kwargs['eri_tensor'] = sys.fock_tensors.ert

        exc = xc_module.apply(params, P, sys.grid, **non_local_kwargs)
        # Tcomputes E_H = 0.5 * trace(J·P)
        eh = 0.5 * (coulomb_matrix_fn(P, sys.fock_tensors.ert) * P).sum()
        return exc + eh

    P0 = jnp.asarray(P_ref)

    def vresp(dms: FloatTxBxB) -> FloatTxBxB:
        """Batched Hessian-vector product of E w.r.t. P at P0."""

        def hv_single(dP: FloatBxB) -> FloatBxB:
            return jax.jvp(jax.grad(energy_total), (P0,), (dP,))[1]

        return jax.vmap(hv_single, in_axes=0)(dms)

    return energy_total, vresp


def build_cassida_mv(
    sys: System,
    xc_module: XCModule,
    params: Any,
    occupied_orbs: FloatBxO,
    virtual_orbs: FloatBxV,
    e_ia: FloatVxO,
    P_ref: FloatBxB,
    spin_restricted: bool,
    use_density_fitting: bool,
    tda_approx: bool = False,
) -> Callable:
    """Construct TDDFT/TDA matrix-vector product functions using HVP response.

    Builds matrix-vector product functions for either the Tamm-Dancoff approximation
    (TDA) or full time-dependent DFT (TDDFT) eigenvalue problems in the occupied-virtual
    (ia) orbital space.

    For TDA (Tamm-Dancoff approximation):
        A·X = ω·X
    where A_ia,jb = δ_ij δ_ab ε_ia + K_ia,jb

    For full TDDFT (Casida formulation):
        [A  B] [X]     [X]
        [B  A] [Y] = ω [Y]

    The matrices are constructed from orbital energy differences ε_ia = ε_a - ε_i
    and the response kernel K computed via HVP of the total energy functional.

    Args:
        sys: Physical system containing grid and Fock tensors.
        xc_module: Exchange-correlation functional module.
        params: Parameters for the XC functional.
        occupied_orbs: Occupied molecular orbital coefficients (B×O).
        virtual_orbs: Virtual molecular orbital coefficients (B×V).
        e_ia: Orbital energy differences ε_a - ε_i with shape (V×O).
        P_ref: Reference density matrix for HVP evaluation (B×B).
        spin_restricted: Whether to use spin-restricted formalism (factor of 2).
        use_density_fitting: Whether to use density fitting for ERIs.
        tda_approx: If True, return TDA matrix-vector product; if False, return
                    full TDDFT matrix-vector product.

    Returns:
        For TDA (tda_approx=True):
            Function (X: OV×M) → (AX: OV×M) computing A·X for M trial vectors

        For TDDFT (tda_approx=False):
            Function (X: OV×M, Y: OV×M) → (U1: OV×M, U2: OV×M) where
            U1 = A·X + B·Y and U2 = B·X + A·Y

    Notes:
        - OV dimension is flattened occupied-virtual product (O*V)
        - Spin factor of 2.0 is hardcoded for spin-restricted calculations
        - Uses einsum for efficient tensor contractions
    """
    _, vresp = build_total_energy_and_vresp(
        sys, xc_module, params, P_ref, spin_restricted, use_density_fitting
    )

    O = occupied_orbs.shape[1]
    V = virtual_orbs.shape[1]

    def _tda_mv(X: FloatOVxM) -> FloatOVxM:
        m = X.shape[1]
        xs = X.T.reshape(m, O, V)
        dms = jnp.einsum('xov,pv,qo->xpq', xs, virtual_orbs, occupied_orbs.conj() * 2.0)
        v1ao = vresp(dms)
        v1mo = jnp.einsum('xpq,qo,pv->xov', v1ao, occupied_orbs, virtual_orbs.conj())
        v1mo += jnp.einsum('xia,ia->xia', xs, e_ia)
        return v1mo.reshape(m, -1).T

    if tda_approx:
        return _tda_mv

    def _tddft_mv(X: FloatOVxM, Y: FloatOVxM) -> Tuple[FloatOVxM, FloatOVxM]:
        m = X.shape[1]
        xs = X.T.reshape(m, O, V)
        ys = Y.T.reshape(m, O, V)

        dms = jnp.einsum('xov,pv,qo->xpq', xs, virtual_orbs, occupied_orbs.conj() * 2.0)
        dms += jnp.einsum('xov,qv,po->xpq', ys, virtual_orbs.conj(), occupied_orbs * 2.0)
        v1ao = vresp(dms)

        v1_top = jnp.einsum('xpq,qo,pv->xov', v1ao, occupied_orbs, virtual_orbs.conj())
        v1_bot = jnp.einsum('xpq,po,qv->xov', v1ao, occupied_orbs.conj(), virtual_orbs)
        v1_top += jnp.einsum('xia,ia->xia', xs, e_ia)
        v1_bot += jnp.einsum('xia,ia->xia', ys, e_ia)
        U1 = v1_top.reshape(m, -1).T
        U2 = v1_bot.reshape(m, -1).T
        return U1, U2

    return _tddft_mv
