import gym
import dreamerv2.api as dv2
import tensorflow as tf
from common import Config
import envs
import numpy as np
import pathlib
from dreamerv2.common.replay import load_episodes
from examples.run_goal_cond import make_env 
import os
from tqdm import tqdm
import imageio


def main():
    name = 'lexa_robobin_45bs_dyndist_10p2e_1e-5lr_f32'
    seed=0
    config = dv2.defaults.update({
        'logdir': f'~/logdir/{name}',
        'seed': seed,
    })
    proprio_defaults = Config(dv2.configs.pop('lexa_robobin_proprio'))
    # proprio_defaults = Config(dv2.configs.pop('lexa_dmc_proprio'))
    config = config.update(proprio_defaults)
    # config.pop("task")
    config.pop("log_keys_video")
    config = config.parse_flags()

    env = make_env(config)
    message = 'No GPU found. To actually train on CPU remove this assert.'
    assert tf.config.experimental.list_physical_devices('GPU'), message
    # env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(env)

    dir = f"~/logdir/{name}/train_episodes/"
    dir = pathlib.Path(dir).expanduser()
    assert dir.is_dir()
    gif_dir = f"gifs/{name}"
    os.makedirs(gif_dir, exist_ok=True)

    # load episodes
    episodes = load_episodes(dir, capacity=150*100)
    # episodes = load_episodes(dir)
    i = 0
    for file_path, ep_dict in tqdm(reversed(list(episodes.items()))):
        # x,y, dx,dy
        # all_obs = ep_dict['image']
        # all_rew = ep_dict['reward']
        # ep_return = np.sum(all_rew)
        # if ep_return < 4:
        #     continue
        # generate rendering
        all_qpos = ep_dict[config.state_key]
        goal = ep_dict[config.goal_key][0]
        env.reset()
        all_img = []
        # render goal
        if 'dmc' in config.task:
          size = env._env._env.physics.get_state().shape[0] - goal.shape[0]
          env._env._env.physics.set_state(np.concatenate((goal, np.zeros([size]))))
          env.step(np.zeros_like(env.action_space.sample()))
          goal_img = env.render()
          for qpos in all_qpos:
              size = env._env._env.physics.get_state().shape[0] - qpos.shape[0]
              env._env._env.physics.set_state(np.concatenate((qpos, np.zeros([size]))))
              env.step(np.zeros_like(env.action_space.sample()))
              img = env.render()
              # img = env.render("rgb_array")
              img = np.concatenate([goal_img, img], 1)
              all_img.append(img)
        elif 'mtmw' in config.task:
          hand_init_pos = env._env.hand_init_pos
          obj_init_pos = env._env.init_config['obj_init_pos']
          # Render state
          hand_pos, obj_pos, hand_to_goal = np.split(goal, 3)
          env._env.hand_init_pos = hand_pos
          env._env.init_config['obj_init_pos'] = obj_pos
          env._env.reset_model()
          goal_img = env.render_offscreen()
          # Revert environment
          env._env.hand_init_pos = hand_init_pos
          env._env.init_config['obj_init_pos'] = obj_init_pos
          env._env.reset()

          for qpos in all_qpos:
            hand_init_pos = env._env.hand_init_pos
            obj_init_pos = env._env.init_config['obj_init_pos']
            # Render state
            hand_pos, obj_pos, hand_to_goal = np.split(qpos, 3)
            env._env.hand_init_pos = hand_pos
            env._env.init_config['obj_init_pos'] = obj_pos
            env._env.reset_model()
            img = env.render_offscreen()
            # Revert environment
            env._env.hand_init_pos = hand_init_pos
            env._env.init_config['obj_init_pos'] = obj_init_pos
            env._env.reset()
            img = np.concatenate([goal_img, img], 1)
            all_img.append(img)
        elif 'robobin' in config.task:
          inner_env = env._env._env._env._env
          obj_init_pos_temp = inner_env.init_config['obj_init_pos'].copy()
          inner_env.init_config['obj_init_pos'] = goal[3:]
          inner_env.obj_init_pos = goal[3:]
          inner_env.hand_init_pos = goal[:3]
          inner_env.reset_model()
          action = np.zeros(inner_env.action_space.low.shape)
          state, reward, done, info = inner_env.step(action)

          goal_img = env.render_offscreen()
          inner_env.hand_init_pos = inner_env.init_config['hand_init_pos']
          inner_env.init_config['obj_init_pos'] = obj_init_pos_temp
          inner_env.obj_init_pos = inner_env.init_config['obj_init_pos']
          inner_env.reset()
          for qpos in all_qpos:
            obj_init_pos_temp = inner_env.init_config['obj_init_pos'].copy()
            goal = qpos
            inner_env.init_config['obj_init_pos'] = goal[3:]
            inner_env.obj_init_pos = goal[3:]
            inner_env.hand_init_pos = goal[:3]
            inner_env.reset_model()
            action = np.zeros(inner_env.action_space.low.shape)
            state, reward, done, info = inner_env.step(action)

            img = env.render_offscreen()
            inner_env.hand_init_pos = inner_env.init_config['hand_init_pos']
            inner_env.init_config['obj_init_pos'] = obj_init_pos_temp
            inner_env.obj_init_pos = inner_env.init_config['obj_init_pos']
            inner_env.reset()
            img = np.concatenate([goal_img, img], 1)
            all_img.append(img)

        imageio.mimwrite(os.path.join(gif_dir, f"{i}.mp4"), all_img)
        i += 1

if __name__ == "__main__":
    main()