from dataclasses import dataclass
from functools import partial
from typing import Generic, Literal, TypeAlias, TypeVar

import jax
import jax.numpy as jnp

from egxc.systems import System
from egxc.utils.typing import (
    Float1,
    Float2xBxB,
    FloatBxB,
    FloatN,
    FloatNxB,
)

Measure = TypeVar('Measure', covariant=True)
FieldMeasures: TypeAlias = Literal[
    'mae',  # L1 alias
    'mse',  # L2 alias
    'rmse',
    'L1',
    'L2',
    'overlap_based_mae_surrogate',
    'overlap_based_mse_surrogate',
    'spillage',
]


@dataclass(frozen=True, slots=True)
class BaseFieldLossConfig(Generic[Measure]):
    measure: Measure
    scale_per_electron: bool


def field_fn(matrix: FloatBxB | Float2xBxB, grid_aos: FloatNxB) -> FloatN:
    return jnp.einsum('uv,iu,iv->i', matrix, grid_aos, grid_aos)


def _delta_field_fn_same_basis(
    target: FloatBxB | Float2xBxB, reference: FloatBxB | Float2xBxB, grid_aos: FloatNxB
) -> FloatN:
    matrix = target - reference
    return field_fn(matrix, grid_aos)


def _delta_field_fn_different_basis(
    target: FloatBxB | Float2xBxB, reference: FloatBxB | Float2xBxB, grid_aos: FloatNxB
) -> FloatN:
    # TODO: implement optimal projections between density matrices of different datasets
    raise NotImplementedError('Different basis sets not yet implemented.')


@partial(jax.jit, static_argnames=['reference_basis_is_same'])
def delta_field_fn(
    target: FloatBxB | Float2xBxB,
    reference: FloatBxB | Float2xBxB,
    grid_aos: FloatNxB,
    reference_basis_is_same: bool,
) -> FloatN:
    if reference_basis_is_same:
        return _delta_field_fn_same_basis(target, reference, grid_aos)
    else:
        return _delta_field_fn_different_basis(target, reference, grid_aos)


def _field_L1_norm(
    field: FloatN,
    grid_weights: FloatN,
) -> Float1:
    return jnp.sum(jnp.abs(field) * grid_weights)


def _field_L2_norm(
    field: FloatN,
    grid_weights: FloatN,
) -> Float1:
    return jnp.sum(jnp.square(field) * grid_weights)


@partial(jax.jit, static_argnames=['norm'])
def field_integral_measures(
    field: FloatN,
    grid_weights: FloatN,
    norm: Literal['L1', 'L2'],
) -> Float1:
    if norm == 'L1':
        return _field_L1_norm(field, grid_weights)
    elif norm == 'L2':
        return _field_L2_norm(field, grid_weights)
    else:
        raise ValueError(f'Invalid integral norm: {norm}')


def overlap_based_mae_surrogate(
    target: FloatBxB | Float2xBxB, prediction: FloatBxB | Float2xBxB, overlap: FloatBxB
) -> Float1:
    """
    Computes an overlap-weighted L1 surrogate (MAE-like) loss between two density matrices
    without requiring integration on a real-space grid.

    This metric approximates the real-space L1 distance between densities by weighting
    the elementwise absolute difference of the density matrices with the magnitude of
    the AO overlap matrix elements. It provides a cheap, grid-free proxy for the
    integral of |rho_pred(r) - rho_ref(r)| over space.

    Mathematically:
        L1 is approximated by sum_{mu,nu} |delta_P[mu,nu]| * |S[mu,nu]|

    Args:
        target: Reference (ground-truth) density matrix, shape (B, B) or (2, B, B)
        prediction: Predicted density matrix, same shape as `target`
        overlap: AO overlap matrix S, shape (B, B)

    Returns:
        Scalar loss: overlap-weighted mean absolute deviation between the two densities.

    Assumptions under which this surrogate would be exact:
        - The atomic orbital (AO) basis functions {chi_mu} form an orthonormal and complete
          basis, meaning the overlap matrix S equals the identity (S = I).
        - In that limit, the electron density rho(r) = sum_{mu,nu} P[mu,nu] * chi_mu(r) * chi_nu(r)
          reproduces the continuous density exactly.
        - Then the expression sum_{mu,nu} |delta_P[mu,nu]| * |S[mu,nu]| reduces to
          sum_{mu} |delta_P[mu,mu]|, which corresponds exactly to the real-space L1 norm
          of the density difference.

    Notes:
        - In practical finite, non-orthogonal AO bases, this is only an approximation,
          but it correlates well with the true grid-based L1 density error.
        - The absolute-value weighting by |S| emphasizes AO pairs with significant spatial
          overlap, partially compensating for basis non-orthogonality.
    """
    abs_delta_matrix = jnp.abs(target - prediction)
    return (abs_delta_matrix * jnp.abs(overlap)).sum()


