def split_batch_pmap(batch, num_devices): # ADDED
    for key in batch.keys():
        batch[key] = batch[key].reshape((num_devices, -1, *batch[key].shape[1:]))
    return batch
