import tensorflow_datasets as tfds
import jax
import jax.numpy as np
import numpy as onp
import numpyro.distributions as npdist


def load_dataset(dset, mode = 'train', binarize = True):
	rng_key = jax.random.PRNGKey(1)
	if mode == 'train':
		ds = tfds.load(dset, split = 'train[:50000]', batch_size = -1)
	if mode == 'val':
		ds = tfds.load(dset, split = 'train[-10000:]', batch_size = -1)
	if mode == 'test':
		ds = tfds.load(dset, split = 'test[:10000]', batch_size = -1)

	data = tfds.as_numpy(ds)['image']
	print(data.shape) # (N, 28, 28, 1)
	data = data / 255.
	return data


@jax.jit
def binarize(rng_key, x):
	# x has dim (bs, 28, 28, 1), already in [0, 1]
	dist = npdist.Bernoulli(probs = x)
	return dist.sample(rng_key) * 1.






