def run_policy_on_env(policy_fn,
                      env,
                      truncate_episode_at=None,
                      first_obs=None):
    if first_obs is None:
        obs = env.reset()
    else:
        obs = first_obs

    trajectory = []
    step_num = 0
    while True:
        act = policy_fn(obs)
        next_obs, rew, done, _ = env.step(act)
        trajectory.append((obs, act, rew, done))
        obs = next_obs
        step_num += 1
        if (done or (truncate_episode_at is not None
                     and step_num >= truncate_episode_at)):
            break
    return trajectory
