import jax
import jax.numpy as jnp
import jax.scipy as jsp


def to_valid_prob(w):
    w = jnp.clip(jnp.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0), 0.0, None)
    s = jnp.sum(w)
    w = jax.lax.cond((s > 0.0) & jnp.isfinite(s), lambda _: w / s, lambda _: jnp.ones_like(w) / w.size, None)
    if w.size > 1:
        w = w.at[-1].set(jnp.maximum(0.0, 1.0 - jnp.sum(w[:-1]))); w = w / (jnp.sum(w) + 1e-12)
    return w


def kl_hat_from_values(Vn, eta):
    w = jax.nn.softmax(-eta[..., None] * Vn, axis=-1)
    m = Vn.shape[-1]
    kl = jnp.sum(w * (jnp.log(w + 1e-12) + jnp.log(jnp.asarray(m, Vn.dtype))), axis=-1)
    return kl, w


def robust_expectation_dual(Vn, eta):
    m = Vn.shape[-1]
    lse = jsp.special.logsumexp(-eta[..., None] * Vn, axis=-1) - jnp.log(jnp.asarray(m, Vn.dtype))
    return -(1.0 / (eta + 1e-12)) * lse


def dual_primal_gap(Vn, eta):
    _, w = kl_hat_from_values(Vn, eta)
    return robust_expectation_dual(Vn, eta) - jnp.sum(w * Vn, axis=-1)


def project_eta_to_delta(Vn, eta_t, delta, eps=1e-3, iters=25):
    eta_t = jax.lax.stop_gradient(eta_t); Vn = jax.lax.stop_gradient(Vn)
    lo = jnp.zeros_like(eta_t) + 1e-8; hi = jnp.maximum(eta_t, 1.0)

    def expand(_, h):
        kl, _ = kl_hat_from_values(Vn, h)
        return jnp.where(kl < delta, h * 2.0, h)

    hi = jax.lax.fori_loop(0, 5, expand, hi)

    def body(_, carry):
        lo_, hi_ = carry
        mid = 0.5 * (lo_ + hi_)
        kl_mid, _ = kl_hat_from_values(Vn, mid)
        go_right = kl_mid < delta
        return jnp.where(go_right, mid, lo_), jnp.where(go_right, hi_, mid)

    lo, hi = jax.lax.fori_loop(0, iters, body, (lo, hi))
    return jax.lax.stop_gradient(hi)


def straight_through_eta(eta_t, eta_star):
    return jax.lax.stop_gradient(eta_star) + (eta_t - jax.lax.stop_gradient(eta_t))