import jax.numpy as jnp
import jax
from typing import Callable
from functools import partial
import numpy as np

jax_trapezoid = jax.scipy.integrate.trapezoid  # version consistency


@jax.custom_jvp
def safe_sqrt(x):
    return jnp.sqrt(x)


@safe_sqrt.defjvp
def safe_sqrt_jvp(primals, tangents):
    x = primals[0]
    x_dot = tangents[0]
    primal_out = safe_sqrt(x)
    tangent_out = 0.5 / jnp.sqrt(jnp.abs(x) + 1e-8) * x_dot * jax.nn.relu(jnp.sign(x))
    return primal_out, tangent_out


@partial(jax.jit, static_argnums=(1, 2))
def safe_std(x, axis=-1, keepdims=False):
    var = jnp.var(x, axis=axis, keepdims=True)
    temp = jnp.where(var > 0, x, jax.random.normal(jax.random.PRNGKey(42), shape=x.shape))
    std_temp = jnp.std(temp, axis=axis, keepdims=keepdims)
    return jnp.where(jnp.var(x, axis=axis) > 0, std_temp, jnp.zeros_like(std_temp))


@jax.jit
def safe_corr(x, y):
    covar = jnp.cov(x, y)[0, 1]
    x_dummy = jax.random.normal(key=jax.random.PRNGKey(42), shape=x.shape)
    y_dummy = jax.random.normal(key=jax.random.PRNGKey(24), shape=y.shape)
    x_temp = jnp.where(safe_std(x) > 0, x, x_dummy)
    y_temp = jnp.where(safe_std(y) > 0, y, y_dummy)
    corr = jnp.corrcoef(x_temp, y_temp)[0, 1]
    return jnp.where(covar != 0, corr, 0)


@jax.jit
def distortion_lambda(mu_f, mu, sigma_f, sigma,
                      gamma, covar, tolerance, clip_condition):
    K = ((mu_f - mu) ** 2 + sigma_f ** 2 + sigma ** 2 - tolerance) / 2
    C_gamma_f = covar
    V = jnp.var(gamma)
    denominator = jnp.abs(K ** 2 - (sigma ** 2 * sigma_f ** 2))
    log_denominator = jnp.log(jnp.clip(denominator, 1e-12))
    log_numerator = jnp.log(jnp.abs(C_gamma_f ** 2 - V * sigma_f ** 2).clip(1e-12))
    sqrt_term = jnp.exp(0.5 * (log_numerator - log_denominator))
    unclipped_sigma_f = sigma_f
    sigma_f = sigma_f + 1e-5
    lambd = K / (sigma_f ** 2) * sqrt_term - C_gamma_f / (sigma_f ** 2)
    lambd = jnp.where((unclipped_sigma_f > 0) & clip_condition, lambd, jnp.zeros_like(lambd))
    return lambd


@jax.jit
def model_risk_fn(f_ref, gamma, mu, sigma, tolerance):
    mu_f = f_ref.mean()
    sigma_f = safe_std(f_ref)
    covar = jnp.cov(f_ref, gamma, bias=True)[0, 1]
    c_0 = safe_corr(f_ref, gamma)
    l2 = (mu_f - mu) ** 2 + (sigma_f - sigma) ** 2
    clip_condition = jax.lax.stop_gradient(tolerance < l2 + 2 * sigma * sigma_f * (1 - c_0))
    apply_condition = (l2 < tolerance)

    raw_lambd = distortion_lambda(mu_f, mu, sigma_f, sigma, gamma, covar, tolerance, clip_condition)
    lambd = jnp.where(clip_condition, raw_lambd, jnp.zeros_like(raw_lambd))

    corr = safe_corr(gamma, gamma + lambd * f_ref)
    risk = mu + sigma * jnp.std(gamma) * corr
    return risk, jnp.std(gamma) * corr, apply_condition


