import frontier_maze
import gym
import imageio
import ipdb
import pathlib
import os
from tqdm import tqdm
import numpy as np

import dreamerv2.api as dv2
from dreamerv2.common.replay import load_episodes

env = gym.make('normalizeddistractormaze2d-v1')
denorm_ob = env.unwrapped._denorm_ob

dir = "~/logdir/dv1_p2e_3models8bs_seed0/train_episodes/"
dir = pathlib.Path(dir).expanduser()
assert dir.is_dir()
gif_dir = "dv1_p2e_3models8bs_seed0_gifs"
os.makedirs(gif_dir, exist_ok=True)

# load episodes
episodes = load_episodes(dir, capacity=100*10)
# episodes = load_episodes(dir)
i = 0
for file_path, ep_dict in tqdm(episodes.items()):
    # x,y, dx,dy
    all_obs = ep_dict['image']
    all_rew = ep_dict['reward']
    ep_return = np.sum(all_rew)
    # if ep_return < 4:
    #     continue
    # generate rendering 
    env.reset()
    all_img = []
    for ob in all_obs:
        raw_ob = denorm_ob(ob)
        # since raw ob is just qpos of the agent, we can set directly.
        env.unwrapped.set_state(raw_ob[:2], raw_ob[2:])
        img = env.render("rgb_array")
        all_img.append(img)
    imageio.mimwrite(os.path.join(gif_dir, f"{i}.gif"), all_img)
    i += 1