import matplotlib.pyplot as plt


def batch_dataset(key, dataset, batch_size, sshape, ashape, rshape):
    chunks = []
    # dataset = random.permutation(key, dataset, independent=True)
    for i in range(0, len(dataset), batch_size):
        chunk = dataset[i:i+batch_size]
        chunks += [chunk]

    return chunks

def plot_dataset(dataset, plot_range, reward_threshold):
  plt.hist(dataset[0][:, 0], bins=100)
  plt.axvline(reward_threshold, linestyle='--', color='black', linewidth=1)
  plt.xlim(plot_range)
  plt.title(f'Dataset distribution (bin size: {(dataset[0].max() - dataset[0].min()) / (dataset[0].shape[0]):.3f})')
  plt.xlabel('s')
  plt.ylabel('count')
  plt.show()
