from functools import partial
from typing import Tuple

import jax

from egxc.utils.typing import FloatBxB, FloatN, FloatNx3, FloatNxB, FloatNxBx3
from egxc.xc_energy.functionals.dispersion.vv10 import vv10_potential


@partial(jax.jit, static_argnames=['vv10_parameters'])
def vv10_linear_response(
    density_matrix_0: FloatBxB,
    delta_density_matrix: FloatBxB,
    grid_coords: FloatNx3,
    grid_weights: FloatN,
    grid_aos: FloatNxB,
    grid_grad_aos: FloatNxBx3,
    vv10_parameters: Tuple[float, float],
) -> FloatBxB:
    """
    Computes the linear response of the VV10 non-local correlation energy
    with respect to a perturbation in the density matrix.

    Parameters
    ----------
    density_matrix_0 : FloatBxB
        The density matrix at which the linear response is computed
    delta_density_matrix : FloatBxB
        The perturbation in the density matrix
    grid_coords : FloatNx3
        Cartesian coordinates of the quadrature grid
    grid_weights : FloatN
        Quadrature weights associated with the grid points
    grid_aos : FloatNxB
        Atomic orbitals evaluated on the quadrature grid
    grid_grad_aos : FloatNxBx3
        Gradients of atomic orbitals evaluated on the quadrature grid
    vv10_parameters : Tuple[float, float]
        ``(b_vv, c_vv)`` parameters defining the VV10 kernel.

    Returns
    -------
    FloatBxB
        The linear response of the VV10 energy with respect to the perturbation
        in the density matrix, shape ``(n_basis, n_basis)``.
    """
    potential_function = partial(
        vv10_potential,
        grid_coords=grid_coords,
        grid_weights=grid_weights,
        grid_aos=grid_aos,
        grid_grad_aos=grid_grad_aos,
        vv10_parameters=vv10_parameters,
    )
    potential, linear_response = jax.jvp(
        potential_function, primals=(density_matrix_0,), tangents=(delta_density_matrix,)
    )
    return linear_response  # type: ignore
