import jax
import jax.numpy as np

from jax import grad, jit, vmap, pmap, value_and_grad
from jax import random

from jax.tree_util import tree_multimap, tree_map

# converts from batch to a single device
def get_single_copy(inputs):
    return tree_map(lambda x: x[0], inputs)

# makes n_devices copies in a stack
def manual_pmap_tree(inputs, n_devices):
    return tree_map(lambda x: np.repeat(x[None], n_devices, axis=0), inputs)
