import jax
import jax.numpy as jnp


def rollout(
    num_envs,
    num_steps,
    env,
    env_params,
    agent_params,
    rng,
    network,
    return_reward=False,
    without_restart=False,
):

    # INIT ENV
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, num_envs)
    obsv, env_state = env.reset(reset_rng, env_params)
    prev_done = jnp.ones(shape=(num_envs,), dtype=jnp.bool_)
    ever_done = jnp.zeros(shape=(num_envs,), dtype=jnp.bool_)

    # COLLECT TRAJECTORIES
    def _env_step(runner_state, unused):
        env_state, last_obs, rng, prev_done, ever_done = runner_state

        # SELECT ACTION
        rng, _rng = jax.random.split(rng)
        pi = network.apply({"params": agent_params}, last_obs)
        if isinstance(pi, tuple):
            pi, _ = pi
        action = pi.sample(seed=_rng)

        # STEP ENV
        rng, _rng = jax.random.split(rng)
        rng_step = jax.random.split(_rng, num_envs)
        obsv, env_state, reward, done, _ = env.step(
            rng_step, env_state, action, env_params
        )
        prev_done = done
        ever_done = jnp.logical_or(ever_done, done)
        runner_state = (env_state, obsv, rng, prev_done, ever_done)

        last_obs_without_restart = jax.vmap(
            lambda x: jnp.where(
                ever_done,
                -100,
                x,
            ),
            in_axes=(1),
        )(last_obs).T
        action_without_restart = jnp.where(ever_done, -100, action)
        reward_without_restart = jnp.where(ever_done, jnp.zeros_like(reward), reward)

        obs_out = jax.lax.select(without_restart, last_obs_without_restart, last_obs)
        action_out = jax.lax.select(without_restart, action_without_restart, action)
        reward_out = jax.lax.select(without_restart, reward_without_restart, reward)

        return runner_state, (done, obs_out, action_out, reward_out)

    rng, _rng = jax.random.split(rng)
    runner_state = (env_state, obsv, _rng, prev_done, ever_done)
    runner_state, (all_dones, all_obsv, all_actions, all_rewards) = jax.lax.scan(
        _env_step, runner_state, None, num_steps
    )

    if return_reward:
        return (
            all_obsv.transpose((1, 0, 2)),
            all_actions.T,
            all_rewards.T,
            all_dones.T,
        )
    else:
        return all_obsv, all_actions
