import copy
import logging
import os

import h5py
import jax
import jax.numpy as jnp
import numpy as np
from tqdm.auto import trange


def batch_rollout(
    rng,
    data_aug_rng,
    env,
    policy_fn,
    transform_obs_fn,
    transform_action_fn,
    episode_length=600,
    log_interval=None,
    num_episodes=1,
):

    # get indices of test trajectories.
    reward = jnp.zeros(1, dtype=jnp.float32)
    ep_lens = jnp.zeros(1, dtype=jnp.float32)

    videos = []
    for ep in trange(num_episodes, desc="rollout", ncols=0):
        done = jnp.zeros(1, dtype=jnp.int32)

        for t in trange(episode_length + 1 , desc=f"episode {ep}", ncols=0, leave=False):
            done_prev = done
            if t == 0:
                obs = env.reset()
            else:
                obs = next_obs
            if transform_obs_fn is not None:
                input_obs = copy.deepcopy(obs)
                for key, val in input_obs["image"].items():
                    input_obs["image"][key], data_aug_rng = transform_obs_fn(val, data_aug_rng)
            else:
                input_obs = obs

            action = jax.device_get(policy_fn(inputs=input_obs, rngs=rng))[0]
            action = transform_action_fn(action)

            next_obs, _reward, done, info = env.step(action)
            reward = reward + _reward * (1 - done_prev)
            done = jnp.logical_or(done, done_prev).astype(jnp.int32)
            if log_interval and t % log_interval == 0:
                logging.info("step: %d done: %s reward: %s", t, done, reward)

            if jnp.all(done):
                ep_lens += info["episode_len"]
                break

        if info["vid"] is not None:
            videos.append(info["vid"])

    metric = {
        "return": reward.astype(jnp.float32) / num_episodes,
        "episode_length": ep_lens.astype(jnp.float32) / num_episodes,
    }
    return metric, info, videos
