from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp

from egxc.utils.typing import PRECISION, FloatBxB, FloatN, FloatNx3, FloatNxB, FloatNxBx3

VV10_PARAMS = (
    5.9,
    0.0093,
)  # used as a default in pyscf also for wB97M-V ( https://doi.org/10.1063/1.3521275)
VV10_wB97M_V_PARAMS = (6.0, 0.01)  # https://doi.org/10.1063/1.4952647


@jax.jit
def vv10_kernel(
    density: FloatN,
    abs_grad_density: FloatN,
    grid_coordinates: FloatNx3,
    grid_weights: FloatN,
    vv10_parameters: tuple[float, float] = VV10_PARAMS,
    min_density_threshold: float = 1e-8,  # default threshold for VV10 kernel used in pyscf
) -> Tuple[FloatN, jax.Array]:
    """
    Differentiable VV10 correlation kernel on a single grid.
    (1) Vydrov et al. Nonlocal van Der Waals Density Functional: The Simpler the Better.
        Journal of Chemical Physics 2010, 133 (24), 244103. https://doi.org/10.1063/1.3521275.
    (2) Vydrov et al. Impl. and Asses. of a Simple Nonlocal van Der Waals Density Functional.
        Journal of Chemical Physics 2010, 132 (16), 164113. https://doi.org/10.1063/1.3398840.

    Parameters
    ----------
    density : FloatN
        Electron density on the grid, shape ``(n_grid,)``.
    abs_grad_density : FloatN
        Absolute gradient of the electron density, shape ``(n_grid,)``.
    grid_coordinates : FloatNx3
        Cartesian coordinates of the grid points, shape ``(n_grid, 3)``.
    grid_weights : FloatN
        Integration weights associated with ``grid_coordinates``,
        shape ``(n_grid,)``.
    vv10_parameters : tuple of float
        ``(b_vv, c_vv)`` parameters defining the VV10 kernel.

    Returns
    -------
    exc : FloatN
        Non-local correlation energy per grid point divided by the
        density. Shape ``(N,)``.
    vxc : jax.Array
        Potential derivatives w.r.t. density (index 0) and the squared
        gradient magnitude (index 1). Shape ``(2, N)``.
    """
    b_vv, c_vv = vv10_parameters

    mask = density >= min_density_threshold  # (n_grid,)

    # Electron density on the grid (N,)
    electron_density = jnp.where(mask, density, 1.0)  # (n_grid,)
    # Squared magnitude of the density gradient (N,)
    gradient_squared = abs_grad_density**2  # (n_grid,)
    # Weighted density on the grid (N,)
    weighted_density = jnp.where(mask, electron_density * grid_weights, 0.0)  # (n_grid,)

    pi = jnp.pi
    four_thirds_pi = 4.0 * pi / 3.0
    k_coefficient = b_vv * 1.5 * pi * (9.0 * pi) ** (-1.0 / 6.0)
    beta = (3.0 / (b_vv * b_vv)) ** 0.75 / 32.0

    w0_argument = c_vv * (gradient_squared / electron_density**2) ** 2
    w0 = jnp.sqrt(w0_argument + four_thirds_pi * electron_density)
    dw0_drho = (0.5 * four_thirds_pi * electron_density - 2.0 * w0_argument) / w0
    dw0_dgrad = w0_argument * electron_density / (gradient_squared * w0)

    k = k_coefficient * electron_density ** (1.0 / 6.0)
    dk_drho = k / 6.0

    # Pairwise squared distances between grid points, shape ``(n_grid, n_grid)``
    displacement = grid_coordinates[:, None, :] - grid_coordinates[None, :, :]
    r_sq = jnp.sum(displacement**2, axis=-1)

    # Kernel denominators, all ``(n_grid, n_grid)``
    g_outer = r_sq * w0[:, None] + k[:, None]
    g_inner = r_sq * w0[None, :] + k[None, :]
    g_total = g_outer + g_inner

    # Integration factors ``(n_grid, n_grid)``
    t_val = weighted_density[None, :] / (g_outer * g_inner * g_total)
    f_term = -1.5 * jnp.sum(t_val, axis=1)
    u_term = jnp.sum(t_val * (1.0 / g_outer + 1.0 / g_total), axis=1)
    w_term = jnp.sum(t_val * (1.0 / g_outer + 1.0 / g_total) * r_sq, axis=1)

    exc = beta + 0.5 * f_term
    v_rho = beta + f_term + 1.5 * (u_term * dk_drho + w_term * dw0_drho)
    v_sigma = 1.5 * w_term * dw0_dgrad

    exc = jnp.where(mask, exc, 0.0)
    v_rho = jnp.where(mask, v_rho, 0.0)
    v_sigma = jnp.where(mask, v_sigma, 0.0)

    return exc, jnp.stack([v_rho, v_sigma])  # type: ignore


def __density_features(
    density_matrix: FloatBxB, aos: FloatNxB, grad_aos: FloatNxBx3
) -> Tuple[FloatN, FloatN]:
    n = jnp.einsum('uv,iu,iv->i', density_matrix, aos, aos)
    _n_grad = jnp.einsum('uv,iuj,iv -> ij', density_matrix, grad_aos, aos) + jnp.einsum(
        'uv,iu,ivj -> ij', density_matrix, aos, grad_aos
    )
    abs_grad_n = jnp.linalg.norm(_n_grad, axis=-1)
    return n, abs_grad_n