@jax.jit
def model_risk_quantile_fn(f_ref, gamma, mu, sigma, tolerance):
    mu_f = f_ref.mean()
    sigma_f = safe_std(f_ref)
    covar = jnp.cov(f_ref, gamma, bias=True)[0, 1]
    c_0 = safe_corr(f_ref, gamma)
    l2 = (mu_f - mu) ** 2 + (sigma_f - sigma) ** 2
    clip_condition = jax.lax.stop_gradient(tolerance < l2 + 2 * sigma * sigma_f * (1 - c_0))
    apply_condition = (l2 < tolerance)

    raw_lambd = distortion_lambda(mu_f, mu, sigma_f, sigma, gamma, covar, tolerance, clip_condition)
    lambd = jnp.where(clip_condition, raw_lambd, jnp.zeros_like(raw_lambd))
    lambd = jax.lax.stop_gradient(lambd)

    interpolated_fn = gamma + lambd * f_ref
    a_lambda = jnp.mean(interpolated_fn)
    b_lambda = jnp.std(interpolated_fn).clip(1e-6)
    worst_case_quantile = mu + sigma * (interpolated_fn - a_lambda) / b_lambda
    worst_case_quantile = jnp.where(apply_condition, worst_case_quantile, f_ref)
    return worst_case_quantile


@jax.jit
def calculate_lambda(f_ref, gamma, mu, sigma, tolerance):
    mu_f = f_ref.mean()
    sigma_f = safe_std(f_ref)
    covar = jnp.cov(f_ref, gamma, bias=True)[0, 1]

    c_0 = safe_corr(f_ref, gamma)
    l2 = (mu_f - mu) ** 2 + (sigma_f - sigma) ** 2
    clip_condition = jax.lax.stop_gradient(tolerance < l2 + 2 * sigma * sigma_f * (1 - c_0))
    apply_condition = (l2 < tolerance)
    raw_lambd = distortion_lambda(mu_f, mu, sigma_f, sigma, gamma, covar, tolerance, clip_condition)
    lambd = jnp.where(clip_condition, raw_lambd, jnp.zeros_like(raw_lambd))
    return lambd, apply_condition


@partial(jax.jit, static_argnums=(0, 1))
def calculate_model_risk(density_fn: Callable, risk_eta, std: jax.Array,
                         q_values: jax.Array, risk_q_values: jax.Array,
                         neutral_taus: jax.Array, risk_taus: jax.Array):

    index = jnp.argmin(risk_q_values.mean(axis=-2), axis=-1, keepdims=True)[..., None, :]
    qf = jnp.take_along_axis(q_values, index, axis=-1).squeeze(axis=-1)
    neg_risk = risk_q_values.mean(axis=-2).min(axis=-1)
    mu = neg_risk - jax.lax.stop_gradient(neg_risk) + qf.mean(axis=-1)
    sigma = qf.std(axis=-1)

    ensemble = q_values.reshape(q_values.shape[0], -1).sort(axis=-1)
    pooled_q = jnp.concatenate([q_values, risk_q_values], axis=-2)
    pooled_taus = jnp.concatenate([neutral_taus, risk_taus], axis=-1)
    pooled_taus_index = jnp.argsort(pooled_taus, axis=-1)
    pooled_taus = jnp.take_along_axis(pooled_taus, pooled_taus_index, axis=-1)  # sorted taus
    # inside of vmap, (batch, Num_Pools)
    pooled_q = jax.vmap(partial(jnp.take_along_axis, indices=pooled_taus_index, axis=-1), in_axes=-1,
                        out_axes=-1)(pooled_q)

    def trapz(f, x):
        x_pt = jnp.linspace(0, 1, ensemble.shape[-1])
        f_interpolation = jax.vmap(jnp.interp, in_axes=(None, 0, 0), out_axes=0)(x_pt, x, f)
        trap_z_fn = jax.vmap(lambda y, x: jax_trapezoid(y, x=x, axis=-1), in_axes=(0, None), out_axes=0)
        trap_z = trap_z_fn((f_interpolation - ensemble) ** 2, x_pt)
        return trap_z

    tolerance_fn = jax.vmap(trapz, in_axes=(2, None), out_axes=1)

    tolerance = tolerance_fn(pooled_q, pooled_taus).max(axis=-1)
    density = density_fn(neutral_taus, risk_eta)

    model_risk_adjust, lambd, apply_cond = jax.vmap(model_risk_fn)((-qf).sort(axis=-1),
                                                                   density.sort(axis=-1), -mu, sigma,
                                                                   jax.lax.stop_gradient(tolerance))
    model_risk_adjust = -model_risk_adjust  # -(-mu + sigma * Something) = mu - Sigma * something
    risk = risk_q_values.mean(axis=-2).min(axis=-1)
    model_risk_adjust = jnp.where(apply_cond, model_risk_adjust, risk)
    interpolation_ration = 1. - (jnp.logical_or(lambd == 0, jnp.logical_not(apply_cond))).astype(jnp.float32)
    return model_risk_adjust, interpolation_ration


