import jax.numpy as jnp
import numpy as np
from jax import jit, random
from jax.scipy.special import logsumexp


@jit
def resample_is(
    log_w, key
):  # Multinomial resampling (systematic tedious to implement in Jax)
    B = jnp.shape(log_w)[0]
    p = jnp.exp(log_w - logsumexp(log_w))

    # Measure ESS
    ESS = 1 / np.sum(p ** 2)
    I_resample = jnp.where(ESS < 0.5 * B, x=True, y=False)

    # Update or keep old indices/weights depending on ESS
    ind_old = jnp.arange(B)
    ind_new = (
        I_resample * random.choice(key, B, shape=(B,), p=p) + (1 - I_resample) * ind_old
    )
    log_w_new = I_resample * jnp.log(np.ones(B)) + (1 - I_resample) * log_w

    return log_w_new, ind_new, ESS