def vv10_energy(
    density_matrix: FloatBxB,
    grid_coords: FloatNx3,
    grid_weights: FloatN,
    grid_aos: FloatNxB,
    grid_grad_aos: FloatNxBx3,
    vv10_parameters: Tuple[float, float] = VV10_PARAMS,
) -> FloatN:
    n, abs_n_grad = __density_features(density_matrix, grid_aos, grid_grad_aos)
    (
        energy_density,
        _,
    ) = vv10_kernel(n, abs_n_grad, grid_coords, grid_weights, vv10_parameters)
    energy = (energy_density * n * grid_weights).sum()
    assert energy.dtype == PRECISION.xc_energy
    return energy


@partial(jax.jit, static_argnames=['vv10_parameters'])
def vv10_potential(
    density_matrix: FloatBxB,
    grid_coords: FloatNx3,
    grid_weights: FloatN,
    grid_aos: FloatNxB,
    grid_grad_aos: FloatNxBx3,
    vv10_parameters: Tuple[float, float] = VV10_PARAMS,
) -> FloatBxB:
    energy_function = partial(
        vv10_energy,
        grid_coords=grid_coords,
        grid_weights=grid_weights,
        grid_aos=grid_aos,
        grid_grad_aos=grid_grad_aos,
        vv10_parameters=vv10_parameters,
    )
    potential = jax.grad(energy_function)(density_matrix)

    return potential


# def __density_features_with_density_matrix_jacobian(
#     density_matrix: FloatBxB,
#     aos: FloatNxB,
#     grad_aos: FloatNxBx3,
#     eps: float = 1e-12,
# ) -> tuple[FloatN, FloatN, FloatBxB, FloatNxBxB]:
#     """
#     Analytic Jacobians of n and |grad n| w.r.t. the density matrix.

#     Returns
#     -------
#     dn_dD        : jnp.ndarray, shape (N, B, B)
#                    dn_dD[i, u, v] = phi_u(r_i) * phi_v(r_i)

#     dabs_grad_n_dD : jnp.ndarray, shape (N, B, B)
#                    d|∇n|_i / dD_{uv}
#     """
#     # --- compute raw n and grad_n ---
#     # n[i]
#     n = jnp.einsum('uv,iu,iv->i', density_matrix, aos, aos)  # (N,)
#     # ∂_j n[i]
#     grad_n = jnp.einsum('uv,iuj,iv->ij', density_matrix, grad_aos, aos) + jnp.einsum(
#         'uv,iu,ivj->ij', density_matrix, aos, grad_aos
#     )  # (N,3)

#     # --- |∇n| and its safe reciprocal ---
#     abs_grad_n = jnp.linalg.norm(grad_n, axis=-1)  # (N,)
#     inv_norm = 1.0 / jnp.where(abs_grad_n > eps, abs_grad_n, 1.0)  # (N,)  # type: ignore

#     # --- dn/dD is independent of D ---
#     # dn_dD[i,u,v] = aos[i,u] * aos[i,v]
#     dn_dD = jnp.einsum('iu,iv->iuv', aos, aos)  # (N,B,B)

#     # --- d(∂_j n)[i,j] / dD_{uv} ---
#     d_gradn_dD = jnp.einsum('iuj,iv   -> ijuv', grad_aos, aos) + jnp.einsum(
#         'iu,ivj   -> ijuv', aos, grad_aos
#     )  # (N,3,B,B)

#     # --- chain‐rule for |∇n| ---
#     # weights[i,j] = ∂_j n[i] / |∇n[i]|
#     weights = grad_n * inv_norm[:, None]  # (N,3)
#     weights = jnp.where(abs_grad_n[:, None] > eps, weights, 0.0)  # mask zero-gradient

#     # d|∇n|/dD = sum_j weights[i,j] * d_gradn_dD[i,j,u,v]
#     dabs_grad_n_dD = jnp.einsum('ij,ijuv->iuv', weights, d_gradn_dD)  # (N,B,B)

#     return n, abs_grad_n, dn_dD, dabs_grad_n_dD


# @partial(jax.jit, static_argnames=['vv10_parameters'])
# def vv10_potential(
#     density_matrix: FloatBxB,
#     grid_coords: FloatNx3,
#     grid_weights: FloatN,
#     grid_aos: FloatNxB,
#     grid_grad_aos: FloatNxBx3,
#     vv10_parameters: Tuple[float, float] = VV10_PARAMS,
# ) -> FloatBxB:

#     # TODO: check if this could improve the accuracy when comparing to pyscf
#     primals = (density_matrix, grid_aos, grid_grad_aos)
#     tangents = (
#         jnp.ones_like(density_matrix),
#         jnp.zeros_like(grid_aos),
#         jnp.zeros_like(grid_grad_aos),
#     )
#     (n, abs_grad_n), (dn_dp, d_abs_grad_n_dp) = jax.jvp(
#         __density_features, primals, tangents
#     )

#     n, abs_grad_n, dn_dD, dabs_grad_n_dD = (
#         __density_features_with_density_matrix_jacobian(
#             density_matrix, grid_aos, grid_grad_aos
#         )
#     )

#     _, (dE_dn, dE_d_abs_grad_n) = vv10_kernel(
#         n,
#         abs_grad_n,
#         grid_coords,
#         grid_weights,
#         vv10_parameters,
#     )

#     print(n.shape)
#     print(abs_grad_n.shape)
#     print(dE_dn.shape)
#     print(dE_d_abs_grad_n.shape)
#     print(dn_dD.shape)
#     print(dabs_grad_n_dD.shape)

#     potential = jnp.einsum('i,iuv,i->uv', dE_dn, dn_dD, grid_weights)
#     potential += jnp.einsum(
#         'i,iuv,i -> uv', dE_d_abs_grad_n, dabs_grad_n_dD, grid_weights
#     )

#     return potential