@jax.jit
def overlap_based_mse_surrogate(
    target: FloatBxB | Float2xBxB, prediction: FloatBxB | Float2xBxB, overlap: FloatBxB
):
    """
    Computes an overlap-weighted L2 surrogate (MSE-like) loss between two density matrices
    without requiring integration on a real-space grid.

    This metric approximates the real-space L2 norm of the density difference by using
    the AO overlap matrix as a discrete representation of spatial integration. It computes
    Tr[(delta_P S) (delta_P S)] = Tr(delta_P S delta_P S), which can be viewed as an
    overlap-weighted Frobenius norm of the density-matrix difference.

    Mathematically:
        L2_overlap = Tr(delta_P S delta_P S)
                   = sum_{i,j,k,l} delta_P[i,j] * S[j,k] * delta_P[k,l] * S[l,i]

    Args:
        target: Reference density matrix, shape (B, B) or (2, B, B)
        prediction: Predicted density matrix, same shape as `target`
        overlap: AO overlap matrix S, shape (B, B)

    Returns:
        Scalar overlap-weighted Frobenius norm (approximate L2 density error).

    Quality of approximation:
        - The expression would be exact if the four-center overlap integrals would factorize as:
          int chi_i(r) * chi_j(r) * chi_k(r) * chi_l(r) dr ≈ S[j,k] * S[l,i].
        - This factorization is a good approximation when the AO basis functions are
          localized and slowly varying, such that products of AO pairs are nearly
          separable in space.

    Notes:
        - The metric remains approximate in general finite non-orthogonal AO bases but
          correlates well with the true grid-based L2 density error.
        - The overlap weighting accounts for non-orthogonality between basis functions
          and provides a physically meaningful, grid-free error measure.
    """
    dP = target - prediction
    A = dP @ overlap
    return jnp.einsum('ik,ki->', A, A)


@partial(jax.jit, static_argnames=['eps'])
def spillage_loss(
    target: FloatBxB | Float2xBxB,
    prediction: FloatBxB | Float2xBxB,
    overlap: FloatBxB,
    eps: float = 1e-15,
    cutoff: float = 1e-3,
):
    # TODO: test
    # Eigen-decomp
    lam, V = jnp.linalg.eigh(overlap, symmetrize_input=True)

    if cutoff is not None:
        keep = lam >= cutoff
        lam = lam[keep]
        V = V[:, keep]

    lam_clipped = jnp.clip(lam, eps)
    inv_sqrt = 1.0 / jnp.sqrt(lam_clipped)

    # X = V * diag(lam^{-1/2}) * V.T  (note: returns BxB' if cutoff used)
    X = V @ (inv_sqrt[:, None] * V.T)

    target = jnp.einsum('ab,...bc,cd->...ad', X.T, target, X)
    prediction = jnp.einsum('ab,...bc,cd->...ad', X.T, prediction, X)

    num = jnp.einsum('...ij,...ji->...', target, prediction)
    den = jnp.einsum('...ii->...', target)

    spill = 1.0 - (num / jnp.clip(den, eps))
    return jnp.mean(spill)


@partial(jax.jit, static_argnames=['config', 'reference_basis_is_same'])
def field_loss(
    target: FloatBxB | Float2xBxB,
    prediction: FloatBxB | Float2xBxB,
    sys: System,
    config: BaseFieldLossConfig,
    reference_basis_is_same: bool,
):
    if config.measure == 'overlap_based_mae_surrogate':
        out = overlap_based_mae_surrogate(target, prediction, sys.fock_tensors.overlap)
    elif config.measure == 'overlap_based_mse_surrogate':
        out = overlap_based_mse_surrogate(target, prediction, sys.fock_tensors.overlap)
    elif config.measure == 'spillage':
        out = spillage_loss(target, prediction, sys.fock_tensors.overlap)
    else:
        delta_field = delta_field_fn(
            target, prediction, sys.grid.aos, reference_basis_is_same
        )
        if config.measure == 'mae' or config.measure == 'L1':
            norm = 'L1'
        elif config.measure == 'mse' or config.measure == 'L2':
            norm = 'L2'
        else:
            raise ValueError(f'Invalid integral norm: {config.measure}')
        out = field_integral_measures(delta_field, sys.grid.weights, norm)
        if config.measure == 'rmse':
            out = jnp.sqrt(out)
    if config.scale_per_electron:
        out /= sys.n_electrons
    return out
