"""
Runs rollouts (RolloutRunner class) and collects transitions using Rollout class.
"""

from collections import defaultdict
import gym
import numpy as np
import cv2, torch

from utils.info_dict import Info
from utils.gym_env import get_non_absorbing_state, zero_value
from utils.general import cat_dict_numpy_rollout
import copy
from environments.metaworld.policies import *
from utils.pytorch import center_crop_images
# import random, pickle, wandb, os
# import matplotlib.pyplot as plt
# from time import time
# from utils.logger import logger

class Rollout(object):
    """
    Rollout storing an episode.
    """

    def __init__(self):
        """ Initialize buffer. """
        self._history = defaultdict(list)

    def add(self, data):
        """ Add a transition @data to rollout buffer. """
        for key, value in data.items():
            self._history[key].append(value)

    def get(self):
        """ Returns rollout buffer and clears buffer. """
        batch = {}
        batch["ob"] = self._history["ob"] # in rollout buffer, len(ob)=len(ac)=len(ob_next)
        batch["ob_next"] = self._history["ob_next"]
        batch["ac"] = self._history["ac"]
        batch["ac_before_activation"] = self._history["ac_before_activation"]
        batch["done"] = self._history["done"]
        batch["done_mask"] = self._history["done_mask"]
        batch["rew"] = self._history["rew"]
        batch["success"] = self._history["success"]
        batch['traj_start_indicator'] = self._history['traj_start_indicator']
        if self._history["log_pis"]:
            batch["log_pis"] = self._history["log_pis"]
        if self._history["id"]:
            batch["id"] = self._history["id"]

        self._history = defaultdict(list)
        return batch