@partial(jax.jit, static_argnums=(0, 1))
def calculate_model_risk_v2(density_fn: Callable, risk_eta, bc_qf,
                            q_values: jax.Array, risk_q_values: jax.Array,
                            neutral_taus: jax.Array, risk_taus: jax.Array):
    index = jnp.argmin(risk_q_values.mean(axis=-2), axis=-1, keepdims=True)[..., None, :]
    qf = jnp.take_along_axis(q_values, index, axis=-1).squeeze(axis=-1)

    mu = bc_qf.mean(axis=-2).min(axis=-1)
    sigma = bc_qf.std(axis=-2).max(axis=-1)

    def w2(xp, yp1, yp2):
        return jax_trapezoid((yp1 - yp2) ** 2, xp, axis=-1)

    tolerance = jax.vmap(w2, in_axes=(None, -1, -1), out_axes=-1)(neutral_taus, bc_qf, q_values)
    tolerance = tolerance.max(axis=-1)

    density = density_fn(neutral_taus, risk_eta)

    model_risk_adjust, lambd, apply_cond = jax.vmap(model_risk_fn)((-qf).sort(axis=-1),
                                                                   density.sort(axis=-1), -mu, sigma,
                                                                   jax.lax.stop_gradient(tolerance))
    model_risk_adjust = -model_risk_adjust  # -(-mu + sigma * Something) = mu - Sigma * something
    risk = risk_q_values.mean(axis=-2).min(axis=-1)
    model_risk_adjust = jnp.where(apply_cond, model_risk_adjust, risk)
    interpolation_ration = 1. - (jnp.logical_or(lambd == 0, jnp.logical_not(apply_cond))).astype(jnp.float32)
    return model_risk_adjust, interpolation_ration


@partial(jax.jit, static_argnums=(0, 1))
def calculate_worst_case_quantilev2(density_fn: Callable, risk_eta,
                                    q_values: jax.Array, risk_q_values: jax.Array,
                                    neutral_taus: jax.Array, bc_qf: jax.Array):
    index = jnp.argmin(risk_q_values.mean(axis=-2), axis=-1, keepdims=True)[..., None, :]
    qf = jnp.take_along_axis(q_values, index, axis=-1).squeeze(axis=-1)

    mu = bc_qf.mean(axis=-2).min(axis=-1)
    sigma = bc_qf.std(axis=-2).max(axis=-1)

    def w2(xp, yp1, yp2):
        return jax_trapezoid((yp1 - yp2) ** 2, xp, axis=-1)

    tolerance = jax.vmap(w2, in_axes=(None, -1, -1), out_axes=-1)(neutral_taus, bc_qf, q_values)
    tolerance = tolerance.max(axis=-1)
    density = density_fn(neutral_taus, risk_eta)
    vmap_quantile = jax.vmap(model_risk_quantile_fn, in_axes=(0, 0, 0, 0, 0), out_axes=0)
    model_risk_quantile = vmap_quantile((-qf).sort(axis=-1), density.sort(axis=-1), -mu,
                                        sigma.clip(1e-2), jax.lax.stop_gradient(tolerance))

    return -model_risk_quantile


