import jax.numpy as jnp


def process_state(state, v_size=0, p_size=0):
    """
    :param state: tuple of vars
    :return: convert into device arrays with additional batch dim
    """

    state_new = ()
    for i in range(len(state)):

        var_tmp = jnp.array(state[i])

        if v_size > 0:
            var_tmp = var_tmp[None, ...].repeat(v_size, 0)

        if p_size > 0:
            var_tmp = var_tmp[None, ...].repeat(p_size, 0)

        state_new += (var_tmp,)

    return state_new
