import numpy as np
import torch
import pickle
import os
import time
from PIL import Image

import sys
sys.path.append("../EinsumNetworks/src")

from EinsumNetwork import Graph, EinsumNetwork
import eiutils
import datasets


def make_shuffled_batch(N, batch_size):
    idx = np.random.permutation(N)
    num_full_batches = N // batch_size
    k = num_full_batches * batch_size
    b_idx = np.array_split(idx[0:k], num_full_batches)
    if k < N:
        b_idx.append(idx[k:])
    return b_idx