from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp

from egxc.utils.typing import (
    Bool2xB,
    Float2xBxE,
    FloatB,
    FloatBxB,
    FloatBxE,
    NpFloatBxB,
    UIntB,
)


def safe_sqrt(x: jax.Array) -> jax.Array:
    zero_mask = x == 0
    x_safe = jnp.where(zero_mask, 1.0, x)  # avoid singular gradient
    x_safe = jnp.sqrt(x_safe)  # type: ignore
    return jnp.where(zero_mask, 0.0, x_safe)


@partial(jax.jit, static_argnames=('axis'))
def safe_norm(
    x: jax.Array, axis: int = -1
) -> jax.Array:  # masks gradient singularities at zero
    x = jnp.sum(x * x, axis=axis)
    return safe_sqrt(x)


@jax.jit
def coeffs_to_density_matrix(
    coeff: FloatBxE | Float2xBxE, occupancy: UIntB | Bool2xB
) -> FloatBxB:
    """
    Calculates the density matrix from the coefficient matrix.
    """
    return jnp.einsum('pi,qi,i->pq', coeff, coeff, occupancy)


@jax.jit
def transformation_matrix(S: FloatBxB | NpFloatBxB) -> FloatBxB:
    """
    Returns X = S^{-1/2} (Löwdin orthogonalizer) such that X.T @ S @ X = I.
    """
    Lambda, V = jnp.linalg.eigh(S, symmetrize_input=True)
    inv_lambda = jnp.reciprocal(jnp.sqrt(Lambda))
    return jnp.einsum('ab,b,bd->ad', V, inv_lambda, V.T)


@jax.jit
def modified_generalized_eigenvalue_problem(
    F: FloatBxB, X: FloatBxB, mask=None
) -> Tuple[FloatB, FloatBxB]:
    """
    Returns a function that solves the generalized eigenvalue problem.
    In the context of SCF calculations F is the Fock matrix and X is the
    transformation matrix which diagonalizes the overlap matrix S.

    TODO: make more robust to degeneracies due to large basis sets
          (see section 7 in https://doi.org/10.3390/molecules25051218)
    """
    F_dash = X.T @ F @ X
    if mask is not None:
        F_dash *= mask
    e, C_dash = jnp.linalg.eigh(F_dash, symmetrize_input=True)
    C = X @ C_dash
    return e, C


@jax.jit
def symmetrize(x: jax.Array) -> jax.Array:
    """
    Symmetrizes a tensor over its last two axes.

    Args:
        x: The tensor to symmetrize with shape (..., T, T).

    Returns:
        The symmetrized tensor.
    """
    x = 0.5 * (x + jnp.swapaxes(x, -1, -2))
    return x


@jax.jit
def remove_trace(x: jax.Array) -> jax.Array:
    """
    Removes the trace of a tensor over its last two axes.
    For example to enforce particle-number conservation.

    Args:
        x: The tensor to remove the trace of with shape (..., T, T).

    Returns:
        The tensor with the trace removed.
    """
    T = x.shape[-1]
    trace = jnp.trace(x, axis1=-2, axis2=-1)  # shape (...)
    return x - trace[..., None, None] * jnp.eye(T, dtype=x.dtype) / T


# NOTE: jax.scipy.eigh presently (=January 2025) only implements B = None case
# from jax import scipy as jsp
# def direct_generalized_eigenvalue_problem(
#     F: FloatBxB, S: FloatBxB, mask=None
# ) -> Tuple[FloatB, FloatBxB]:
#     """
#     Returns a function that solves the generalized eigenvalue problem.
#     In the context of SCF calculations F is the Fock matrix and S is the overlap matrix.
#     """
#     return jsp.linalg.eigh(F, S)
