import os
import tensorflow as tf
import dreamerv2.api as dv2
import common
from common import Config
import numpy as np
import matplotlib.pyplot as plt
import sys
import collections
from collections import defaultdict
from envs.sibrivalry.ant_maze import AntMazeEnvFullDownscale
from tqdm import tqdm, trange
import pickle
import imageio
import pathlib
from functools import partial
from run_goal_cond import make_env
import ruamel.yaml as yaml
import gc_agent


def main():
  """ 
  General workflow:
  1.Set the following variables. 
  2.Go to configs.yaml and change hyperparameters.
  """
  configs = yaml.safe_load((
      pathlib.Path('/home/anonymous/logdir/umazefulldownscale_ar1_dyndist_sgp_32_32_mlp_eval_env_20/config.yaml').read_text()))
  config = common.Config(configs)
  """ ========= SETUP ENVIRONMENTS  ========"""
  env = make_env(config, use_goal_idx=False, log_per_goal=True)
  # eval_env = make_env(config, use_goal_idx=True, log_per_goal=False, eval=True)
  # report_render_fn = make_report_render_function(config)
  # eval_fn = make_eval_fn(config)

  """ ========= SETUP TF2 and GPU ========"""
  tf.config.run_functions_eagerly(not config.jit)
  # tf.data.experimental.enable_debug_mode(not config.jit)
  message = 'No GPU found. To actually train on CPU remove this assert.'
  assert tf.config.experimental.list_physical_devices('GPU'), message
  for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
  assert config.precision in (16, 32), config.precision
  if config.precision == 16:
    from tensorflow.keras.mixed_precision import experimental as prec
    prec.set_policy(prec.Policy('mixed_float16'))
  logdir = pathlib.Path(config.logdir).expanduser()
  logdir.mkdir(parents=True, exist_ok=True)
  config.save(logdir / 'config.yaml')
  print(config, '\n')
  print('Logdir', logdir)

  replay = common.Replay(logdir / 'train_episodes', **config.replay) # initialize replay buffer
  step = common.Counter(replay.stats['total_steps']) # initialize step counter
  dataset = iter(replay.dataset(**config.dataset))

  agnt = gc_agent.GCAgent(config, env.obs_space, env.act_space, step, obs2goal = None) 
  train_agent = common.CarryOverState(agnt.train)
  train_agent(next(dataset))
  agnt.load(logdir / 'variables.pkl')
  env = AntMazeEnvFullDownscale()
  def make_eval_fn(config):
    if 'umazefull' == config.task or 'umazefulldownscale' == config.task:
      def episode_render_fn(env, ep):
        all_img = []
        goals = []
        executions = []
        print('Goal: ', ep['goal'][0][:2])
        env.maze.wrapped_env.set_state(ep['goal'][0][:15], ep['goal'][0][:14])
        env.maze.wrapped_env.sim.forward()
        print('After forward: ', env.maze.wrapped_env.sim.data.qpos[:2])
        # while True:
        #   env.render(mode='human')
        goal_img = env.render(mode='rgb_array') # TODO: fix, use the episode's goal. 
        for obs in ep['observation']:
          env.maze.wrapped_env.set_state(obs[:15], np.zeros_like(obs[:14]))
          env.maze.wrapped_env.sim.forward()
          img = env.render(mode='rgb_array')
          all_img.append(img)
        goals.append(goal_img[None]) # 1 x H x W x C
        ep_img = np.stack(all_img, 0)
        executions.append(ep_img[None]) # 1 x T x H x W x C
        return goals, executions
      def evaluate_all_goals(driver, eval_policy):
        env = driver._envs[0]
        num_goals = len(env.get_goals()) if len(env.get_goals()) > 0 else 5
        num_eval_eps = 1
        for ep_idx in range(num_eval_eps):
          should_video = ep_idx == 0 and episode_render_fn is not None
          for idx in range(num_goals):
            executions = []
            goals = []
            env.set_goal_idx(idx)
            driver(eval_policy, episodes=1)
            if not should_video:
              continue
            """ rendering based on state."""
            ep = driver._eps[0] # get episode data of 1st env.
            ep = {k: driver._convert([t[k] for t in ep]) for k in ep[0]}
            # render the goal img and rollout
            _goals, _executions = episode_render_fn(env, ep)
            goals.extend(_goals)
            executions.extend(_executions)
            executions = np.concatenate(executions, 0) # num_goals x T x H x W x C
            goals = np.stack(goals, 0) # num_goals x 1 x H x W x C
            goals = np.repeat(goals, executions.shape[1], 1)
            gc_video = np.concatenate([goals, executions], -3)
            video = list(gc_video.squeeze())
            imageio.mimwrite(os.path.join('/home/anonymous/Videos', f'new_goal_video{idx}.mp4'), video, macro_block_size=1)
    return evaluate_all_goals
  eval_fn = make_eval_fn(config)
  eval_env = make_env(config, use_goal_idx=True, log_per_goal=False, eval=True)
  eval_driver = common.GCDriver([eval_env], config.goal_key)
  eval_gc_policy = partial(agnt.policy, mode='eval')
  eval_fn(eval_driver, eval_gc_policy)

if __name__ == "__main__":
    main()