class RolloutRunner(object):
    """
    Run rollout given environment and policy.
    """

    def __init__(self, config, env, env_eval, pi):
        """
        Args:
            config: configurations for the environment.
            env: environment.
            pi: policy.
        """

        self._config = config
        self._env = env
        self._env_eval = env_eval
        self._pi = pi

    def run(
        self,
        is_train=True,
        every_steps=None,
        every_episodes=None,
        log_prefix="",
        step=0,
    ):
        """
        Collects trajectories and yield every @every_steps/@every_episodes.

        Args:
            is_train: whether rollout is for training or evaluation.
            every_steps: if not None, returns rollouts @every_steps
            every_episodes: if not None, returns rollouts @every_epiosdes
            log_prefix: log as @log_prefix rollout: %s
        """
        if every_steps is None and every_episodes is None:
            raise ValueError("Both every_steps and every_episodes cannot be None")

        config = self._config
        device = config.device
        env = self._env if is_train else self._env_eval
        pi = self._pi
        # if gail_rl_algo is ppo, then we don't predict reward to prevent time consuming caused by averaging rewards
        il = False #hasattr(pi, "predict_reward") if not self._config.encoder_type == "cnn" else False #if not (config.algo in ["gail", "gail-v2", "reachable_gail", "reachable_gail-v0"] and (config.gail_rl_algo=='ppo' or (config.gail_rl_algo in ['sac', 'ddpg'] and config.is_relabel_rew==True))) else False

        # initialize rollout buffer
        rollout = Rollout()
        reward_info = Info()
        ep_info = Info()
        episode = 0
        gail_rew, reach_rew, sig_reach = [], [], []
        target_taskID, dis_target_taskID = None, None
        if config.with_taskID:
            target_taskID = np.array(config.target_taskID).reshape(1, -1)

        while True:
            done = False
            ep_len = 0
            ep_rew = 0
            ep_rew_rl = 0
            if il:
                ep_rew_il = 0

            ob = env.reset()
            i = 0
            # run rollout
            while not done:
                # sample action from policy
                if step < config.warm_up_steps:
                    ac, ac_before_activation = env.action_space.sample(), 0
                else:
                    ac, ac_before_activation, log_pis = pi.act(ob, is_train=is_train, target_taskID=target_taskID, return_log_prob=True)

                rollout.add(
                    {"ob": ob, "ac": ac, "ac_before_activation": ac_before_activation}
                )
                if i == 0:
                    rollout.add({"traj_start_indicator": 1})
                else:
                    rollout.add({"traj_start_indicator": 0})
                i += 1
                # ob_pre = ob

                # take a step
                ob, reward, done, info = env.step(ac)

                if config.algo=='sqil':
                    reward = 0.0
                # rollout.add({"ob_next": ob})
                rollout.add({"ob_next": ob, "success": info["success"]})

                # replace reward
                if il:
                    reward_dict = pi.predict_reward(ob, ac, dis_target_taskID)

                    reward_il = reward_dict["rew"]
                    reward_rl = (1 - config.gail_env_reward) * reward_il + config.gail_env_reward * reward
                    if config.algo in ["reachable_gail", "reachable_gail-v0"]:
                        reach_rew.append(reward_dict["reach"])
                        sig_reach.append(reward_dict["sig_reach"])
                    if config.algo in ["reachable_gail", "reachable_gail-v0", "gail"]:
                        gail_rew.append(reward_dict["gail_rew"])
                else:
                    reward_rl = reward

                if config.target_taskID:
                    rollout.add({"id": config.target_taskID})

                rollout.add({"done": done, "rew": reward})
                step += 1
                ep_len += 1
                ep_rew += reward
                ep_rew_rl += reward_rl
                if il:
                    ep_rew_il += reward_il

                if done and ep_len < env.max_episode_steps:
                    done_mask = 0  # -1 absorbing, 0 done, 1 not done
                else:
                    done_mask = 1

                rollout.add(
                    {"done_mask": done_mask}
                )  # -1 absorbing, 0 done, 1 not done

                reward_info.add(info)
                if il:
                    del reward_dict["rew"]
                    reward_info.add(reward_dict)
                    

                if config.absorbing_state and done_mask == 0:
                    absorbing_state = env.get_absorbing_state()
                    absorbing_action = zero_value(env.action_space)
                    rollout._history["ob_next"][-1] = absorbing_state
                    rollout.add(
                        {
                            "ob": absorbing_state,
                            "ob_next": absorbing_state,
                            "ac": absorbing_action,
                            "ac_before_activation": absorbing_action,
                            "rew": 0.0,
                            "done": 0,
                            "done_mask": -1,  # -1 absorbing, 0 done, 1 not done
                        }
                    )

                if every_steps is not None and step % every_steps == 0:  # yield one transition information
                    yield rollout.get(), ep_info.get_dict() #only_scalar=True

            # compute average/sum of information
            ep_info.add({"len": ep_len, "rew": ep_rew, "rew_rl": ep_rew_rl})
            if il:
                ep_info.add({"rew_il": ep_rew_il})
                # # ep_info add wandb histogram of gail and reach reward
                # path1 = os.path.join(config.record_dir, "gail_rew{}.png".format(step))
                # plt.hist(np.array(gail_rew))
                # plt.savefig(path1)
                # plt.close()
                # ep_info['gail_reward_histogram'] = wandb.Image(path1)
                # gail_rew = []
                # if config.algo in ["reachable_gail", "reachable_gail-v0"]:
                #     path2 = os.path.join(config.record_dir, "reach_rew{}.png".format(step))
                #     path3 = os.path.join(config.record_dir, "sig_reach{}.png".format(step))
                #     plt.hist(np.array(reach_rew))
                #     plt.savefig(path2)
                #     plt.close()
                #     plt.hist(np.array(sig_reach))
                #     plt.savefig(path3)
                #     plt.close()
                #     ep_info["reach_reward_histogram"] =  wandb.Image(path2)
                #     ep_info["sig_reach_histogram"] = wandb.Image(path3)

            
            reward_info_dict = reward_info.get_dict(reduction="sum", only_scalar=True)
            ep_info.add(reward_info_dict)
            # reward_info_dict.update({"len": ep_len, "rew": ep_rew, "rew_rl": ep_rew_rl})
            # if il:
            #     reward_info_dict.update({"rew_il": ep_rew_il})

            # logger.info(
            #     log_prefix + " rollout: %s",
            #     {
            #         k: v
            #         for k, v in reward_info_dict.items()
            #         if not "qpos" in k and np.isscalar(v)
            #     },
            # )

            episode += 1

            if every_episodes is not None and episode % every_episodes == 0:
                yield rollout.get(), ep_info.get_dict(only_scalar=True)

    def run_episode(self, max_step=10000, is_train=True, record_video=False):
        """
        Runs one episode and returns the rollout (mainly for evaluation).

        Args:
            max_step: maximum number of steps of the rollout.
            is_train: whether rollout is for training or evaluation.
            record_video: record video of rollout if True.
        """
        config = self._config
        device = config.device
        env = self._env_eval #self._env if is_train else self._env_eval
        pi = self._pi
        il = hasattr(pi, "predict_reward")
        # il = False # for now we don't log train rew_il to save time TODO: remove this line

        # initialize rollout buffer
        rollout = Rollout()
        reward_info = Info()

        done = False
        return_log_prob = False
        target_taskID, dis_target_taskID = None, None
        ep_len = 0
        ep_rew = 0
        ep_rew_rl = 0
        if il:
            ep_rew_il = 0
            all_obs = []
            all_acs = []
        if config.with_taskID:
            target_taskID = np.array(config.target_taskID).reshape(1, -1)

        ob = env.reset()

        self._record_frames = []
        if record_video:
            self._store_frame(env, ep_len, ep_rew)

        # run rollout
        while not done and ep_len < max_step:
            # sample action from policy
            ac, ac_before_activation, log_pis = pi.act(ob, is_train=is_train, target_taskID=target_taskID, return_log_prob=return_log_prob)

            rollout.add(
                {"ob": ob, "ac": ac, "ac_before_activation": ac_before_activation}
            )
            if il:
                all_obs.append(ob)
                all_acs.append(ac)

            ob_pre = ob

            # take a step
            ob, reward, done, info = env.step(ac)

            reward_rl = reward
            rollout.add({"done": done, "rew": reward})
            ep_len += 1
            ep_rew += reward
            ep_rew_rl += reward_rl

            reward_info.add(info)

            if record_video and ep_len > config.record_video_start_step and (ep_len % config.record_video_every_step == 0 or done or ep_len == max_step):
                frame_info = info.copy()
                if il:
                    reward_dict = pi.predict_reward(ob, ac, dis_target_taskID, None)
                    reward_il = reward_dict["rew"]
                    reward_rl = (
                                        1 - config.gail_env_reward
                                ) * reward_il + config.gail_env_reward * reward
                    frame_info.update(
                        {
                            "ep_rew_il": ep_rew_il,
                            "rew_il": reward_il,
                            "rew_rl": reward_rl,
                        }
                    )
                    if config.algo in ["reachable_gail", "reachable_gail-v0"]:
                        rew_info = {
                                "gail_rew": reward_dict["gail_rew"],
                                "sig_gail_rew": reward_dict["sig_gail_rew"],
                                "reach_rew": reward_dict["reach"],
                                "sig_reach": reward_dict["sig_reach"],
                            }
                        new_info = {**rew_info, **frame_info}
                        frame_info = new_info
               
                self._store_frame(env, ep_len, ep_rew, frame_info)

        # add last observation
        rollout.add({"ob": ob})

        # compute average/sum of information
        ep_info = {"len": ep_len, "rew": ep_rew, "rew_rl": ep_rew_rl}
        if not is_train and il:
            temp_all_obs = cat_dict_numpy_rollout(all_obs)
            if self._config.encoder_type == "cnn" and len(temp_all_obs["ob"].shape) == 4:
                temp_all_obs["ob"] = center_crop_images(temp_all_obs["ob"], self._config.encoder_image_size)
            
            reward_dict = pi.predict_reward(temp_all_obs, cat_dict_numpy_rollout(all_acs), dis_target_taskID, None)
            reward_il = reward_dict["rew"]
            reward_rl = (1 - config.gail_env_reward) * reward_il + config.gail_env_reward * reward
            ep_rew_rl = np.sum(reward_rl)
            ep_rew_il = np.sum(reward_il)
            ep_info["rew_rl"] = ep_rew_rl
            ep_info["rew_il"] = ep_rew_il

            del reward_dict["rew"]
            reward_info.add(reward_dict)
        ep_info.update(reward_info.get_dict(reduction="sum", only_scalar=True))
        ep_info["success_rate"] = info["success"]

        return rollout.get(), ep_info, self._record_frames

    def _store_frame(self, env, ep_len, ep_rew, info={}):
        """ Renders a frame and stores in @self._record_frames. """
        color = (200, 200, 200)

        # render video frame
        frame = env.render(mode='rgb_array')  # HACK env.render("rgb_array")
        # frame = cv2.flip(frame, -1)

        if len(frame.shape) == 4:
            frame = frame[0]
        if np.max(frame) <= 1.0:
            frame *= 255.0

        h, w = frame.shape[:2]
        # if h < 500:
        #     h, w = 500, 500
        #     frame = cv2.resize(frame, (h, w))

        frame = np.concatenate([frame, np.zeros((h, w, 3))], 0)
        scale = h / 400

        # add caption to video frame
        if self._config.record_video_caption:
            text = "{:4} {}".format(ep_len, ep_rew)
            font_size = 0.4 * scale
            thickness = 1
            offset = int(12 * scale)
            x, y = int(5 * scale), h + int(10 * scale)
            cv2.putText(
                frame,
                text,
                (x, y),
                cv2.FONT_HERSHEY_SIMPLEX,
                font_size,
                (255, 255, 0),
                thickness,
                cv2.LINE_AA,
            )
            for i, k in enumerate(info.keys()):
                v = info[k]
                # float precision
                if isinstance(v, float):
                    v = "{:.3f}".format(v)
                key_text = "{}: ".format(k)
                (key_width, _), _ = cv2.getTextSize(
                    key_text, cv2.FONT_HERSHEY_SIMPLEX, font_size, thickness
                )

                cv2.putText(
                    frame,
                    key_text,
                    (x, y + offset * (i + 2)),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    font_size,
                    (66, 133, 244),
                    thickness,
                    cv2.LINE_AA,
                )

                cv2.putText(
                    frame,
                    str(v),
                    (x + key_width, y + offset * (i + 2)),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    font_size,
                    (255, 255, 255),
                    thickness,
                    cv2.LINE_AA,
                )

        self._record_frames.append(frame)
