def get_batch(x, step, batch_size):
    """
    Simple batching mechanism.

    Args:
        x           input tensor containing all the data points, e.g. of shape (n_points, ...)
        step        gradient-descent-based optimization loop step index of the current epoch
        batch_size  desired batch_size
    Returns:
        batch       of shape (batch_size, ...)
    """
    n_points = len(x)
    i = (step * batch_size) % n_points
    batch = x[i : i+batch_size, ...]
    return batch
