import jax
import jax.numpy as jnp
import numpy as np
import math


def build_dataloader(images, labels, batch_size,
                     rng=None, shuffle=False, transform=None):

    # shuffle the entire dataset, if specified
    if shuffle:
        _shuffled = jax.random.permutation(rng, len(images))
    else:
        _shuffled = jnp.arange(len(images))
    images = images[_shuffled]
    labels = labels[_shuffled]

    # add padding to process the entire dataset
    marker = np.ones([len(images),], dtype=bool)
    num_batches = math.ceil(len(marker) / batch_size)
    padded_images = np.concatenate([
        images, np.zeros([
            num_batches*batch_size - len(images), *images.shape[1:]
        ], images.dtype)])
    padded_labels = np.concatenate([
        labels, np.zeros([
            num_batches*batch_size - len(labels), *labels.shape[1:]
        ], labels.dtype)])
    padded_marker = np.concatenate([
        marker, np.zeros([
            num_batches*batch_size - len(images), *marker.shape[1:]
        ], marker.dtype)])

    # define generator using yield
    batch_indices = jnp.arange(len(padded_images))
    batch_indices = batch_indices.reshape((num_batches, batch_size))
    for batch_idx in batch_indices:
        batch = {'images': jnp.array(padded_images[batch_idx]),
                 'labels': jnp.array(padded_labels[batch_idx]),
                 'marker': jnp.array(padded_marker[batch_idx]),}
        if transform is not None:
            if rng is not None:
                _, rng = jax.random.split(rng)
            sub_rng = None if rng is None else jax.random.split(
                rng, batch['images'].shape[0])
            batch['images'] = transform(sub_rng, batch['images'])
        yield batch
