import jax.numpy as jnp
import jax
from functools import partial
from typing import Callable, Optional
from scipy.interpolate import interp1d
import numpy as np
import optax


@partial(jax.jit, static_argnums=(1,))
def cvar(x: jax.Array, eta: float):
    return x * eta


@partial(jax.jit, static_argnums=(1,))
def wang(x: jax.Array, eta: float):
    x = (x + 1e-6) * (1 - 1e-6)
    transform = jax.scipy.stats.norm.ppf(x) + eta
    return jax.scipy.stats.norm.cdf(transform)


@partial(jax.jit, static_argnums=(1,))
def pow(x: jax.Array, eta: float):
    exponent = 1 / (1 + jnp.abs(eta))
    return 1 - jnp.pow(1 - x, exponent)


@jax.jit
def quanitle_regression_loss(target, predict, taus):
    pairwise_delta = target[..., None, :] - predict[..., None]
    abs_pairwise_delta = jnp.abs(pairwise_delta)
    taus = taus[..., None]
    loss = jnp.where(pairwise_delta < 0, (1 - taus) * abs_pairwise_delta, taus * abs_pairwise_delta)
    return loss


@jax.jit
def quantile_huber_loss(predict, target, taus, weight):
    pairwise_delta = target[..., None, :] - predict[..., None]
    pairwise_huber = optax.huber_loss(pairwise_delta)
    taus = taus[..., None]
    loss = jnp.where(pairwise_delta < 0, (1 - taus) * pairwise_huber, taus * pairwise_huber)
    loss = loss * weight[..., None, :]
    loss = loss.sum(axis=-1)
    return loss


@jax.jit
def quantile_huber_loss_without_weights(target, predict, taus):
    pairwise_delta = target[..., None, :] - predict[..., None]
    pairwise_huber = optax.huber_loss(pairwise_delta)
    taus = taus[..., None]
    loss = jnp.where(pairwise_delta < 0, (1 - taus) * pairwise_huber, taus * pairwise_huber)
    return loss


@partial(jax.jit, static_argnums=(1,))
def get_tau(key, shape):
    presum_tau = jax.random.uniform(key, shape) + 0.1
    presum_tau /= presum_tau.sum(axis=-1, keepdims=True)
    tau = jnp.cumsum(presum_tau, axis=-1)  # (N, T), note that they are tau1...tauN in the paper
    tau_hat = 0.5 * (tau[..., 1:] + tau[..., :-1])
    tau_hat = jnp.concatenate([tau[..., 0:1] / 2, tau_hat], axis=-1)
    return tau, tau_hat, presum_tau


@partial(jax.jit, static_argnums=(1,))
def cvar_density(x: jax.Array, eta):
    return jnp.where(x <= eta, jnp.ones_like(x), jnp.zeros_like(x)) / eta


@partial(jax.jit, static_argnums=(1,))
def wang_density(x: jax.Array, eta):
    x = x.clip(1e-6, 1 - 1e-6)
    denominator = jax.scipy.stats.norm.logpdf(jax.scipy.stats.norm.ppf(1 - x) + eta)
    numerator = jax.scipy.stats.norm.logpdf(jax.scipy.stats.norm.ppf(1 - x))
    return jnp.exp(denominator - numerator)


def _default_value_fn(quantiles):
    return quantiles.mean(axis=-2)


@partial(jax.jit, static_argnums=(1,))
def take_min_value_quantile(quantiles: jax.Array, value_fn: Callable = _default_value_fn):
    values = value_fn(quantiles)
    index = jnp.argmin(values, axis=-1, keepdims=True)
    index = jnp.expand_dims(index, axis=-2)
    ret_quantile = jnp.take_along_axis(quantiles, index, axis=-1).squeeze(axis=-1)
    return ret_quantile


def evaluate_risk_measure(scores: np.ndarray,
                          risk_measure: str,
                          risk_eta: Optional[float] = None,
                          ):
    scores.sort()
    xp = np.linspace(0, 1, len(scores))
    fn = interp1d(xp, scores)
    x_eval = np.random.uniform(0, 1, 10000)
    if risk_measure == 'neutral':
        return np.mean(scores)
    elif risk_measure == 'cvar':
        return scores[:int(risk_eta * len(scores))].mean()

    elif risk_measure == 'wang':
        x_eval = wang(x_eval, risk_eta)
        y_p = jnp.asarray(fn(x_eval))
        return np.asarray(y_p).mean()

    elif risk_measure == 'pow':
        x_eval = pow(x_eval, risk_eta)
        y_p = jnp.asarray(fn(x_eval))
        return y_p.mean()

