import jax
import jax.numpy as np
from jax import grad, jit, vmap
from jax.scipy.special import logsumexp

# makes a pass through a dataset to update batchnorm statistics
def collect_batchnorm_running_stats(params, state, np_ds, net_apply):
    for x, _ in np_ds:
        x = np.array(x)
        _, state = net_apply(params, state, None, x)
    return state
