import jax
import jax.numpy as np

from jax import random
from jax import grad, jit, vmap, value_and_grad
from jax.experimental import optimizers
from jax.util import partial, safe_zip, safe_map, unzip2
from jax.tree_util import tree_map, tree_multimap, tree_flatten, tree_unflatten, tree_reduce
from jax.lax import fori_loop

from jax.scipy import stats

import operator

map = safe_map
zip = safe_zip

@jit
def sample_weights_diag(rng, bnn_params, scale=1.0):
    mean_params, var_params = bnn_params
    flat_means, tree = tree_flatten(mean_params)
    flat_vars, tree2 = tree_flatten(var_params)
    assert tree == tree2
    rngs = random.split(rng, len(flat_means))
    def sample_weight(rng, mean, var):
        noise = random.normal(rng, shape=mean.shape)
        return mean + scale * np.sqrt(var) * noise
    return tree_unflatten(tree, map(sample_weight, rngs, flat_means, flat_vars))
