from functools import partial

import chex
import jax
import jax.numpy as jnp


def huber_loss(x: jnp.ndarray, delta: float = 1.0) -> jnp.ndarray:
    chex.assert_type(x, float)
    # 0.5 * x^2                  if |x| <= d
    # 0.5 * d^2 + d * (|x| - d)  if |x| > d
    abs_x = jnp.abs(x)
    quadratic = jnp.minimum(abs_x, delta)
    # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient.
    linear = abs_x - quadratic
    return 0.5 * quadratic**2 + delta * linear


def _quantile_huber_loss(
    dist_src: jnp.ndarray,  # (num_quantiles, )
    tau_src: jnp.ndarray,  # (num_quantiles, )
    dist_target: jnp.ndarray,  # (num_quantiles, )
    huber_param: float = 0,
    stop_target_gradients: bool = True,
):
    # Calculate quantile error.
    delta = dist_target[None, :] - dist_src[:, None]
    delta_neg = (delta < 0.0).astype(jnp.float32)
    delta_neg = jax.lax.select(
        stop_target_gradients, jax.lax.stop_gradient(delta_neg), delta_neg
    )
    weight = jnp.abs(tau_src[:, None] - delta_neg)

    # Calculate Huber loss.
    if huber_param > 0.0:
        loss = huber_loss(delta, huber_param)
    else:
        loss = jnp.abs(delta)
    loss *= weight

    # Average over target-samples dimension, sum over src-samples dimension.
    return jnp.sum(jnp.mean(loss, axis=-1))


@partial(jax.jit, static_argnums=(3, 4))
def batched_quantile_huber_loss(
    dist_src: jnp.ndarray,  # (b, num_quantiles, num_actions)
    tau_src: jnp.ndarray,  # (num_quantiles, )
    dist_target: jnp.ndarray,  # (b, num_quantiles, num_actions)
    huber_param: float = 0,
    stop_target_gradients: bool = True,
) -> jnp.ndarray:
    """Batched version of quantile_huber_loss."""

    _batched_quantile_huber_loss = jax.vmap(
        _quantile_huber_loss,
        in_axes=(0, None, 0, None, None),
        out_axes=0,
    )

    return _batched_quantile_huber_loss(
        dist_src, tau_src, dist_target, huber_param, stop_target_gradients
    )
