from dataclasses import dataclass
from functools import partial
from typing import Literal

import jax
import jax.numpy as jnp

from egxc.systems import System
from egxc.training.loss.field import (
    BaseFieldLossConfig,
    FieldMeasures,
    field_loss,
)
from egxc.utils.typing import (
    Float1,
    Float2xBxB,
    FloatBxB,
)


@dataclass(frozen=True, slots=True)
class XCPotentialFieldLossConfig(BaseFieldLossConfig[FieldMeasures]):
    per_sample_optimal_gauge: Literal['L1', 'L2', 'none'] = 'none'


@dataclass(frozen=True, slots=True)
class XCLinearResponseFieldLossConfig(BaseFieldLossConfig[FieldMeasures]):
    n_perturbations: int
    differentiate_through_ground_state: bool


def _L1_optimal_gauge_xc_potential(
    delta_xc_matrix: FloatBxB, overlap: FloatBxB, overlap_cutoff: float = 1e-12
) -> FloatBxB:
    """Shift delta_xc_matrix by the optimal L1 gauge C * S (weighted median solution)."""
    # TODO: currently spin-restricted only
    assert delta_xc_matrix.ndim == 2
    dV, S = delta_xc_matrix, overlap
    # Safe ratios and weights (no division where S is close to 0)
    r = jnp.divide(dV, S)
    # Valid where |S|>cutoff AND ratio is finite
    mask = (jnp.abs(S) > overlap_cutoff) & jnp.isfinite(r)
    r = jnp.where(mask, r, 0.0)  # same shape as dV
    w = jnp.where(mask, S * S, 0.0)  # |S|^2  (S is real-valued)
    # Flatten, sort by r, compute weighted median
    r_flat = r.reshape(-1)
    w_flat = w.reshape(-1)

    idx = jnp.argsort(r_flat)
    r_sorted = r_flat[idx]
    w_sorted = w_flat[idx]

    cumw = jnp.cumsum(w_sorted)
    half = 0.5 * jnp.sum(w_sorted)
    k = jnp.searchsorted(cumw, half, side='left')
    C = r_sorted[k]
    # jax.debug.print('L1 optimal gauge C: {C:.3e}', C=C)
    return dV - jax.lax.stop_gradient(C) * S


def _L2_optimal_gauge_xc_potential(
    delta_xc_matrix: FloatBxB, overlap: FloatBxB, overlap_cutoff: float = 1e-12
) -> FloatBxB:
    """
    Shift delta_xc_matrix by the L2-optimal gauge C*S, where
        C* = <S, dV> / <S, S>
    computed over entries with |S| > overlap_cutoff.
    Returns: dV - stop_gradient(C*) * S, same shape as delta_xc_matrix.
    """
    # TODO: currently spin-restricted only
    assert delta_xc_matrix.ndim == 2
    dV, S = delta_xc_matrix, overlap
    mask = jnp.abs(S) > overlap_cutoff
    numerator = jnp.sum(jnp.where(mask, dV * S, 0.0), dtype=dV.dtype)
    normalization = jnp.sum(jnp.where(mask, S * S, 0.0), dtype=dV.dtype)
    C = numerator / normalization
    # jax.debug.print('L2 optimal gauge C: {C:.3e}', C=C)
    return dV - jax.lax.stop_gradient(C) * S


@partial(jax.jit, static_argnames=['config', 'reference_basis_is_same'])
def xc_potential_loss(
    target: FloatBxB | Float2xBxB,
    prediction: FloatBxB | Float2xBxB,
    sys: System,
    config: XCPotentialFieldLossConfig,
    reference_basis_is_same: bool,
):
    if config.per_sample_optimal_gauge != 'none':
        delta = target - prediction
        target = jnp.zeros_like(target)
        if config.per_sample_optimal_gauge == 'L1':
            prediction = _L1_optimal_gauge_xc_potential(delta, sys.fock_tensors.overlap)
        elif config.per_sample_optimal_gauge == 'L2':
            prediction = _L2_optimal_gauge_xc_potential(delta, sys.fock_tensors.overlap)
    return field_loss(target, prediction, sys, config, reference_basis_is_same)


@partial(jax.jit, static_argnames=['config', 'reference_basis_is_same'])
def xc_potential_linear_response_loss(
    target: FloatBxB | Float2xBxB,
    prediction: FloatBxB | Float2xBxB,
    sys: System,
    config: BaseFieldLossConfig[FieldMeasures],
    reference_basis_is_same: bool,
) -> Float1:
    return field_loss(target, prediction, sys, config, reference_basis_is_same)


##### NOTE: This is deprecated
# @partial(jax.jit, static_argnames=['reference_basis_is_same'])
# def density_hessian_diagonal_loss(
#     target: FloatOxV,
#     prediction: FloatOxV,
#     reference_basis_is_same: bool,
# ) -> Float1:
#     assert reference_basis_is_same
#     return jnp.sum(jnp.square(target - prediction))
