from typing import Callable
import jax
import jax.numpy as jnp
from rapo_math import to_valid_prob


def sample_snext_weighted_allk(step_allK: Callable, probs: jnp.ndarray, key, s_batch, a_batch, m: int):
    S_all = step_allK(s_batch, a_batch)
    B, K, D = S_all.shape
    key, sub = jax.random.split(key)
    idxs = jax.random.choice(sub, K, shape=(B, m), p=to_valid_prob(probs))
    idxs3 = jnp.broadcast_to(idxs[..., None], (B, m, D))
    nxt = jnp.take_along_axis(S_all, idxs3, axis=1)
    return nxt, key


def tilt_theta_weights_allk(step_allK, value_apply, w_prior, key, s, a, rho, subsample: int):
    import jax.numpy as jnp
    N = s.shape[0]
    key, sub = jax.random.split(key)
    idx = jax.random.randint(sub, (subsample,), 0, jnp.maximum(N, 1))
    idx = jnp.where(N > 0, idx % N, jnp.zeros_like(idx))
    s_sub, a_sub = s[idx], a[idx]
    S_all = step_allK(s_sub, a_sub)  # (S,K,D)
    V_all = value_apply(S_all.reshape(-1, S_all.shape[-1])).reshape(S_all.shape[0], S_all.shape[1])
    H = jnp.mean(V_all, axis=0)  # (K,)

    def make_w(beta):
        w = w_prior * jnp.exp(-beta * H)
        w = w / (jnp.sum(w) + 1e-12)
        return to_valid_prob(w)

    def scan_body(carry, _):
        lo, hi, w_best = carry
        mid = 0.5 * (lo + hi); w_mid = make_w(mid)
        kl = jnp.sum(w_mid * (jnp.log(w_mid + 1e-12) - jnp.log(w_prior + 1e-12)))
        go_right = kl < rho
        return (jnp.where(go_right, mid, lo),
                jnp.where(go_right, hi, mid),
                jnp.where(go_right, w_mid, w_best)), None

    w0 = make_w(0.0)
    (lo, hi, w_new), _ = jax.lax.scan(scan_body, (0.0, 1e6, w0), None, length=30)
    return w_new, key