@partial(jax.jit, static_argnums=(0, 1))
def calculate_worst_case_quantile(density_fn: Callable, risk_eta,
                                  q_values: jax.Array, risk_q_values: jax.Array,
                                  neutral_taus: jax.Array, risk_taus: jax.Array):
    index = jnp.argmin(risk_q_values.mean(axis=-2), axis=-1, keepdims=True)
    qf = jnp.take_along_axis(q_values, index[..., None, :], axis=-1).squeeze(axis=-1)
    ensemble = q_values.reshape(q_values.shape[0], -1).sort(axis=-1)

    pooled_q = jnp.concatenate([q_values, risk_q_values], axis=-2)
    pooled_taus = jnp.concatenate([neutral_taus, risk_taus], axis=-1)
    pooled_taus_index = jnp.argsort(pooled_taus, axis=-1)
    pooled_taus = jnp.take_along_axis(pooled_taus, pooled_taus_index, axis=-1)  # sorted taus
    pooled_q = jax.vmap(partial(jnp.take_along_axis, indices=pooled_taus_index, axis=-1), in_axes=-1,
                        out_axes=-1)(pooled_q)

    mu = q_values.mean(axis=-2).min(axis=-1)
    sigma = jnp.sqrt(jnp.var(q_values, axis=-2).mean(axis=-1) + jnp.var(q_values.mean(axis=-2), axis=-1))

    def trapz(f, x):
        x_pt = jnp.linspace(0, 1, ensemble.shape[-1])
        f_interpolation = jax.vmap(jnp.interp, in_axes=(None, 0, 0), out_axes=0)(x_pt, x, f)
        trap_z_fn = jax.vmap(lambda y, x: jax_trapezoid(y, x=x, axis=-1), in_axes=(0, None), out_axes=0)
        trap_z = trap_z_fn((f_interpolation - ensemble) ** 2, x_pt)
        return trap_z

    tolerance_fn = jax.vmap(trapz, in_axes=(2, None), out_axes=1)
    tolerance = tolerance_fn(pooled_q, pooled_taus).max(axis=-1)
    density = density_fn(neutral_taus, risk_eta)
    vmap_quantile = jax.vmap(model_risk_quantile_fn, in_axes=(0, 0, 0, 0, 0), out_axes=0)
    model_risk_quantile = vmap_quantile((-qf).sort(axis=-1), density.sort(axis=-1), -mu,
                                        sigma.clip(1e-2), jax.lax.stop_gradient(tolerance))

    return -model_risk_quantile


@partial(jax.jit, static_argnums=(0, 1))
def calculate_model_risk_var_limit(q_values: jax.Array, risk_taus: jax.Array):
    ensemble = q_values.reshape(q_values.shape[0], -1).sort(axis=-1)

    mu = ensemble.mean(axis=-1)
    sigma = safe_std(ensemble, axis=-1)

    def fx(x):
        x = x.clip(1e-12, )
        return mu[..., None] - sigma[..., None] * jnp.sqrt((1 - x) / x)

    return fx(risk_taus)


