import jax
import jax.numpy as jnp


def simulate_episode(env, params, apply_fn, key, eps, render=False):
    state = env.reset(key=key)
    episode_return = 0
    terminated = truncated = False
    while not (terminated or truncated):
        if render:
            env.render()
        key, subkey = jax.random.split(key)
        if jax.random.uniform(subkey) < eps:
            key, subkey = jax.random.split(key)
            action = jax.random.choice(subkey, env.n_actions)
        else:
            action = int(jnp.argmax(apply_fn(params, state)))
        next_state, reward, terminated, truncated = env.step(action)
        episode_return += reward
        state = next_state
    return episode_return


def simulate_n_episodes(env, params, apply_fn, key, n_episodes, eps, render=False):
    episode_returns = []
    for _ in range(n_episodes):
        key, subkey = jax.random.split(key)
        episode_return = simulate_episode(env, params, apply_fn, subkey, eps, render)
        episode_returns.append(episode_return)
    return jnp.array(episode_returns)
