"""
Base code for RL/IL training.
Collects rollouts and updates policy networks.
"""

import os
import gzip
import pickle
import copy
from time import time
# from collections import defaultdict, OrderedDict

import torch
import wandb
import h5py
import gym
import numpy as np
import moviepy.editor as mpy
from tqdm import tqdm, trange

from algorithms import RL_ALGOS, IL_ALGOS, get_agent_by_name
from algorithms.rollouts import RolloutRunner
from utils.info_dict import Info
from utils.logger import logger
from utils.pytorch import get_ckpt_path, count_parameters, to_tensor
from utils.mpi import mpi_sum, mpi_average, mpi_gather_average
from environments import make_env
# import environments.metaworld as metaworld
from gym.utils import seeding
import matplotlib.pyplot as plt
from utils.visual_tool import make_gif, visualize_heatmap_8sections, heatmap_arrows, sigmoid
import seaborn as sn

class Trainer(object):
    """
    Trainer class for SAC, PPO, DDPG, BC, and GAIL in PyTorch.
    """

    def __init__(self, config):
        """
        Initializes class with the configuration.
        """
        self._config = config
        self._is_chef = config.is_chef
        self._is_rl = config.algo in RL_ALGOS
        self._average_info = config.average_info

        # create environment
        self._env = make_env(config.env, config)
        self._env.seed(config.seed)
        self._env.action_space.seed(config.seed)

        ob_space = env_ob_space = self._env.observation_space
        ac_space = self._env.action_space
        logger.info("Observation space: " + str(ob_space))
        logger.info("Action space: " + str(ac_space))

        self.config_eval = copy.copy(config)
        if hasattr(self.config_eval, "port"):
            self.config_eval.port += 1
        self.config_eval.seed = config.seed + 10000
        self._env_eval = make_env(config.env, self.config_eval) if self._is_chef else None
        self._env_eval.seed(self.config_eval.seed)
        self._env_eval.action_space.seed(self.config_eval.seed)

        # create a new observation space after data augmentation (random crop)
        if config.encoder_type == "cnn":
            assert (
                not config.ob_norm
            ), "Turn off the observation norm (--ob_norm False) for pixel inputs"
            ob_space = gym.spaces.Dict(spaces=dict(ob_space.spaces))
            for k in ob_space.spaces.keys():
                if len(ob_space.spaces[k].shape) == 3:
                    shape = [
                        ob_space.spaces[k].shape[0],
                        config.encoder_image_size,
                        config.encoder_image_size,
                    ]
                    ob_space.spaces[k] = gym.spaces.Box(
                        low=0, high=255, shape=shape, dtype=np.uint8
                    )
        maze_str = None
        # if config.algo in ["reachable_gail", "reachable_gail-v0"]:
        #     if config.env.startswith("maze"):
        #         maze_str = self._env.str_maze_spec

        # build agent and networks for algorithm
        self._agent = get_agent_by_name(config.algo)(
            config, ob_space, ac_space, env_ob_space, layout=maze_str
        )

        # build rollout runner
        self._runner = RolloutRunner(config, self._env, self._env_eval, self._agent)

        # save ood trajectory
        # if config.is_save_ood_traj:
        #     self._ood_trajectories = []

        # self.dvd_rew = []
        # self.reach_rew = []
        # self.total_rew = []
        # self.np_random, seed = seeding.np_random(self.config_eval.seed)
        # ac_list = [self._env.action_space.sample()['ac'] for _ in range(200)]
        # ac_list = np.stack(ac_list, axis=0)
        # self.ac = to_tensor(ac_list, config.device)
        
        # self.ob_noise = [self.np_random.uniform(low=-.1, high=.1, size=self._env.model.nq) for _ in range(200)]

        # setup log
        if self._is_chef and config.is_train:
            exclude = ["device"]
            if not config.wandb:
                os.environ["WANDB_MODE"] = "dryrun"
            
            os.environ["WANDB__SERVICE_WAIT"] = "300"

            from wandb_credentials import WANDB_API_KEY
            wandb.login(key=WANDB_API_KEY)

            wandb.init(
                name=config.run_name,
                resume=True,
                project=config.wandb_project,
                config={k: v for k, v in config.__dict__.items() if k not in exclude},
                dir=config.log_dir,
                entity=config.wandb_entity,
                notes=config.notes,
            )

    def _save_ckpt(self, ckpt_num, update_iter):
        """
        Save checkpoint to log directory.

        Args:
            ckpt_num: number appended to checkpoint name. The number of
                environment step is used in this code.
            update_iter: number of policy update. It will be used for resuming training.
        """
        ckpt_path = os.path.join(self._config.log_dir, "ckpt_%09d.pt" % ckpt_num)
        state_dict = {"step": ckpt_num, "update_iter": update_iter}
        state_dict["sampled_few_demo_index"] = self._agent.sampled_few_demo_index# TODO: load this for taget demos
        state_dict["agent"] = self._agent.state_dict()
        torch.save(state_dict, ckpt_path)
        logger.warn("Save checkpoint: %s", ckpt_path)

        if self._agent.is_off_policy():
            replay_path = os.path.join(
                self._config.log_dir, "replay_%08d.pkl" % ckpt_num
            )
            with gzip.open(replay_path, "wb") as f:
                if self._config.algo in ["reachable_gail", "reachable_gail-v0"]:
                    ### HACK HACK HACK
                    # replay_buffers = {"replay": self._agent.replay_buffer()}
                    replay_buffers = {"replay": self._agent.replay_buffer(), "reachable": self._agent._reachable_buffer.state_dict(), "unreachable": self._agent._unreachable_buffer.state_dict(), "expert_demo_size":
                                   self._agent._reachable_buffer._expert_demo_size}
                else:
                    replay_buffers = {"replay": self._agent.replay_buffer()}
                pickle.dump(replay_buffers, f)
            logger.warn("Save replay buffer: %s", replay_path)
        
        # delete old checkpoints and replay buffers
        # find all the checkpoints and replay buffers files in the log directory and delete them except the last two
        ckpt_files = []
        replay_files = []
        for file in os.listdir(self._config.log_dir):
            if file.startswith("ckpt") and file.endswith(".pt"):
                ckpt_files.append(file)
            if file.startswith("replay") and file.endswith(".pkl"):
                replay_files.append(file)
        ckpt_files.sort()
        replay_files.sort()
        for file in ckpt_files[:-2]:
            os.remove(os.path.join(self._config.log_dir, file))
            logger.warn("Delete checkpoint: %s", file)
        for file in replay_files[:-2]:
            os.remove(os.path.join(self._config.log_dir, file))
            logger.warn("Delete replay buffer: %s", file)

    def _load_ckpt(self, ckpt_path, ckpt_num):
        """
        Loads checkpoint with path @ckpt_path or index number @ckpt_num. If @ckpt_num is None,
        it loads and returns the checkpoint with the largest index number.
        """
        if ckpt_path is None:
            ckpt_path, ckpt_num = get_ckpt_path(self._config.log_dir, ckpt_num)
        else:
            ckpt_num = int(ckpt_path.rsplit("_", 1)[-1].split(".")[0])

        if ckpt_path is not None:
            logger.warn("Load checkpoint %s", ckpt_path)
            ckpt = torch.load(ckpt_path, map_location=self._config.device)
            self._agent.load_state_dict(ckpt["agent"])
            self.sampled_few_demo_index = None#ckpt["sampled_few_demo_index"]

            if self._config.is_train and self._agent.is_off_policy():
                replay_path = os.path.join(
                    self._config.log_dir, "replay_%08d.pkl" % ckpt_num
                )
                logger.warn("Load replay_buffer %s", replay_path)
                if os.path.exists(replay_path):
                    with gzip.open(replay_path, "rb") as f:
                        buffers = pickle.load(f)
                        self._agent.load_replay_buffer(buffers["replay"])

                        if self._config.algo in ["reachable_gail", "reachable_gail-v0"]:
                            self._agent._reachable_buffer.load_state_dict(buffers["reachable"])
                            self._agent._unreachable_buffer.load_state_dict(buffers["unreachable"])
                            self._agent._reachable_buffer._expert_demo_size = buffers["expert_demo_size"]
                            self._agent._reachable_buffer._expert_demo = True
                else:
                    logger.warn("Replay buffer not exists at %s", replay_path)

            if (
                    self._config.init_ckpt_path is not None
                    and "bc" in self._config.init_ckpt_path
            ):
                return 0, 0
            else:
                return ckpt["step"], ckpt["update_iter"]
        
        self.sampled_few_demo_index = None
        logger.warn("Randomly initialize models")
        return 0, 0

    def _log_pretrain(self, step, train_info, ep_info):
        """
        same as log_train but uses pretrain prefix
        Args:
            step: the number of environment steps.
            train_info: training information to log, such as loss, gradient.
            ep_info: episode information to log, such as reward, episode time.
        """
        for k, v in train_info.items():
            if np.isscalar(v) or (hasattr(v, "shape") and np.prod(v.shape) == 1):
                wandb.log({"pretrain/%s" % k: v}, step=step)
            elif isinstance(v[0], wandb.Image):
                for i, image in enumerate(v):
                    wandb.log({"pretrain/%s_%d" % (k, i): image}, step=step)
            elif isinstance(v, wandb.Image):
                wandb.log({"pretrain/%s" % k: v}, step=step)
            elif isinstance(v, wandb.Video):
                    wandb.log({"pretrain/%s" % k: v}, step=step)
            elif isinstance(v, list) and isinstance(v[0], wandb.Video):
                for i, video in enumerate(v):
                    wandb.log({"pretrain/%s_%d" % (k, i): video}, step=step)
        if ep_info is not None:
            for k, v in ep_info.items():
                wandb.log({"pretrain_ep/%s" % k: np.mean(v)}, step=step)
                wandb.log({"pretrain_ep_max/%s" % k: np.max(v)}, step=step)

    def _log_train(self, step, train_info, ep_info):
        """
        Logs training and episode information to wandb.
        Args:
            step: the number of environment steps.
            train_info: training information to log, such as loss, gradient.
            ep_info: episode information to log, such as reward, episode time.
        """
        for k, v in train_info.items():
            if np.isscalar(v) or (hasattr(v, "shape") and np.prod(v.shape) == 1):
                wandb.log({"train_rl/%s" % k: v}, step=step)
            elif isinstance(v[0], wandb.Image):
                for i, image in enumerate(v):
                    wandb.log({"train_rl/%s_%d" % (k, i): image}, step=step)
            elif isinstance(v, wandb.Image):
                wandb.log({"train_rl/%s" % k: v}, step=step)
            elif isinstance(v, wandb.Video):
                    wandb.log({"train_rl/%s" % k: v}, step=step)
            elif isinstance(v, list) and isinstance(v[0], wandb.Video):
                for i, video in enumerate(v):
                    wandb.log({"train_rl/%s_%d" % (k, i): video}, step=step)

        for k, v in ep_info.items():
            if isinstance(v, wandb.Image):
                wandb.log({"train_ep/%s" % k: v}, step=step)
            elif isinstance(v, list) and isinstance(v[0], wandb.Image):
                 for i, image in enumerate(v):
                    wandb.log({"train_ep/%s_%d" % (k, i): image}, step=step)
                    break
            else:
                wandb.log({"train_ep/%s" % k: np.mean(v)}, step=step)
                wandb.log({"train_ep_max/%s" % k: np.max(v)}, step=step)

    def _log_test(self, step, ep_info):
        """
        Logs episode information during testing to wandb.
        Args:
            step: the number of environment steps.
            ep_info: episode information to log, such as reward, episode time.
        """
        if self._config.is_train:
            # if self._config.evaluate_bc_test_loss:
            #     ### HACK HACK HACK so messy
            #     bc_test_loss = self._agent._bc_loss(self._agent._sample_expert_data(test=True))
            #     ep_info["bc_loss"] = bc_test_loss.cpu().item()
            for k, v in ep_info.items():
                if isinstance(v, wandb.Video):
                    wandb.log({"test_ep/%s" % k: v}, step=step)
                elif isinstance(v, list) and isinstance(v[0], wandb.Video):
                    for i, video in enumerate(v):
                        wandb.log({"test_ep/%s_%d" % (k, i): video}, step=step)
                else:
                    wandb.log({"test_ep/%s" % k: np.mean(v)}, step=step)

    def train(self):
        """ Trains an agent. """
        config = self._config

        # load checkpoint
        step, update_iter = self._load_ckpt(config.init_ckpt_path, config.ckpt_num)

        # load training data
        if self._config.algo not in ["sqil", "gail"]:
            self._agent.load_training_data(sampled_few_demo_index=self.sampled_few_demo_index)

        # sync the networks across the cpus
        self._agent.sync_networks()

        logger.info("Start training at step=%d", step)
        if self._is_chef:
            pbar = tqdm(
                initial=update_iter, total=config.max_global_step, desc=config.run_name
            )
            ep_info = Info()
            train_info = Info()

        runner = None
        # decide how many episodes or how long rollout to collect
        if self._config.algo == "bc" or self._config.algo == "mt-bc":
            runner = None
            
            if self._config.pretrain_BC:
                if self._config.ob_norm:
                    self._agent.update_normalizer_pretrain()
                step_bc = 0
                while step_bc < config.pretrain_bc_max_step:
                    pretrain_info = self._agent.pretrain_bc()
                    self._log_pretrain(step_bc, pretrain_info, None)
                    step_bc += 1

        elif self._config.algo in ["gail", "airl", "gail-v2"]:
            # pretrain discriminator
            if config.pretrain_discriminator:
                self._agent.pretrain_dis()

            assert config.gail_rl_algo in ['ppo', 'sac', 'ddpg', 'td3', "dqn"]
            runner = self._runner.run(
                every_steps=self._config.rollout_length if config.gail_rl_algo=='ppo' else self._config.num_env_steps_per_update, step=step
            )
        elif self._config.algo in ["reachable_gail", "reachable_gail-v0", "prox"]:
            if config.pretrain_discriminator:
                self._agent.pretrain_dis()
            if not config.single_task_training and config.pretrain_prox:
                pre_step = 0
                while True:
                    pretrain_info, done = self._agent.pretrain()
                    ep_info = Info()
                    self._log_pretrain(pre_step, pretrain_info, ep_info.get_dict())
                    pre_step += 1
                    if done:
                        # self._save_ckpt(ckpt_num=0, update_iter=0)
                        break
                logger.info("Finish Prox. pretraining")

            assert config.gail_rl_algo in ['ppo', 'sac', 'ddpg', 'td3', "dqn"]
            runner = self._runner.run(
                every_steps=self._config.rollout_length if config.gail_rl_algo=='ppo' else self._config.num_env_steps_per_update, step=step
            )
        elif self._config.algo == "ppo":
            runner = self._runner.run(
                every_steps=self._config.rollout_length, step=step
            )
        elif self._config.algo in ["sac", "ddpg", "td3", "dqn"]:
            runner = self._runner.run(every_steps=self._config.num_env_steps_per_update, step=step)
        elif self._config.algo in ["iqlearn", "sqil"]:
            runner = self._runner.run(every_steps=1, step=step)
            # runner = self._runner.run(every_episodes=1)
        # elif self._config.algo == "dac":
        #     runner = self._runner.run(every_steps=1, step=step)

        st_time = time()
        st_step = step

        while runner and step < config.warm_up_steps:
            rollout, info = next(runner)
            info_store = self._agent.store_episode(rollout, step=step)
            step_per_batch = mpi_sum(len(rollout["ac"]))
            step += step_per_batch
            if runner and step < config.max_ob_norm_step:
                self._update_normalizer(rollout)
            if self._is_chef:
                pbar.update(step_per_batch)

        if self._config.algo == "bc" and self._config.ob_norm:
            self._agent.update_normalizer()

        # evaluate before any training
        # if config.num_eval:
        #     logger.info("Evaluate at %d", 0)
        #     rollout, info = self._evaluate(step=step, record_video=config.record_video)
        #     self._log_test(step, info)

        while step < config.max_global_step:
            # collect rollouts
            if runner:
                # run_time = time()
                rollout, info = next(runner)
                # info["run_time"] = time() - run_time
                # print(f"*********env data collect time: {time() - run_time}")
                if self._average_info:
                    info = mpi_gather_average(info)
                
                # store_time = time()
                info_store = self._agent.store_episode(rollout, step=step) # store a dict{list-length is 2000}
                # info["store_time"] = time() - store_time
                # print(f"*********store time: {time() - store_time}")
                
                if self._config.algo in ["reachable_gail"] and info_store is not None:
                    info.update(info_store)
                step_per_batch = mpi_sum(len(rollout["ac"]))
            else:
                step_per_batch = mpi_sum(1)
                info = {}
            # train_time = time()
            _train_info = self._agent.train(step)
            # print(f"*********train time: {time() - train_time}")

            if runner and step < config.max_ob_norm_step:
                self._update_normalizer(rollout)

            step += step_per_batch
            update_iter += 1

            # log training and episode information or evaluate
            if self._is_chef:
                pbar.update(step_per_batch)
                ep_info.add(info)
                train_info.add(_train_info)
                
                if update_iter % config.log_interval == 0:
                    train_info.add(
                        {
                            "sec": (time() - st_time) / config.log_interval,
                            "steps_per_sec": (step - st_step) / (time() - st_time),
                            "update_iter": update_iter,
                        }
                    )
                    st_time = time()
                    st_step = step
                    self._log_train(step, train_info.get_dict(), ep_info.get_dict())
                    ep_info = Info()
                    train_info = Info()

                if config.num_eval and update_iter % config.evaluate_interval == 0:
                    logger.info("Evaluate at %d", update_iter)
                    # eval_time = time()
                    rollout, info = self._evaluate(step=step, record_video=config.record_video)
                    # info["eval_time"] = time() - eval_time
                    # create a heatmap of the rollouts
                    # if self._config.algo in ["gail", "airl"]:
                    #     target_taskID = None
                    #     if config.target_taskID:
                    #         target_taskID = torch.tensor(np.array(config.target_taskID).reshape(1, -1), dtype=torch.float32).to(self._config.device)
                    #     self.create_heatmap(rollout[-1], step, target_taskID)
                    
                    self._log_test(step, info)

                # if update_iter % config.ckpt_interval == 0:
                    # if self._config.algo in ["reachable_gail", "gail-v2"]:
                    # dvd, reach, total = self.save_rewards(self._config.algo)
                    # self.dvd_rew.append(dvd)
                    # self.reach_rew.append(reach)
                    # self.total_rew.append(total)
                    # self._save_ckpt(step, update_iter)

        # if self._config.algo in ["reachable_gail", "gail-v2"]:
        #     dvd, reach, total = self.save_rewards(self._config.algo)
        #     self.dvd_rew.append(dvd)
        #     self.reach_rew.append(reach)
        #     self.total_rew.append(total)
        # self.visualize_average(self._config.algo)
        self._save_ckpt(step, update_iter)
        
        logger.info("Reached %s steps. worker %d stopped.", step, config.rank)

    def _update_normalizer(self, rollout):
        """ Updates normalizer with @rollout. """
        if self._config.ob_norm:
            self._agent.update_normalizer(rollout["ob"])
    
    def _evaluate(self, step=None, record_video=False):
        """
        Runs one rollout if in eval mode (@idx is not None).
        Runs num_record_samples rollouts if in train mode (@idx is None).

        Args:
            step: the number of environment steps.
            record_video: whether to record video or not.
        """
        logger.info("Run %d evaluations at step=%d", self._config.num_eval, step)
        rollouts = []
        info_history = Info()
        ood_i = 0
        for i in range(self._config.num_eval):
            logger.warn("Evaluate run %d", i + 1)
            if hasattr(self._agent, 'reset'):
                self._agent.reset()
            rollout, info, frames = self._runner.run_episode(
                is_train=False, record_video=((record_video and (i + 1) == self._config.num_eval) or self._config.record_ood_video)
            )
            rollouts.append(rollout)
            # logger.info(
            #     "rollout: %s", {k: v for k, v in info.items() if not "qpos" in k}
            # )

            if record_video and (i + 1) == self._config.num_eval:
                ep_rew = info["rew"]
                ep_success = (
                    "s"
                    if "episode_success" in info and info["episode_success"]
                    else "f"
                )
                fname = "{}_step_{:011d}_{}_r_{}_{}.mp4".format(
                    self._config.env, step, i, ep_rew, ep_success,
                )
                video_path = self._save_video(fname, frames)
                if self._config.is_train:
                    info["video"] = wandb.Video(video_path, fps=15, format="mp4")

            info_history.add(info)

            # if failure is caused by ood issue, then save the rollout in the record_dir
            # if self._config.is_save_ood_traj and info["failure_ood"]:
            #     self._ood_trajectories.append(rollout)

            #     # save videos of the ood trajectories
            #     if self._config.record_ood_video:
            #         ep_rew = info["rew"]
            #         ep_success = (
            #             "s"
            #             if "episode_success" in info and info["episode_success"]
            #             else "f"
            #         )
            #         fname = "step_{:011d}_{}_r_{}_{}_ood.mp4".format(
            #             step, ood_i, ep_rew, ep_success,
            #         )
            #         ood_i += 1
            #         video_path = self._save_video(fname, frames, is_ood=True)

        print(f"Success rate: {np.mean(info_history['success_rate'])}")
        return rollouts, info_history

    def save_rewards(self, algo):
        # dvd, reach, total = None, None, None
        env = self._env_eval
        pi = self._agent
        
        maze_str = env.str_maze_spec
        lines = maze_str.strip().split('\\')
        width, height = len(lines), len(lines[0])
        
        available_loc = env.empty_and_goal_locations
        ob = env.reset()

        # make a numpy arrary that has the same shape as the maze_str
        maze = np.zeros((width, height))
        maze1 = np.zeros((width, height))
        maze2 = np.zeros((width, height))

        repeat_time = 200
        # ac_list = [env.action_space.sample()['ac'] for _ in range(repeat_time)]
        # ac_list = np.stack(ac_list, axis=0)
        # ac = to_tensor(ac_list, pi._config.device)
        
        # ob_noise = [self.np_random.uniform(low=-.1, high=.1, size=env.model.nq) for _ in range(repeat_time)]
        ob['ob'] = np.repeat(ob['ob'].reshape(1, -1), repeat_time, axis=0)
        # predict rewards
        for i in available_loc:
            i_repeat = np.repeat(np.array(i).reshape(1, -1), repeat_time, axis=0)
            new_pos = i_repeat + np.array(self.ob_noise)
            new_ob = {'ob': np.concatenate([new_pos, ob['ob'][:, 2:]], axis=1)}
            new_ob = pi.normalize(new_ob)
            new_ob = to_tensor(new_ob, pi._config.device)
            
            prox_rew = pi._predict_reward(new_ob, self.ac)

            if algo == "reachable_gail":
                maze[int(i[0]), int(i[1])]  = torch.mean(prox_rew['gail_rew']).cpu().item()
                maze1[int(i[0]), int(i[1])] = torch.mean(prox_rew['reach_rew']).cpu().item()
            else:
                maze[int(i[0]), int(i[1])] = torch.mean(prox_rew['rew']).cpu().item()

            maze2[int(i[0]), int(i[1])] = torch.mean(prox_rew['rew']).cpu().item()
        return maze, maze1, maze2   
        
    def visualize_average(self, algo, is_sigmoid=False):
        un_available_loc = self._env_eval.wall_locations
        target_pos = self._env_eval.goal_locations

        if not self._config.is_train:
            ckpt_path, ckpt_num = get_ckpt_path(self._config.log_dir, self._config.ckpt_num)

            if ckpt_path is not None:
                logger.warn("Load checkpoint %s", ckpt_path)
                ckpt = torch.load(ckpt_path, map_location=self._config.device)
                self.sampled_few_demo_index = ckpt["sampled_few_demo_index"]
                self._agent.load_training_data(sampled_few_demo_index=self.sampled_few_demo_index)

            replay_path = os.path.join(
                self._config.log_dir, "replay_%08d.pkl" % ckpt_num
            )
            logger.warn("Load replay_buffer %s", replay_path)
            if os.path.exists(replay_path):
                with gzip.open(replay_path, "rb") as f:
                    buffers = pickle.load(f)
                    dvd_rew, reach_rew, total_rew = buffers["dvd_rew"], buffers["reach_rew"], buffers["total_rew"]
            else:
                logger.warn("Replay buffer not exists at %s", replay_path)
        else:
            dvd_rew, reach_rew, total_rew = self.dvd_rew, self.reach_rew, self.total_rew
        # make the walls' value smaller than the minimum reward
        min_rew = np.min(dvd_rew)
        # for item in un_available_loc:
        #     maze[item] = min_rew - 0.1
        
        if algo == "reachable_gail":
            min_rew1, min_rew2 = np.min(reach_rew), np.min(total_rew)
            max_rew, max_rew1, max_rew2 = np.max(dvd_rew), np.max(reach_rew), np.max(total_rew)

            for item in un_available_loc:
                for i in range(len(dvd_rew)):
                    dvd_rew[i][item] = min_rew - 0.1
                    reach_rew[i][item] = min_rew1 - 0.1
                    total_rew[i][item] = min_rew2 - 0.1
            
            # for i in range(len(dvd_rew)):
            #     dvd_rew[i][target_pos[0]] = max_rew + 0.1
            #     reach_rew[i][target_pos[0]] = max_rew1 + 0.1
            #     total_rew[i][target_pos[0]] = max_rew2 + 0.1
                
        # get expert demos positions of the target task
        expert_traj = self._agent._data_dataset[self._config.target_task_index_in_demo_path]._data
        color = self._config.env.split('-')[2]
        dict_ = {'red': 4, 'blue': 6, 'magenta': 8, 'yellow': 10}
        goal_pos = expert_traj[0]['ob']['ob'][dict_[color]:dict_[color]+2]
        expert = []
        for item in expert_traj:
            expert.append(item['ob']['ob'][0:2])
        expert = np.stack(expert, axis=0)

        x = np.concatenate((expert[:, 0:1], np.expand_dims(goal_pos[0:1], axis=0)), axis=0) # concate the expert and goal x position
        y = np.concatenate((expert[:, 1:2], np.expand_dims(goal_pos[1:2], axis=0)), axis=0)
        x = x + 0.65
        y = y + 0.75
        colors = np.concatenate(([1 for _ in range(len(expert))], [12]), axis=0)
        size = np.concatenate(([5 for _ in range(len(expert))], [100]), axis=0)

        # plot the heatmap
        for i in range(len(dvd_rew)):
            if is_sigmoid:
                maze, maze1, maze2 = sigmoid(dvd_rew[i]), sigmoid(reach_rew[i]), sigmoid(total_rew[i])
            else:
                maze, maze1, maze2 = dvd_rew[i], reach_rew[i], total_rew[i]
            maze = np.flip(maze, axis=1)
            maze = np.rot90(maze, k=1, axes=(0, 1))

            ## Clear the current figure. Otherwise, the heatmap will have multiple legends
            plt.clf()
            if is_sigmoid:
                hm = sn.heatmap(data = maze, vmin=sigmoid(min_rew), vmax=sigmoid(max_rew), cmap='viridis')
            else:
                hm = sn.heatmap(data = maze, vmin=min_rew, vmax=max_rew, cmap='viridis')
            figure = hm.get_figure()
            path = os.path.join(self._config.visual_dir, f"heatmap_dvd_s{i}_repeat.png")
            plt.scatter(x, y, c=colors, s=size, cmap='Set3')
            figure.savefig(path)

            if algo == "reachable_gail":
                maze1 = np.flip(maze1, axis=1)
                maze1 = np.rot90(maze1, k=1, axes=(0, 1))

                maze2 = np.flip(maze2, axis=1)
                maze2 = np.rot90(maze2, k=1, axes=(0, 1))
                
                plt.clf()
                if is_sigmoid:
                    hm = sn.heatmap(data = maze1, vmin=sigmoid(min_rew1), vmax=sigmoid(max_rew1), cmap='viridis')
                else:
                    hm = sn.heatmap(data = maze1, vmin=min_rew1, vmax=max_rew1, cmap='viridis')
                figure = hm.get_figure()
                path1 = os.path.join(self._config.visual_dir, f"heatmap_reach_s{i}_repeat.png")
                plt.scatter(x, y, c=colors, s=size, cmap='Set3')
                figure.savefig(path1)
                
                plt.clf()
                if is_sigmoid:
                    hm = sn.heatmap(data = maze2, vmin=sigmoid(min_rew2), vmax=sigmoid(max_rew2), cmap='viridis')
                else:
                    hm = sn.heatmap(data = maze2, vmin=min_rew2, vmax=max_rew2, cmap='viridis')
                figure = hm.get_figure()
                path2 = os.path.join(self._config.visual_dir, f"heatmap_total_s{i}_repeat.png")
                plt.scatter(x, y, c=colors, s=size, cmap='Set3')
                figure.savefig(path2)

        make_gif(self._config.visual_dir, video_p=self._config.log_dir, algo="reachable_gail", is_sigmoid=is_sigmoid)
        print(f"Visualized the heatmap of the rewards")
    
    def visulize_8directons(self, algo, is_sigmoid=False, step=None):
        if not self._config.is_train:
            step, update_iter = self._load_ckpt(
                self._config.init_ckpt_path, self._config.ckpt_num
            )
            self._agent.load_training_data(sampled_few_demo_index=self.sampled_few_demo_index)

            logger.info(
                "Run visulation at step=%d, update_iter=%d",
                step,
                update_iter,
            )

        env = self._env_eval
        pi = self._agent
        
        maze_str = env.str_maze_spec
        lines = maze_str.strip().split('\\')
        width, height = len(lines), len(lines[0])
        
        available_loc = env.empty_and_goal_locations
        un_available_loc = env.wall_locations
        target_pos = env.goal_locations
        distractor_pos = env.dis_locations

        ob = env.reset()
        directions = 8 #[4, 8]
        # make a numpy arrary that has the same shape as the maze_str
        maze = np.zeros((width, height, directions))
        if algo == "reachable_gail":
            maze1 = np.zeros((width, height, directions))
            maze2 = np.zeros((width, height, directions))

        repeat_time = directions 
        if repeat_time == 4:
            ac_list = [[0.0, 1.0], [1.0, 0.0], [0.0, -1.0], [-1.0, 0.0]]
        elif repeat_time == 8:
            ac_list = [[0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [1.0, -1.0], [0.0, -1.0], [-1.0, -1.0], [-1.0, 0.0], [-1.0, 1.0]]
        
        ac_list = np.stack(ac_list, axis=0)
        ac = to_tensor(ac_list, pi._config.device)
        
        # predict rewards
        for i in available_loc:
            new_pos = [i[0] + 0.5, i[1] + 0.5] # + self.np_random.uniform(low=-.1, high=.1, size=env.model.nq) # in the center of the square?
            new_ob = np.concatenate([new_pos, [0.,0.], ob['ob'][4:]]).reshape(1, -1)
            new_ob = {'ob': np.repeat(new_ob, repeat_time, axis=0)}
            new_ob = pi.normalize(new_ob)
            new_ob = to_tensor(new_ob, pi._config.device)
            if algo == "reachable_gail":
                prox_rew = pi._predict_reward(new_ob, ac)
                gail_rew = prox_rew['gail_rew'].cpu().numpy()
                reach_rew = prox_rew['reach'].cpu().numpy()
                total_rew = prox_rew['rew'].cpu().numpy()
                for jj in range(directions):
                    maze[int(i[0]), int(i[1]), jj] = gail_rew[jj]
                    maze1[int(i[0]), int(i[1]), jj] = reach_rew[jj]
                    maze2[int(i[0]), int(i[1]), jj] = total_rew[jj]

            else:
                dvd_rew = pi._predict_reward(new_ob, ac)
                gail_rew = dvd_rew['gail_rew'].cpu().numpy()
                for jj in range(directions):
                    maze[int(i[0]), int(i[1]), jj] = gail_rew[jj]
        
        # make the walls' value smaller than the minimum reward
        min_rew = np.min(maze)
        for item in un_available_loc:
            for jj in range(directions):
                maze[item[0], item[1], jj] = min_rew - 0.1
        
        if algo == "reachable_gail":
            min_rew1, min_rew2 = np.min(maze1), np.min(maze2)
            for item in un_available_loc:
                for jj in range(directions):
                    maze1[item[0], item[1], jj] = min_rew1 - 0.1
                    maze2[item[0], item[1], jj] = min_rew2 - 0.1

            # maze1 (21,21,8) average (21,21)
            maze1 = np.mean(maze1, axis=2)

            # double check to make sure the maze plot is correct
            max_rew, max_rew1, max_rew2 = np.max(maze), np.max(maze1), np.max(maze2)
            for jj in range(directions):
                maze[target_pos[0][0], target_pos[0][1], jj] = max_rew + 0.1
                maze2[target_pos[0][0], target_pos[0][1], jj] = max_rew2 + 0.1
                # maze1[target_pos[0][0], target_pos[0][1], jj] = max_rew1 + 0.1
            maze1[target_pos[0][0], target_pos[0][1]] = max_rew1 + 0.1
        
            if is_sigmoid:
                for g in range(21):
                    for f in range(21):
                        for h in range(8):
                            maze[g, f, h] = sigmoid(maze[g, f, h])
                            maze2[g, f, h] = sigmoid(maze2[g, f, h])
                        maze1[g, f] = sigmoid(maze1[g, f])
        
        # get expert demos positions of the target task
        expert_traj = pi._data_dataset[self._config.target_task_index_in_demo_path]._data
        color = self._config.env.split('-')[2]
        dict_ = {'red': 4, 'blue': 6, 'magenta': 8, 'yellow': 10}
        goal_pos = expert_traj[0]['ob']['ob'][dict_[color]:dict_[color]+2]
        expert = []
        for item in expert_traj:
            expert.append(item['ob']['ob'][0:2])
        expert = np.stack(expert, axis=0)

        x = np.concatenate((expert[:, 0:1], np.expand_dims(goal_pos[0:1], axis=0)), axis=0) # concate the expert and goal x position
        y = np.concatenate((expert[:, 1:2], np.expand_dims(goal_pos[1:2], axis=0)), axis=0)
        x = x + 0.65
        y = y + 0.75
        colors = np.concatenate(([1 for _ in range(len(expert))], [12]), axis=0)
        size = np.concatenate(([5 for _ in range(len(expert))], [50]), axis=0)

        # plot the heatmap
        if algo == "reachable_gail":
            if is_sigmoid:
                path = os.path.join(self._config.log_dir, f"heatmap_dvd_s{step}_repeat{repeat_time}_sig.png")
                path1 = os.path.join(self._config.log_dir, f"heatmap_reach_s{step}_repeat{repeat_time}_sig.png")
                path2 = os.path.join(self._config.log_dir, f"heatmap_total_s{step}_repeat{repeat_time}_sig.png")
            else:
                path = os.path.join(self._config.log_dir, f"heatmap_dvd_s{step}_repeat{repeat_time}.png")
                path1 = os.path.join(self._config.log_dir, f"heatmap_reach_s{step}_repeat{repeat_time}.png")
                path2 = os.path.join(self._config.log_dir, f"heatmap_total_s{step}_repeat{repeat_time}.png")
            visualize_heatmap_8sections(repeat_time, maze, path, x, y, colors, size, available_loc)
            visualize_heatmap_8sections(repeat_time, maze2, path2, x, y, colors, size, available_loc)
            heatmap_arrows(maze1, path1, x, y, colors, size, available_loc)

        print(f"Visualized the heatmap of the rewards")

    def visulize_critic(self, algo, step=None):
        if not self._config.is_train:
            step, update_iter = self._load_ckpt(
                self._config.init_ckpt_path, self._config.ckpt_num
            )
            self._agent.load_training_data(sampled_few_demo_index=self.sampled_few_demo_index)

            logger.info(
                "Run visulation at step=%d, update_iter=%d",
                step,
                update_iter,
            )

        env = self._env_eval
        pi = self._agent
        
        maze_str = env.str_maze_spec
        lines = maze_str.strip().split('\\')
        width, height = len(lines), len(lines[0])
        
        available_loc = env.empty_and_goal_locations
        un_available_loc = env.wall_locations
        target_pos = env.goal_locations
        distractor_pos = env.dis_locations

        ob = env.reset()
        # make a numpy arrary that has the same shape as the maze_str
        maze1 = np.zeros((width, height))
        repeat_time = 200

        ob_noise = [self.np_random.uniform(low=-.1, high=.1, size=env.model.nq) for _ in range(repeat_time)]
        ob['ob'] = np.repeat(ob['ob'].reshape(1, -1), repeat_time, axis=0)
        # predict rewards
        for i in available_loc:
            i_repeat = np.repeat(np.array(i).reshape(1, -1), repeat_time, axis=0)
            new_pos = i_repeat + np.array(ob_noise)
            new_ob = {'ob': np.concatenate([new_pos, ob['ob'][:, 2:]], axis=1)}
            new_ob = pi.normalize(new_ob)
            new_ob = to_tensor(new_ob, pi._config.device)
        
            values = pi._rl_agent._critic(new_ob)
            reach_rew = torch.mean(values).cpu().detach().numpy()

            maze1[int(i[0]), int(i[1])] = reach_rew
            
        # make the walls' value smaller than the minimum reward
        min_rew1 = np.min(maze1)
        for item in un_available_loc:
            maze1[item[0], item[1]] = min_rew1 - 0.1

        # double check to make sure the maze plot is correct
        max_rew1 = np.max(maze1)
        maze1[target_pos[0][0], target_pos[0][1]] = max_rew1 + 0.1
        
        # get expert demos positions of the target task
        expert_traj = pi._data_dataset[self._config.target_task_index_in_demo_path]._data
        color = self._config.env.split('-')[2]
        dict_ = {'red': 4, 'blue': 6, 'magenta': 8, 'yellow': 10}
        goal_pos = expert_traj[0]['ob']['ob'][dict_[color]:dict_[color]+2]
        expert = []
        for item in expert_traj:
            expert.append(item['ob']['ob'][0:2])
        expert = np.stack(expert, axis=0)

        x = np.concatenate((expert[:, 0:1], np.expand_dims(goal_pos[0:1], axis=0)), axis=0) # concate the expert and goal x position
        y = np.concatenate((expert[:, 1:2], np.expand_dims(goal_pos[1:2], axis=0)), axis=0)
        x = x + 0.65
        y = y + 0.75
        colors = np.concatenate(([1 for _ in range(len(expert))], [12]), axis=0)
        size = np.concatenate(([5 for _ in range(len(expert))], [50]), axis=0)

        # plot the heatmap
        heatmap_arrows(maze1, os.path.join(self._config.log_dir, f"heatmap_V_s{step}_repeat{repeat_time}.png"), x, y, colors, size, available_loc)

        print(f"Visualized the heatmap of the rewards")

    def evaluate(self):
        """ Evaluates an agent stored in chekpoint with @self._config.ckpt_num. """
        step, update_iter = self._load_ckpt(
            self._config.init_ckpt_path, self._config.ckpt_num
        )

        logger.info(
            "Run %d evaluations at step=%d, update_iter=%d",
            self._config.num_eval,
            step,
            update_iter,
        )

        rollouts, info = self._evaluate(
            step=step, record_video=self._config.record_video
        )

        info_stat = info.get_stat()
        os.makedirs("result", exist_ok=True)
        with h5py.File("result/{}.hdf5".format(self._config.run_name), "w") as hf:
            for k, v in info.items():
                hf.create_dataset(k, data=info[k])
        with open("result/{}.txt".format(self._config.run_name), "w") as f:
            for k, v in info_stat.items():
                f.write("{}\t{:.03f} $\\pm$ {:.03f}\n".format(k, v[0], v[1]))

        if self._config.record_demo:
            new_rollouts = []
            for rollout in rollouts:
                new_rollout = {
                    "obs": rollout["ob"],
                    "actions": rollout["ac"],
                    "rewards": rollout["rew"],
                    "dones": rollout["done"],
                }
                new_rollouts.append(new_rollout)

            fname = "{}_step_{:011d}_{}_trajs.pkl".format(
                self._config.run_name, step, self._config.num_eval,
            )
            path = os.path.join(self._config.demo_dir, fname)
            logger.warn("[*] Generating demo: {}".format(path))
            with open(path, "wb") as f:
                pickle.dump(new_rollouts, f)
    
    """
    def save_rollout_to_pkl_file(self, rollouts, step, update_iter):
        # Saves a ood case to a pickle file.
        file_name = os.path.join(self._config.ood_traj_dir, "ood_{:011d}_{}_trajs.pkl".format(step, len(rollouts)))
        if len(rollouts) == 0:
            return
        new_rollouts = []
        for rollout in rollouts:
            new_rollout = {
                "obs": rollout["ob"],
                "actions": rollout["ac"],
                "dones": rollout["done"],
            }
            new_rollouts.append(new_rollout)

        logger.warn("[*] Generating demo: {}".format(file_name))
        with open(file_name, "wb") as f:
            pickle.dump(new_rollouts, f) # a list of dic-lists
        
        self._ood_trajectories = []
    """

    def _save_video(self, fname, frames, is_ood=False, fps=15.0):
        """ Saves @frames into a video with file name @fname. """
        if is_ood:
            path = os.path.join(self._config.ood_record_dir, fname)
        else:
            path = os.path.join(self._config.record_dir, fname)
        
        logger.warn("[*] Generating video: {}".format(path))

        def f(t):
            frame_length = len(frames)
            new_fps = 1.0 / (1.0 / fps + 1.0 / frame_length)
            idx = min(int(t * new_fps), frame_length - 1)
            return frames[idx]

        video = mpy.VideoClip(f, duration=len(frames) / fps + 2)

        video.write_videofile(path, fps, verbose=False)
        logger.warn("[*] Video saved: {}".format(path))
        return path