@partial(jax.jit, static_argnums=(0, 1))
def calculate_model_risk_consist(density_fn: Callable, risk_eta,
                                 target_q: jax.Array,
                                 q_values: jax.Array, risk_q_values: jax.Array,
                                 neutral_taus: jax.Array):
    index = jnp.argmin(risk_q_values.mean(axis=-2), axis=-1, keepdims=True)[..., None, :]
    qf = jnp.take_along_axis(q_values, index, axis=-1).squeeze(axis=-1)
    mu = target_q.mean(axis=-1)
    sigma = target_q.std(axis=-1)

    # mu = qf.mean(axis=-1)
    # sigma = safe_std(qf, axis=-1)

    def trapz(f, x):
        x_pt = jnp.linspace(0, 1, target_q.shape[-1])
        f_interpolation = jax.vmap(jnp.interp, in_axes=(None, 0, 0), out_axes=0)(x_pt, x, f)
        trap_z_fn = jax.vmap(lambda y, x: jax_trapezoid(y, x=x, axis=-1), in_axes=(0, None), out_axes=0)
        trap_z = trap_z_fn((f_interpolation - target_q) ** 2, x_pt)
        return trap_z

    tolerance_fn = jax.vmap(trapz, in_axes=(2, None), out_axes=1)
    tolerance = tolerance_fn(q_values, neutral_taus)
    tolerance = tolerance.max(axis=-1)

    density = density_fn(neutral_taus, risk_eta)
    model_risk_adjust, lambd, apply_cond = jax.vmap(model_risk_fn)(-qf, density, -mu, sigma,
                                                                   jax.lax.stop_gradient(tolerance))

    model_risk_adjust = -model_risk_adjust
    risk = risk_q_values.mean(axis=-2).min(axis=-1)
    model_risk_adjust = jnp.where(apply_cond, model_risk_adjust, risk)
    interpolation_ration = 1. - (lambd == 0).astype(jnp.float32)
    return model_risk_adjust, interpolation_ration


# @partial(jax.jit, static_argnums=(0, 1))
def calculate_cvar_model_risk(density_fn: Callable, risk_eta,
                              q_values: jax.Array, risk_q_values: jax.Array,
                              neutral_taus: jax.Array):
    """
    :param density_fn: dummy variable
    """
    risk_q = risk_q_values.mean(axis=-2)
    index = jnp.argmin(risk_q, axis=-1, keepdims=True)[..., None, :]

    qf = jnp.take_along_axis(q_values, index, axis=-1).squeeze(axis=-1)
    ensemble = q_values.reshape(q_values.shape[0], -1).sort(axis=-1)

    mu = ensemble.mean(axis=-1)
    sigma = safe_std(q_values, axis=-2).mean(axis=-1)

    def trapz(f, x):
        x_pt = jnp.linspace(0, 1, ensemble.shape[-1])
        f_interpolation = jax.vmap(jnp.interp, in_axes=(None, 0, 0), out_axes=0)(x_pt, neutral_taus, f)
        trap_z_fn = jax.vmap(lambda y, x: jax_trapezoid(y, x=x, axis=-1), in_axes=(0, None), out_axes=0)
        trap_z = trap_z_fn((f_interpolation - ensemble) ** 2, x_pt)
        return trap_z

    tolerance_fn = jax.vmap(trapz, in_axes=(2, None), out_axes=1)
    tolerance = tolerance_fn(q_values, neutral_taus)
    tolerance = tolerance.max(axis=-1)

    density = density_fn(neutral_taus, risk_eta)
    lambd, apply_cond = jax.vmap(calculate_lambda)(-qf, density, -mu, sigma, jax.lax.stop_gradient(tolerance))
    alpha = (1 - risk_eta) / risk_eta

    risk_q = risk_q.min(axis=-1)

    diff = jnp.abs(qf.mean(axis=-1) - risk_q)

    denominator = jnp.log((alpha + 2 * lambd * diff + (lambd * safe_std(qf, axis=-1)) ** 2))
    numerator = jnp.log((alpha + lambd * diff))
    coef = jnp.exp(numerator - 0.5 * denominator)
    model_risk_adjust = mu - sigma * coef
    jnp.where(apply_cond, model_risk_adjust, risk_q)

    interpolation_ration = 1. - (lambd == 0).astype(jnp.float32)
    return model_risk_adjust, qf.mean(axis=-1)


if __name__ == '__main__':
    import jax.numpy as jnp
    import numpy as np
    from rl.utils.risk_utils import cvar_density

    q1 = jax.scipy.stats.norm.ppf(jnp.linspace(0.05, 1 - 0.05, 30), )
    q2 = jax.scipy.stats.norm.ppf(jnp.linspace(0.05, 1 - 0.05, 30), )
    q3 = jax.scipy

    calculate_worst_case_quantile(lambda x, alpha: jnp.where(x < alpha, 1 / alpha, 0),
                                  0.3,
                                  )
