from functools import partial
from typing import Callable, Tuple

import einops
import jax
import jax.numpy as jnp

from egxc.utils import linalg
from egxc.utils.constants import ANGSTROM_TO_BOHR
from egxc.utils.typing import (
    Float3,
    FloatAx3,
    FloatAxN,
    FloatAxNx3,
    FloatG,
    FloatN,
    FloatNx3,
    FloatNxAxM_SPH,
    FloatNxB,
    FloatNxBx3,
    FloatNxRP,
    FloatRP,
    NpFloatAx3,
    UIntRP,
)

from .constants import CART_SPH_CONTRACTIONS, L_TO_LXLYLZ
from .containers import GTOCompileStatics, RadialPrimitives


def displacements_and_distances(n: FloatNx3, a: FloatAx3) -> Tuple[FloatAxNx3, FloatAxN]:
    """
    Computes the displacement vectors between two sets of points and
    their corresponding distances.
    """
    displacements = n[None, :, :] - a[:, None, :]
    distances = linalg.safe_norm(displacements)
    return displacements, distances


@partial(jax.vmap, in_axes=(0, 0, 0, 0))  # basis
def radial_components(
    distance: FloatN, norm: FloatG, ctr_coeffs: FloatG, basis_exponents: FloatG
) -> FloatN:
    exps = norm[None] * jnp.exp(-(distance**2) * basis_exponents[None])  # N x G
    return (ctr_coeffs * exps).sum(-1)


def compute_radial_primitives(
    distances: FloatAxN,
    coeffs: FloatRP,
    exponents: FloatRP,
    atom_indices: UIntRP,
) -> FloatNxRP:
    """
    Evaluates the radial components of the primitives.
    """
    sq_dist = (distances**2).T  # N x A
    exponent = -exponents[None] * sq_dist.at[:, atom_indices].get()  # N x T
    return coeffs[None] * jnp.exp(exponent)  # N x T


@partial(jax.jit, donate_argnames=('primitives'))
def aggregate_radial_shells(primitives: FloatNxRP, shell_indices: UIntRP) -> FloatNxRP:
    """
    In place aggregation of radial primitives. out[:, :b] contains the aggregated radial components.
    """
    out = jnp.zeros_like(primitives)
    return out.at[:, shell_indices].add(primitives)


GTOGridEvalFn = Callable[
    [FloatNx3, FloatAx3 | NpFloatAx3, RadialPrimitives, GTOCompileStatics],
    FloatNxB | Tuple[FloatNxB, FloatNxBx3],
]


@jax.jit
def radials_and_displacements(
    grid_coords: FloatNx3,
    masked_nuc_pos: FloatAx3,
    primitives: RadialPrimitives,
) -> Tuple[FloatNxRP, FloatAxNx3]:
    displacements, distances = displacements_and_distances(
        grid_coords, masked_nuc_pos * ANGSTROM_TO_BOHR
    )  # pyscf constants are in bohr
    radial_primitives = compute_radial_primitives(
        distances,
        primitives.coeffs,
        primitives.exponents,
        primitives.atom_indices,
    )

    radial_shells = aggregate_radial_shells(radial_primitives, primitives.shell_indices)
    return radial_shells, displacements


@partial(jax.jit, static_argnames=('angular_momentum',))
def compute_angular_components(
    displacements: FloatAxNx3, angular_momentum: int
) -> FloatNxAxM_SPH:
    """
    displacements[a,n,i]: contains the displacement in i-th direction between the a-th atom and the n-th grid point
    angular_momentum: angular momentum of the basis function

    Returns the angular components of the basis functions of equal angular momentum for each grid point.
    """
    ijk_s = jnp.array(L_TO_LXLYLZ[angular_momentum], dtype=jnp.uint8)  # M_cartesian x 3
    cartesian_angulars = jnp.power(displacements[None], ijk_s[:, None, None]).prod(
        axis=-1
    )  # M_cartesian x A x N
    coeff = CART_SPH_CONTRACTIONS[angular_momentum]  # M_sph x M_cartesian
    angulars = jnp.einsum('sc,can -> nas', coeff, cartesian_angulars)  # N x A x M_sph
    return angulars


def get_gto_grid_eval_fn(
    deriv: int,
    max_angular_momentum: int,
) -> GTOGridEvalFn:
    @partial(jax.jit, static_argnames=('compile_statics'))
    def atomic_orbitals(
        grid_coords: FloatNx3,
        nuc_pos: FloatAx3 | NpFloatAx3,
        primitives: RadialPrimitives,
        compile_statics: GTOCompileStatics,
    ) -> FloatNxB:
        """
        Evaluate atomic orbitals at grid points.

        Args:
            grid_coords: Grid coordinates in Angstrom (N x 3)
            nuc_pos: Nuclei positions (A x 3)
            vec_basis_fns: Vectorized basis function data structure

        Returns:
            Atomic orbital values at grid points (N x B)

        Notes:
            - b denotes the number of basis function shells with distinct (L, radial) quantum numbers
            - B denotes the total number of basis functions (including all magnetic quantum numbers)
            - Nuclear positions are converted from Angstrom to Bohr internally
        """
        radial_shells, displacements = radials_and_displacements(
            grid_coords,
            nuc_pos,
            primitives,
        )
        out = jnp.empty((grid_coords.shape[0], compile_statics.num_basis_fns))
        # fast path for l = 0  -> angular = 1, so AO = radial
        ao0 = radial_shells.at[
            :, compile_statics.angular_shell_indices[0]
        ].get()  # [N, b_0] == [N, B_0]
        out = out.at[:, compile_statics.angular_basis_indices[0]].set(ao0)
        # for l > 0
        for l_angular in range(1, max_angular_momentum + 1):
            shell_mask = compile_statics.angular_shell_indices[l_angular]
            shell_atom_indices = compile_statics.shell_atom_indices[shell_mask]
            angulars = compute_angular_components(
                displacements, l_angular
            )  # N x A x M_sph
            ao = einops.rearrange(
                radial_shells.at[:, shell_mask].get()[..., None]
                * angulars.at[:, shell_atom_indices, :].get(),
                'n b_l m -> n (b_l m)',
            )  # N x B_l
            out = out.at[:, compile_statics.angular_basis_indices[l_angular]].set(ao)

        return out

    @partial(jax.jit, static_argnames=('compile_statics'))
    def atomic_orbitals_and_gradients(
        grid: FloatNx3,
        nuc_pos: FloatAx3 | NpFloatAx3,
        primitives: RadialPrimitives,
        compile_statics: GTOCompileStatics,
    ) -> Tuple[FloatNxB, FloatNxBx3]:
        def _aos(grid_coords: FloatNx3) -> FloatNxB:
            return atomic_orbitals(grid_coords, nuc_pos, primitives, compile_statics)

        def _aos_dr_aos(dr: Float3, grid_coords: FloatNx3) -> Tuple[FloatNxB, FloatNxB]:
            dr = jnp.zeros_like(grid_coords) + dr[None]
            return jax.jvp(_aos, (grid_coords,), (dr,))

        dr_xyz = jnp.eye(3, dtype=grid.dtype)
        aos, grad_aos = jax.vmap(_aos_dr_aos, in_axes=(0, None), out_axes=(None, -1))(
            dr_xyz, grid
        )
        return aos, grad_aos

    if deriv == 0:
        return atomic_orbitals
    elif deriv == 1:
        return atomic_orbitals_and_gradients
    else:
        raise ValueError(f'Derivative order {deriv} not supported')
