from dataclasses import dataclass
from functools import partial
from typing import Generic, Literal, TypeAlias, TypeVar

import jax
import jax.numpy as jnp

Measure = TypeVar('Measure', covariant=True)
TensorMeasures: TypeAlias = Literal[
    'mae',  # L1 alias
    'mse',  # L2 alias
    'rmse',
    'L1',
    'L2',
]


@dataclass(frozen=True, slots=True)
class TensorLossConfig(Generic[Measure]):
    measure: Measure
    scale_per_entry: bool


@partial(jax.jit, static_argnames=['config'])
def tensor_loss(
    delta_tensor: jax.Array,
    config: TensorLossConfig,
):
    if config.measure == 'mae' or config.measure == 'L1':
        out = jnp.abs(delta_tensor)
    elif config.measure == 'mse' or config.measure == 'L2' or config.measure == 'rmse':
        out = jnp.square(delta_tensor)
    else:
        raise ValueError(f'Invalid integral norm: {config.measure}')

    if config.scale_per_entry:
        out = jnp.mean(out)
    else:
        out = jnp.sum(out)

    if config.measure == 'rmse':
        out = jnp.sqrt(out)

    return out
