from typing import Callable

import e3nn_jax as e3nn
import jax.numpy as jnp

from egxc.utils.typing import (
    FloatAxAx1,
    FloatAxAxRBF,
    FloatAxNx1,
    FloatAxNxRBF,
    RBFType,
)


def quintic_cutoff(
    r: FloatAxNx1, r_cut: float, r_switch_fraction: float = 0.9
) -> FloatAxNx1:
    """
    Quintic (C²) cutoff function with a smooth switching band.

    Args:
        r: Distances at which to evaluate the cutoff.
        r_cut: Cutoff radius. The function is exactly zero for r >= r_cut.
        r_switch_fraction: Fraction of r_cut at which the cutoff starts to decay. Default is 0.9,
            meaning the function is equal to 1 up to 0.9 * r_cut and then smoothly decays to 0 at r_cut.

    Returns:
        Values of the cutoff function at r, ranging smoothly from 1 (for
        r <= r_switch_fraction * r_cut) to 0 (for r >= r_cut).

    """
    assert 0.0 <= r_switch_fraction <= 1.0, 'r_switch_fraction must be between 0 and 1'
    r_switch = r_switch_fraction * r_cut
    u = (r - r_switch) / (r_cut - r_switch)
    u = jnp.clip(u, 0.0, 1.0)
    return 1.0 - (10.0 * u**3 - 15.0 * u**4 + 6.0 * u**5)


def coulomb_like_envelope(r: FloatAxNx1, r_cut: float, eps: float = 1e-6) -> FloatAxNx1:
    """
    Twice smoothly differentiable 'Coulomb-like' envelope.
    Returns E(r) ~ (1/r) near long-range, finite at r=0, and vanishes at r_cut.
    """
    # Soft-Coulomb core, normalized so E(0)=1
    soft = eps / jnp.sqrt(r * r + eps * eps)
    smooth = quintic_cutoff(r, r_cut)
    return soft * smooth


def _trigonometric_rbf(
    r: FloatAxNx1,
    n: int,
    envelope: Callable,
    add_constant: bool = True,
) -> FloatAxNxRBF:
    """
    Trigonometric radial basis function.
    Polynomial envelope equal to 0 at x=0 and x=1 respectively.
    https://e3nn-jax.readthedocs.io/en/latest/api/radial.html
    https://mariogeiger.ch/polynomial_envelope_for_gnn.pdf
    """
    out = []
    if add_constant:
        out.append(jnp.ones_like(r) * 0.1)
        n -= 1
    ns = jnp.arange(0, (n + 1) // 2, dtype=r.dtype) + 1
    out.append(jnp.sin(ns * jnp.pi * r))
    if n % 2 == 0:
        out.append(jnp.cos(ns * jnp.pi * r))
    else:
        out.append(jnp.cos(ns[:-1] * jnp.pi * r))

    out = jnp.concatenate(out, axis=-1)
    return jnp.sqrt(2.0) * out * envelope(r)


def _bessel_rbf(r: FloatAxNx1, n: int, envelope: Callable) -> FloatAxNxRBF:
    bessel_weights = jnp.arange(1, n + 1, dtype=r.dtype)
    return jnp.sinc(bessel_weights * r) * envelope(r)


def _polynomial_rbf(r: FloatAxNx1, n: int, envelope: Callable) -> FloatAxNxRBF:
    """
    Bartók et al. "On Representing Chemical Environments"
    https://doi.org/10.1103/PhysRevB.87.184115.
    """

    def _compute_matrix_inverse_sqrt(S):
        """
        Compute the inverse square root of a symmetric positive-definite matrix S
        using eigenvalue decomposition.
        """
        # Eigenvalue decomposition
        eigvals, eigvecs = jnp.linalg.eigh(S)  # eigh is for symmetric/Hermitian matrices
        # Compute Λ^{-1/2}
        D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(eigvals))
        # Compute S^{-1/2} = U Λ^{-1/2} U^T
        S_inv_sqrt = eigvecs @ D_inv_sqrt @ eigvecs.T
        return S_inv_sqrt

    ns = jnp.arange(1, n + 1, dtype=r.dtype)
    normalization = jnp.sqrt(2 * ns + 5)
    phi = (1 - r) ** (ns + 2) * normalization
    temp = jnp.sqrt(5 + 2 * ns)
    S = jnp.outer(temp, temp) / (5 + ns[None] + ns[:, None])
    W = _compute_matrix_inverse_sqrt(S)
    return jnp.einsum('anr,mr->anm', phi, W) * envelope(r)  # FloatAxNxRBF


def radial_basis_values(
    r: FloatAxNx1 | FloatAxAx1,
    r_cutoff: float,
    n: int,
    key: RBFType,
) -> FloatAxNxRBF | FloatAxAxRBF:
    """Return radial basis values of the requested family for distances ``r``.

    Args:
        r: Pairwise distances normalised by r_cutoff.
        r_cutoff: Cutoff radius.
        n: Number of basis functions.
        key: Radial family to use, either ``'trigonometric'`` or ``'polynomial'``.

    Returns:
        Tensor of shape ``FloatAxNxRBF`` containing the basis evaluations.
    """
    r = r[..., None] / r_cutoff
    match key:
        case 'trigonometric':
            return _trigonometric_rbf(r, n, envelope=e3nn.poly_envelope(5, 2))
        case 'polynomial':
            assert n <= 11, 'Orthogonolization becomes numerically unstable for n > 11'
            return _polynomial_rbf(r, n, envelope=e3nn.poly_envelope(5, 2))
        case 'bessel':
            return _bessel_rbf(r, n, envelope=e3nn.poly_envelope(5, 2))
        case 'smooth_finite':
            return e3nn.soft_one_hot_linspace(  # TODO: fix
                r,
                start=0.0,
                end=1.0,
                number=n,
                basis='smooth_finite',
                start_zero=False,
                end_zero=True,
            )
