from typing import Callable, Literal

import jax.numpy as jnp
from flax.struct import dataclass

from egxc.utils.typing import Float1, FloatSCF, UInt1

ScalarMeasures = Literal['mae', 'mse', 'huber', 'asinh']


@dataclass
class ScalarLossConfig:
    measure: ScalarMeasures
    scale_per_electron: bool
    scale_parameter: float | None = None  # in case of energy this is in Hartree


def _sqrt_1_plus_x_squared(x: Float1) -> Float1:
    return jnp.hypot(1.0, x)  # = sqrt(1 + x^2) without overflow


def huber_loss_fn(scale: float) -> Callable[[Float1], Float1]:
    """
    Smooth interpolation between L1 (MAE) and L2 (MSE) loss.
    Args:
        scale: The characteristic input scale of the error.
    Returns:
        A function that computes the Huber loss.
    """

    def huber_loss_fn(diff: Float1) -> Float1:
        # sqrt always positive no masking required
        dU = diff / scale
        s = _sqrt_1_plus_x_squared(dU)
        return scale**2 * (s - 1)

    return huber_loss_fn


def asinh_loss_fn(a: float) -> Callable[[Float1], Float1]:
    """
    DelloStritto, M.; Klein, M. L.
    "Improved Loss Functions for Machine-Learned Atomic Potentials"
    The Journal of Chemical Physics 2025, 163 (13), 134108.
    https://doi.org/10.1063/5.0280032.
    """

    def asinh_loss_fn(diff: Float1) -> Float1:
        dU = diff / a
        s = _sqrt_1_plus_x_squared(dU)
        return a**2 * (1 - s + dU * jnp.asinh(dU))

    return asinh_loss_fn


def scalar_loss(
    target: Float1,
    prediction: Float1 | FloatSCF,
    n_electrons: UInt1,
    config: ScalarLossConfig,
) -> Float1:
    diff = prediction - target
    if config.scale_per_electron:
        diff /= n_electrons
    match config.measure:
        case 'mae':
            out = jnp.abs(diff)
        case 'mse':
            out = jnp.square(diff)
        case 'huber':
            assert config.scale_parameter is not None, (
                'Scale parameter is required for Huber loss'
            )
            loss_fn = huber_loss_fn(config.scale_parameter)
            out = loss_fn(diff)
        case 'asinh':
            assert config.scale_parameter is not None, (
                'Scale parameter is required for asinh loss'
            )
            loss_fn = asinh_loss_fn(config.scale_parameter)
            out = loss_fn(diff)
        case _:
            raise ValueError(f'Unknown energy loss measure: {config.measure}')
    return out
