import os

import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import (
    FigureCanvasAgg as FigureCanvas,
)
import matplotlib.animation as animation
import tempfile
import cv2
from collections import defaultdict

import global_context
from garagei.replay_buffer.path_buffer_ex import PathBufferEx
from garage import TrajectoryBatch
from garagei import log_performance_ex
from iod import sac_utils
from iod.iod import IOD
import copy

from iod.utils import (
    get_torch_concat_obs,
    FigManager,
    get_option_colors,
    record_video,
    draw_2d_gaussians,
    plot_visualize_rewards,
)

from iod import apt_utils
from sklearn.manifold import TSNE
import pickle


class AgentWrapper(object):
    def __init__(self, policies):
        assert isinstance(policies, dict) and "default_policy" in policies
        self.default_policy = policies["default_policy"]
        self.exploration_policy = policies.get("exploration_policy", None)
        self.start_default_policy = policies.get("start_default_policy", None)
        self.hierarchical_policy = policies.get("hierarchical_policy", None)
        self.encoder = policies.get("encoder", None)
        # when adding new policy, modify get_param_values and set_param_values

    def get_actions(self, obs, use_exploration_policy=False):
        raise NotImplementedError
        if use_exploration_policy:
            acs = self.exploration_policy.get_actions(obs)
        else:
            actions = self.default_policy.get_actions(obs)
        return self.default_policy.get_actions(obs)

    def get_param_values(self):
        param_dict = {}
        default_param_dict = self.default_policy.get_param_values()
        for k in default_param_dict.keys():
            param_dict[f"default_{k}"] = default_param_dict[k].detach()

        if self.exploration_policy:
            exploration_param_dict = self.exploration_policy.get_param_values()
            for k in exploration_param_dict.keys():
                param_dict[f"exploration_{k}"] = exploration_param_dict[k].detach()

        if self.start_default_policy:
            start_default_param_dict = self.start_default_policy.get_param_values()
            for k in start_default_param_dict.keys():
                param_dict[f"start_{k}"] = start_default_param_dict[k].detach()

        if self.hierarchical_policy:
            hierarchical_param_dict = self.hierarchical_policy.get_param_values()
            for k in hierarchical_param_dict.keys():
                param_dict[f"hierarchical_{k}"] = hierarchical_param_dict[k].detach()

        if self.encoder:
            encoder_param_dict = self.encoder.state_dict()
            for k in encoder_param_dict.keys():
                param_dict[f"encoder_{k}"] = encoder_param_dict[k].detach()

        return param_dict

    def set_param_values(self, state_dict):
        default_state_dict = {}
        exploration_state_dict = {}
        start_default_state_dict = {}
        hierarchical_state_dict = {}
        encoder_state_dict = {}

        for k, v in state_dict.items():
            k: str
            if k.startswith("default_"):
                default_state_dict[k.replace("default_", "", 1)] = v
            elif k.startswith("exploration_"):
                exploration_state_dict[k.replace("exploration_", "", 1)] = v
            elif k.startswith("start_"):
                start_default_state_dict[k.replace("start_", "", 1)] = v
            elif k.startswith("hierarchical_"):
                hierarchical_state_dict[k.replace("hierarchical_", "", 1)] = v
            elif k.startswith("encoder_"):
                encoder_state_dict[k.replace("encoder_", "", 1)] = v
            else:
                raise ValueError(f"Unknown key: {k}")

        self.default_policy.set_param_values(default_state_dict)
        if self.exploration_policy:
            self.exploration_policy.set_param_values(exploration_state_dict)
        if self.start_default_policy:
            self.start_default_policy.set_param_values(start_default_state_dict)
        if self.hierarchical_policy:
            self.hierarchical_policy.set_param_values(hierarchical_state_dict)
        if self.encoder:
            self.encoder.load_state_dict(encoder_state_dict)

    def eval(self):
        self.default_policy.eval()
        if self.exploration_policy:
            self.exploration_policy.eval()
        if self.start_default_policy:
            self.start_default_policy.eval()
        if self.hierarchical_policy:
            self.hierarchical_policy.eval()
        if self.encoder:
            self.encoder.eval()

    def train(self):
        self.default_policy.train()
        if self.exploration_policy:
            self.exploration_policy.train()
        if self.start_default_policy:
            self.start_default_policy.train()
        if self.hierarchical_policy:
            self.hierarchical_policy.train()
        if self.encoder:
            self.encoder.train()

    def reset(self):
        self.default_policy.reset()
        if self.exploration_policy:
            self.exploration_policy.reset()
        if self.start_default_policy:
            self.start_default_policy.reset()


class METRA(IOD):

    def __init__(
        self,
        *,
        qf1,
        qf2,
        log_alpha,
        tau,
        scale_reward,
        target_coef,
        replay_buffer,
        min_buffer_size,
        inner,
        num_alt_samples,
        split_group,
        dual_reg,
        dual_slack,
        dual_dist,
        pixel_shape=None,
        use_pure_rewards=False,  # added
        model=None,  # added
        use_random_options_for_exploration=0,  # added
        dot_penalty=None,
        option_freq=0,
        exploration_type=0,
        perpendicular=0,
        apt=0,
        icm=None,
        exploration_policy=None,
        exploration_qf1=None,
        exploration_qf2=None,
        exploration_log_alpha=None,
        use_traj_for_apt_rep=0,
        use_start_policy=0,
        hierarchical_qf1=None,
        hierarchical_dim=0,
        prevupd=0,
        prevupd_dim_exp_option=0,
        prevupd_traj_encoder=None,
        prevupd_dual_lam=None,
        prevupd_dual_lam_constraint=None,
        prevupd_freq=-1,
        prevupd_use_same_policy=0,
        debug_noconst=0,
        debug_val=-1,
        no_plot=0,
        hilp=0,
        hilp_expectile=0.95,
        hilp_traj_encoder_tau=0.005,
        hilp_qrl=0,
        hilp_p_trajgoal=0.625,
        hilp_discount=0.99,
        use_double_encoder=0,
        traj_encoder_2=None,
        dual_lam_2=None,
        goal_reaching=0,
        use_goal_checker=0,
        frame_stack=None,
        visualize_rewards=0,
        exploration_sac_discount=0.99,
        use_target_traj=0,
        use_repelling=0,
        knn_k=12,
        rnd=None,
        disagreement=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.debug_noconst = debug_noconst
        self.debug_val = debug_val

        self.qf1 = qf1.to(self.device)
        self.qf2 = qf2.to(self.device)

        self.target_qf1 = copy.deepcopy(self.qf1)
        self.target_qf2 = copy.deepcopy(self.qf2)

        self.log_alpha = log_alpha.to(self.device)

        self.param_modules.update(
            qf1=self.qf1,
            qf2=self.qf2,
            log_alpha=self.log_alpha,
        )

        self.tau = tau

        self.replay_buffer: PathBufferEx = replay_buffer
        self.min_buffer_size = min_buffer_size
        self.inner = inner

        self.dual_reg = dual_reg
        self.dual_slack = dual_slack
        self.dual_dist = dual_dist

        self.num_alt_samples = num_alt_samples
        self.split_group = split_group

        self._reward_scale_factor = scale_reward
        self._target_entropy = (
            -np.prod(self._env_spec.action_space.shape).item() / 2.0 * target_coef
        )

        self.pixel_shape = pixel_shape

        assert self._trans_optimization_epochs is not None

        self.use_pure_rewards = use_pure_rewards
        self.model = model
        self.use_random_options_for_exploration = use_random_options_for_exploration
        self.dot_penalty = dot_penalty
        self.option_freq = option_freq
        self.exploration_type = exploration_type
        self.perpendicular = perpendicular

        self.apt = apt
        self.icm = icm
        self.rnd = rnd
        self.disagreement = disagreement
        if self.apt:
            assert exploration_policy is not None
            assert exploration_qf1 is not None
            assert exploration_qf2 is not None
            assert exploration_log_alpha is not None

            self.exploration_policy = exploration_policy.to(self.device)
            self.exploration_qf1 = exploration_qf1.to(self.device)
            self.exploration_qf2 = exploration_qf2.to(self.device)
            self.target_exploration_qf1 = copy.deepcopy(self.exploration_qf1)
            self.target_exploration_qf2 = copy.deepcopy(self.exploration_qf2)
            self.exploration_log_alpha = exploration_log_alpha.to(self.device)
            self.use_traj_for_apt_rep = use_traj_for_apt_rep

            if self.use_traj_for_apt_rep or self.icm:
                # particle-based entropy
                rms = apt_utils.RMS(self.device)
                if self.use_traj_for_apt_rep:
                    knn_clip = 0.0001
                    knn_k = knn_k  # adjusting this controls temperatire of goals
                    knn_avg = True
                    knn_rms = False  # since rms should be only used for calculating apt reward
                    self.pbe = apt_utils.METRAPBE(
                        rms, knn_clip, knn_k, knn_avg, knn_rms, self.device
                    )
                else:
                    knn_clip = 0.0
                    knn_k = knn_k
                    knn_avg = True
                    knn_rms = False
                    self.pbe = apt_utils.PBE(
                        rms, knn_clip, knn_k, knn_avg, knn_rms, self.device
                    )

            if self.use_traj_for_apt_rep:
                assert self.icm is None and self.rnd is None
                self.param_modules.update(
                    exploration_policy=self.exploration_policy,
                    exploration_qf1=self.exploration_qf1,
                    exploration_qf2=self.exploration_qf2,
                    exploration_log_alpha=self.exploration_log_alpha,
                )
            elif self.icm:
                assert self.icm is not None
                self.param_modules.update(
                    icm=self.icm,
                    exploration_policy=self.exploration_policy,
                    exploration_qf1=self.exploration_qf1,
                    exploration_qf2=self.exploration_qf2,
                    exploration_log_alpha=self.exploration_log_alpha,
                )

                self.icm.train()
            elif self.disagreement:
                assert self.disagreement is not None
                self.param_modules.update(
                    disagreement=self.disagreement,
                    exploration_policy=self.exploration_policy,
                    exploration_qf1=self.exploration_qf1,
                    exploration_qf2=self.exploration_qf2,
                    exploration_log_alpha=self.exploration_log_alpha,
                )

                self.disagreement.train()
            else:
                assert self.rnd is not None

                self.rnd = rnd

                assert exploration_policy is not None
                assert exploration_qf1 is not None
                assert exploration_qf2 is not None
                assert exploration_log_alpha is not None

                self.exploration_policy = exploration_policy.to(self.device)
                self.exploration_qf1 = exploration_qf1.to(self.device)
                self.exploration_qf2 = exploration_qf2.to(self.device)
                self.target_exploration_qf1 = copy.deepcopy(self.exploration_qf1)
                self.target_exploration_qf2 = copy.deepcopy(self.exploration_qf2)
                self.exploration_log_alpha = exploration_log_alpha.to(self.device)

                self.param_modules.update(
                    rnd=self.rnd,
                    exploration_policy=self.exploration_policy,
                    exploration_qf1=self.exploration_qf1,
                    exploration_qf2=self.exploration_qf2,
                    exploration_log_alpha=self.exploration_log_alpha,
                )

                self.rnd.train()
                self.rnd_scale = 1.0

                # particle-based entropy

                self.intrinsic_reward_rms = apt_utils.RMS(self.device)

        self.hierarchical = hierarchical_qf1 is not None
        if self.hierarchical:
            self.hierarchical_qf1 = hierarchical_qf1

            self.hierarchical_dim = hierarchical_dim
            self.hierarchical_qf1.hierarchical_dim = hierarchical_dim
            self.hierarchical_qf1.dim_option = self.dim_option

            self.hierarchical_qf1 = hierarchical_qf1.to(self.device)
            self.target_hierarchical_qf1 = copy.deepcopy(self.hierarchical_qf1)

            self.param_modules.update(
                hierarchical_qf1=self.hierarchical_qf1,
            )

        self.use_start_policy = use_start_policy
        if self.use_start_policy > 0:
            self.start_option_policy = copy.deepcopy(self.option_policy)
            self.upd = 0

        policy_for_agent = {
            "default_policy": self.option_policy,
        }

        if self.use_start_policy:
            policy_for_agent.update(
                {
                    "start_default_policy": self.start_option_policy,
                }
            )
        if self.apt:
            policy_for_agent.update(
                {
                    "exploration_policy": self.exploration_policy,
                }
            )
        if self.hierarchical:
            policy_for_agent.update(
                {
                    "hierarchical_policy": self.hierarchical_qf1,
                }
            )

        self.prevupd = prevupd

        if self.prevupd:
            assert exploration_policy is not None
            assert exploration_qf1 is not None
            assert exploration_qf2 is not None
            assert exploration_log_alpha is not None

            self.exploration_policy = exploration_policy.to(self.device)
            self.exploration_qf1 = exploration_qf1.to(self.device)
            self.exploration_qf2 = exploration_qf2.to(self.device)
            self.target_exploration_qf1 = copy.deepcopy(self.exploration_qf1)
            self.target_exploration_qf2 = copy.deepcopy(self.exploration_qf2)
            self.exploration_log_alpha = exploration_log_alpha.to(self.device)
            self.param_modules.update(
                exploration_policy=self.exploration_policy,
                exploration_qf1=self.exploration_qf1,
                exploration_qf2=self.exploration_qf2,
                exploration_log_alpha=self.exploration_log_alpha,
            )

            self.prevupd_traj_encoder = prevupd_traj_encoder.to(self.device)
            self.prevupd_dual_lam = prevupd_dual_lam.to(self.device)
            self.prevupd_dual_lam_constraint = prevupd_dual_lam_constraint.to(
                self.device
            )
            self.param_modules.update(
                {
                    "prevupd_traj_encoder": self.prevupd_traj_encoder,
                    "prevupd_dual_lam": self.prevupd_dual_lam,
                    "prevupd_dual_lam_constraint": self.prevupd_dual_lam_constraint,
                }
            )

            self.epoch_cnt = 0
            self.prevupd_freq = prevupd_freq
            self.prevupd_dim_exp_option = prevupd_dim_exp_option
            assert self.prevupd_freq > 0 and self.prevupd_dim_exp_option > 0

            self.prevupd_use_same_policy = prevupd_use_same_policy

            assert "exploration_policy" not in policy_for_agent
            policy_for_agent.update(
                {
                    "exploration_policy": self.exploration_policy,
                }
            )
            self.agent_replay_buffer = copy.deepcopy(replay_buffer)
            self.exploration_replay_buffer = copy.deepcopy(replay_buffer)
        else:
            self.agent_replay_buffer = None
            self.exploration_replay_buffer = None

        if self.debug_noconst == 5:
            self.agent_replay_buffer._capacity = 12000
        self.no_plot = no_plot

        self.hilp = hilp
        if self.hilp:
            self.hilp_expectile = hilp_expectile
            self.hilp_traj_encoder_tau = hilp_traj_encoder_tau
            self.hilp_qrl = hilp_qrl
            self.hilp_p_trajgoal = hilp_p_trajgoal
            self.hilp_discount = hilp_discount
            self.use_double_encoder = use_double_encoder

            if traj_encoder_2 is not None:
                self.traj_encoder_2 = traj_encoder_2.to(self.device)
                self.dual_lam_2 = dual_lam_2.to(self.device)
                assert "traj_encoder" in self.param_modules
                self.param_modules.update(
                    {
                        "traj_encoder_2": self.traj_encoder_2,
                        "dual_lam_2": self.dual_lam_2,
                    }
                )

            if not self.hilp_qrl:
                self.target_traj_encoder = copy.deepcopy(self.traj_encoder)
                self.target_traj_encoder_2 = copy.deepcopy(self.traj_encoder_2)

        self.goal_reaching = goal_reaching
        if self.goal_reaching and use_goal_checker:
            policy_for_agent.update(
                {
                    "encoder": self.traj_encoder,
                }
            )

        self.policy_for_agent = AgentWrapper(
            policies=policy_for_agent
        )  # this should be at the end

        self.frame_stack = frame_stack

        self.visualize_rewards = visualize_rewards
        self.exploration_sac_discount = exploration_sac_discount
        self.plot_first_2dims = True

        self.disagreement = disagreement

    @property
    def policy(self):
        return {"option_policy": self.policy_for_agent}

    def _get_concat_obs(self, obs, option):
        return get_torch_concat_obs(obs, option)

    def _generate_option_extras_list(self, options):
        """
        >>> a = np.zeros((3,4))
        >>> list(a)
        [array([0., 0., 0., 0.]), array([0., 0., 0., 0.]), array([0., 0., 0., 0.])]
        """
        return [{"option": list(option)} for option in options]

    def _get_train_trajectories_kwargs(self, runner, options=None):
        if self.option_freq > 0 and options is not None:
            raise NotImplementedError
        if options is None:
            if self.discrete:
                if self.hierarchical:
                    extras = self._generate_option_extras(
                        np.eye(self.hierarchical_dim)[
                            np.random.randint(
                                0, self.hierarchical_dim, runner._train_args.batch_size
                            )
                        ]
                    )
                elif self.prevupd:
                    extras = self._generate_option_extras(
                        np.eye(self.dim_option)[
                            np.random.randint(
                                0, self.dim_option, runner._train_args.batch_size
                            )
                        ]
                    )
                elif self.option_freq > 0:
                    options = np.eye(self.dim_option)[
                        np.random.randint(
                            0,
                            self.dim_option,
                            runner._train_args.batch_size
                            * self.max_path_length
                            // self.option_freq,
                        )
                    ]
                    options = options.reshape(
                        runner._train_args.batch_size, -1, self.dim_option
                    )
                    options = options.repeat(self.option_freq, axis=1)
                    assert options.shape[1] == self.max_path_length
                    assert options.ndim == 3
                    assert options.shape[0] == runner._train_args.batch_size
                    assert options.shape[2] == self.dim_option
                    extras = self._generate_option_extras_list(options)
                elif self.option_freq < 0:  # randomly switching skills with probability
                    batch_size = runner._train_args.batch_size
                    T = self.max_path_length
                    options = np.eye(self.dim_option)[
                        np.random.randint(
                            0,
                            self.dim_option,
                            batch_size * T,
                        )
                    ]
                    options = options.reshape(batch_size, T, self.dim_option)
                    prob = 1 / np.arange(T, 0, -1)

                    switching_point_1 = np.random.random((batch_size, T)) < prob
                    switching_point_2 = np.random.random((batch_size, T)) < prob

                    mask = switching_point_2.cumsum(axis=-1) == 0  # false points
                    assert mask.ndim == 2  # batch_size, T
                    random_skills_indices = (
                        switching_point_1.cumsum(axis=-1) * mask + (T - 1) * ~mask
                    )

                    options = options[
                        np.arange(batch_size)[:, None], random_skills_indices
                    ]
                    assert options.ndim == 3
                    assert options.shape[0] == runner._train_args.batch_size
                    assert options.shape[1] == self.max_path_length
                    assert options.shape[2] == self.dim_option
                    extras = self._generate_option_extras_list(options)

                    if self.use_start_policy:
                        for i, extra in enumerate(extras):
                            extra["use_start_default_policy"] = mask[i].tolist()
                else:
                    if self.option_freq > 0:
                        raise NotImplementedError
                    extras = self._generate_option_extras(
                        np.eye(self.dim_option)[
                            np.random.randint(
                                0, self.dim_option, runner._train_args.batch_size
                            )
                        ]
                    )

            else:
                if self.hierarchical:
                    raise NotImplementedError
                random_options = np.random.randn(
                    runner._train_args.batch_size, self.dim_option
                )
                if self.unit_length:
                    random_options /= np.linalg.norm(
                        random_options, axis=-1, keepdims=True
                    )
                extras = self._generate_option_extras(random_options)

            if self.goal_reaching:
                goals = self.get_random_goals(runner._train_args.batch_size)
                assert goals.ndim == 2  # for state spaces
                for i, extra in enumerate(extras):
                    extra["option"] = goals[i]
        else:  # options is not None
            extras = self._generate_option_extras(options)

        if self.exploration_type > 0:
            if self.exploration_type == 10:
                for i, extra in enumerate(extras):
                    extra["exploration_type"] = 10
            else:
                if (
                    self.replay_buffer.n_transitions_stored
                    < runner._train_args.batch_size
                ):
                    for i, extra in enumerate(extras):
                        extra["exploration_type"] = (
                            15  # use exploration policy all the time
                        )
                else:
                    for i, extra in enumerate(extras):
                        if i % 2 == 0:
                            extra["exploration_type"] = self.exploration_type
                        else:
                            extra["exploration_type"] = 0

        if self.prevupd:
            exp_options = np.eye(self.prevupd_dim_exp_option)[
                np.random.randint(
                    0,
                    self.prevupd_dim_exp_option,
                    len(extras),  # should not be runner._train_args.batch_size,
                )
            ]
            for i, extra in enumerate(extras):
                extra["exp_option"] = exp_options[i]

        return dict(
            extras=extras,
            sampler_key="option_policy",
        )

    def get_random_goals(self, size, num_batch=None):
        """Get `size` number of goals from buffer."""
        if num_batch is None:
            num_batch = self._trans_minibatch_size
        if self.replay_buffer.n_transitions_stored < num_batch:
            goals = self.initial_state[None, :].repeat(size, axis=0)
            assert goals.ndim == 2 and goals.shape[0] == size
            return goals

        # get samples as tensors
        samples = self.replay_buffer.sample_transitions(num_batch)
        data = {}
        for key, value in samples.items():
            data[key] = torch.from_numpy(value).float().to(self.device)
        with torch.no_grad():
            # calculate reward of each goals
            if self.icm:
                obs = data["obs"]
                actions = data["actions"]
                rep = self.icm.get_rep(obs, actions)
                reward = self.pbe(rep).squeeze(-1)
                assert reward.ndim == 1

            elif self.rnd:
                obs = data["obs"]
                next_obs = data["next_obs"]
                reward = self.rnd(obs).squeeze(-1)
                prediction_error = self.rnd(obs)
                _, intr_reward_var = self.intrinsic_reward_rms(prediction_error)
                reward = (
                    self.rnd_scale
                    * prediction_error
                    / (torch.sqrt(intr_reward_var) + 1e-8)
                )

                reward = reward.flatten()
                assert reward.ndim == 1 and reward.shape == (num_batch,)
            elif self.disagreement:
                obs = data["obs"]
                next_obs = data["next_obs"]
                action = data["actions"]

                reward = self.disagreement.get_disagreement(obs, action, next_obs)
                reward = reward.flatten()
                assert reward.ndim == 1

            elif self.use_traj_for_apt_rep:
                z1 = self.traj_encoder(data["next_obs"]).mean
                reward = self.pbe.get_reward(z1, z1).squeeze(-1)
            else:
                raise NotImplementedError
        assert reward.ndim == 1
        # select top `size` goals
        indices = torch.argsort(reward, descending=True)[:size].cpu().numpy()
        assert indices.shape == (size,)
        goals = samples["next_obs"][indices]

        # TODO: consider diversity of the goals
        return goals

    def _flatten_data(self, data):
        epoch_data = {}
        for key, value in data.items():
            epoch_data[key] = torch.tensor(
                np.concatenate(value, axis=0), dtype=torch.float32, device=self.device
            )
        return epoch_data

    def _update_replay_buffer(self, data):
        # refactor data, key, i to data, i, key
        assert self.replay_buffer is not None
        paths = []
        # Add paths to the replay buffer
        for i in range(len(data["actions"])):
            path = {}
            for key in data.keys():
                cur_list = data[key][i]
                if cur_list.ndim == 1:
                    cur_list = cur_list[..., np.newaxis]
                path[key] = cur_list
            paths.append(path)

        for path in paths:
            exploration_type = path.pop("exploration_type")
            if self.agent_replay_buffer is not None:
                if self.debug_noconst == 4:
                    self.exploration_replay_buffer.add_path(path)
                else:
                    if exploration_type > 0:
                        if self.exploration_replay_buffer is not None:
                            self.exploration_replay_buffer.add_path(path)
                    elif exploration_type == 0:
                        self.agent_replay_buffer.add_path(path)

            self.replay_buffer.add_path(path)

    def _sample_replay_buffer(self, replay_buffer, only_first_and_last=False):
        if self.hilp:
            if self.goal_reaching:
                samples = replay_buffer.sample_transitions_with_goals(
                    self._trans_minibatch_size,
                    p_trajgoal=self.hilp_p_trajgoal,
                    discount=self.hilp_discount,
                )
                trajgoals = samples["trajgoals"]

                if self.debug_noconst == 11:
                    # ===== trajgoals ===== (quite slow)
                    N, K, dim = trajgoals.shape
                    trajgoals = trajgoals.reshape(N * K, -1)
                    # now we are comparing between this trajgoals and normal samples
                    samples_unif = replay_buffer.sample_transitions(
                        self._trans_minibatch_size
                    )["obs"]

                    trajgoals = torch.from_numpy(trajgoals).float().to(self.device)
                    samples_unif = (
                        torch.from_numpy(samples_unif).float().to(self.device)
                    )

                    # use top reward goals
                    with torch.no_grad():
                        z1 = self.traj_encoder(trajgoals).mean
                        z2 = self.traj_encoder(samples_unif).mean
                        rewards = self.pbe.get_reward(z1, z2)
                        assert rewards.shape == (N * K, 1)
                        rewards = rewards.reshape(N, K)
                    trajgoals = samples["trajgoals"][
                        np.arange(len(rewards)),
                        rewards.argmax(axis=1).cpu().numpy(),
                    ]
                    assert trajgoals.shape == (N, dim)
                    # ==================

                # we are using "goals" for global push loss,
                # and "trajgoals" for goal relabelling.
                samples["options"] = trajgoals
                samples["next_options"] = trajgoals
                del samples["trajgoals"]

                # even if current is cur_exploration, they can be used with future options, so this is wrong.
                # what's importnat is "dones" for exploration policy. We will not have any dones for that.
                # it's okay because it will have less bias (rms)
                samples["dones_exp"] = samples["dones"].copy()
                # note: this is added to deal with gradient explosion
                samples["dones"] = np.maximum(
                    samples["success_rewards"], samples["dones"]
                )
            else:
                samples = replay_buffer.sample_transitions_with_goals(
                    self._trans_minibatch_size, p_trajgoal=self.hilp_p_trajgoal
                )
        else:
            samples = replay_buffer.sample_transitions(
                self._trans_minibatch_size, only_first_and_last=only_first_and_last
            )

        data = {}
        for key, value in samples.items():
            if value.shape[1] == 1 and "option" not in key:
                value = np.squeeze(value, axis=1)
            data[key] = torch.from_numpy(value).float().to(self.device)

        return data

    def _train_once_inner(self, path_data):
        if self.debug_noconst != 6 and self.debug_noconst != 40:
            self._update_replay_buffer(path_data)

        epoch_data = self._flatten_data(path_data)
        assert "obs" in epoch_data
        tensors = self._train_components(epoch_data)

        return tensors

    def _generate_rollout_once(self, v):
        assert self.model
        sim_states = v["sim_states"]

        self.model.reset()
        success = self.model.sim_set_state(sim_states.cpu().numpy())
        assert success

        cur_obs = v["obs"]
        obs_list = []
        next_obs_list = []
        actions_list = []
        used_options: torch.Tensor
        options_for_exploration = np.eye(self.dim_option)[
            np.random.randint(0, self.dim_option, v["options"].shape[0])
        ]
        options_for_exploration = v["options"].new_tensor(options_for_exploration)
        assert options_for_exploration.shape == v["options"].shape

        used_options = options_for_exploration
        for i in range(5):  # rollout for 5 steps
            if self.use_random_options_for_exploration:
                processed_cat_obs = self._get_concat_obs(
                    self.option_policy.process_observations(cur_obs),
                    options_for_exploration,
                )

                assert processed_cat_obs.shape[0] == cur_obs.shape[0]
                assert processed_cat_obs.shape[1] == cur_obs.shape[1] + self.dim_option
                actions, agent_infos = self.option_policy.get_actions(processed_cat_obs)
            else:  # random actions
                # processed_cat_obs = self._get_concat_obs(
                #     self.option_policy.process_observations(cur_obs),
                #     v["options"],
                # )
                random_actions = np.stack(self.model.action_space.sample(), axis=0)

                actions = random_actions
            assert actions.shape[0] == cur_obs.shape[0]
            assert actions.ndim == 2

            next_obs, _, _, _ = self.model.step(actions)
            next_obs = cur_obs.new_tensor(next_obs)
            assert next_obs.shape == cur_obs.shape

            obs_list.append(cur_obs)
            actions_list.append(v["actions"].new_tensor(actions))
            next_obs_list.append(next_obs)

            cur_obs = next_obs

        v["exp_obs"] = torch.concat(obs_list, dim=0)
        assert v["exp_obs"].shape[0] == len(v["options"]) * 5
        v["exp_actions"] = torch.concat(actions_list, dim=0)
        v["exp_next_obs"] = torch.concat(next_obs_list, dim=0)
        v["exp_options"] = v["options"].repeat(5, 1)
        v["exp_next_options"] = v["options"].repeat(5, 1)
        v["exp_random_options"] = used_options.repeat(5, 1)
        v["exp_next_random_options"] = used_options.repeat(5, 1)
        assert v["dones"].ndim == 1
        v["exp_dones"] = torch.zeros_like(v["dones"]).repeat(5)

    def _train_components_hilp(self, epoch_data):
        assert self.hilp

        tensors = {}
        for _ in range(self._trans_optimization_epochs):
            v = self._sample_replay_buffer(self.replay_buffer)
            self._optimize_te_hilp(tensors, v)
            if self.apt and self.icm:
                self._update_icm(tensors, v)
            elif self.apt and self.rnd:
                self._update_rnd(tensors, v)
            elif self.apt and self.disagreement:
                self._update_disagreement(tensors, v)
            with torch.no_grad():
                self._update_rewards(tensors, v)
                if self.use_traj_for_apt_rep or self.icm:
                    self._update_apt_rewards(tensors, v)
                elif self.rnd:
                    self._update_rnd_rewards(tensors, v)
                elif self.disagreement:
                    self._update_disagreement_rewards(tensors, v)
            if not self.debug_noconst == 20:
                self._optimize_exploration_policy(tensors, v)

            self._optimize_op(tensors, v)

        return tensors

    def _update_disagreement(self, tensors, internal_vars):
        obs = internal_vars["obs"]
        next_obs = internal_vars["next_obs"]
        action = internal_vars["actions"]

        error = self.disagreement(obs, action, next_obs)

        loss = error.mean()

        tensors.update(
            {
                "disagreement_loss": loss,
            }
        )

        self._gradient_descent(
            tensors["disagreement_loss"],
            optimizer_keys=["disagreement"],
        )

    def _update_disagreement_rewards(self, tensors, v):
        obs = v["obs"]
        next_obs = v["next_obs"]
        actions = v["actions"]

        rewards = self.disagreement.get_disagreement(obs, actions, next_obs).unsqueeze(
            1
        )

        rewards = rewards.flatten()
        assert rewards.ndim == 1 and rewards.shape[0] == obs.shape[0]

        tensors.update(
            {
                "RndRewardMean": rewards.mean(),
                "RndRewardStd": rewards.std(),
            }
        )

        v["exploration_rewards"] = rewards

        return rewards

    def _update_rnd(self, tensors, internal_vars):
        obs = internal_vars["obs"]
        next_obs = internal_vars["next_obs"]
        action = internal_vars["actions"]

        prediction_error = self.rnd(obs)

        loss = prediction_error.mean()

        tensors.update(
            {
                "rnd_loss": loss,
            }
        )

        self._gradient_descent(
            tensors["rnd_loss"],
            optimizer_keys=["rnd"],
        )

    def _update_rnd_rewards(self, tensors, v):
        obs = v["obs"]
        next_obs = v["next_obs"]
        actions = v["actions"]

        prediction_error = self.rnd(obs)
        _, intr_reward_var = self.intrinsic_reward_rms(prediction_error)
        rewards = (
            self.rnd_scale * prediction_error / (torch.sqrt(intr_reward_var) + 1e-8)
        )

        rewards = rewards.flatten()
        assert rewards.ndim == 1 and rewards.shape[0] == obs.shape[0]

        tensors.update(
            {
                "RndRewardMean": rewards.mean(),
                "RndRewardStd": rewards.std(),
            }
        )

        v["exploration_rewards"] = rewards

        return rewards

    def _train_components(self, epoch_data):
        if (
            self.replay_buffer is not None
            and self.replay_buffer.n_transitions_stored < self.min_buffer_size
        ):
            return {}
        if self.hilp:
            return self._train_components_hilp(epoch_data)

        for _ in range(self._trans_optimization_epochs):
            tensors = {}  # to keep track of losses

            if self.replay_buffer is None:
                v = self._get_mini_tensors(epoch_data)
            else:
                v = self._sample_replay_buffer(self.replay_buffer)

            if self.model:
                self._generate_rollout_once(v)

            if self.hilp:
                # test if HILP representation learning works
                self._optimize_te_hilp(tensors, v)

                if self.debug_noconst == 6:
                    continue
                self._update_rewards(tensors, v)
                if self.apt:
                    if self.rnd is None:
                        with torch.no_grad():
                            self._update_apt_rewards(tensors, v)
                    else:
                        with torch.no_grad():
                            self._update_rnd_rewards(tensors, v)

                    assert v["exploration_rewards"].shape == v["rewards"].shape, (
                        v["exploration_rewards"].shape,
                        v["rewards"].shape,
                    )
                    self._optimize_exploration_policy(tensors, v)

            elif self.use_pure_rewards == 1:
                pass
            elif self.use_pure_rewards == 2:
                self._update_rewards(tensors, v)
                self._optimize_te(tensors, v)
                ext_rewards = v["rewards"].clone()
                self._update_rewards(tensors, v)
                if ext_rewards is v["rewards"]:
                    print(
                        f"{ext_rewards.mean()}, {ext_rewards.sum()}, {ext_rewards.shape}"
                    )  # assert two are different
                    print("maybe something's wrong")
                v["rewards"] += ext_rewards
            elif self.icm:
                self._update_rewards(tensors, v)
                self._optimize_te(tensors, v)
                self._update_rewards(tensors, v)
                self._update_icm(tensors, v)
                with torch.no_grad():
                    self._update_apt_rewards(tensors, v)
                assert v["exploration_rewards"].shape == v["rewards"].shape, (
                    v["exploration_rewards"].shape,
                    v["rewards"].shape,
                )
                self._optimize_exploration_policy(tensors, v)
            elif self.hierarchical:
                self._update_rewards(tensors, v)
                self._optimize_te(tensors, v)
                self._update_rewards(
                    tensors, v
                )  # why is this duplicated? probably redundant
                self._optimize_hierarchical(tensors, v)
                v["options"] = v["low_options"]
                v["next_options"] = v["next_low_options"]
                assert (
                    v["options"].ndim == 2 and v["options"].shape[1] == self.dim_option
                )
                # check if options are used well. Seems good.
            elif self.prevupd:
                assert self.dual_reg
                use_agent_v = (
                    self.agent_replay_buffer.n_transitions_stored
                    >= self.min_buffer_size
                )
                use_exp_v = (
                    self.exploration_replay_buffer.n_transitions_stored
                    >= self.min_buffer_size
                )
                if self.debug_noconst == 1:
                    use_agent_v = False

                do_prevupd_update = use_agent_v or use_exp_v
                if do_prevupd_update:
                    # for updating the constraints to zero mapping, we do not need rewards
                    if (
                        self.debug_noconst == 3
                    ):  # debug_noconst == 3 => multiple agent update
                        repeat_cnt = self.debug_val
                    else:
                        repeat_cnt = 1
                    for _ in range(repeat_cnt):
                        if use_agent_v:
                            # we do not use options, so we do not care option dims
                            if self.debug_noconst == 2:  # only first and last
                                agent_v = self._sample_replay_buffer(
                                    self.agent_replay_buffer, only_first_and_last=True
                                )
                                self._update_loss_te_prevupd(
                                    tensors, agent_v, constraint_only=True
                                )
                                self._update_loss_prevupd_dual_lam(
                                    tensors, agent_v, constraint_only=True
                                )
                                tensors["ConstraintPrevupdLossTe"] = (
                                    tensors["ConstraintPrevupdLossTe"] * 0.1 * 0.1
                                )
                                tensors["ConstraintPrevupdLossTe"] = (
                                    tensors["ConstraintPrevupdLossTe"] * 0.1 * 0.1
                                )
                            else:
                                agent_v = self._sample_replay_buffer(
                                    self.agent_replay_buffer
                                )
                                self._update_loss_te_prevupd(
                                    tensors, agent_v, constraint_only=True
                                )
                                self._update_loss_prevupd_dual_lam(
                                    tensors, agent_v, constraint_only=True
                                )
                        if use_exp_v:  # half-half.
                            exp_v = self._sample_replay_buffer(
                                self.exploration_replay_buffer
                            )
                            self._update_rewards(tensors, v, exp_v=exp_v)
                            assert "prevupd_rewards" in exp_v, exp_v.keys()
                            self._update_loss_te_prevupd(
                                tensors, exp_v, constraint_only=False
                            )
                            self._update_loss_prevupd_dual_lam(
                                tensors, exp_v, constraint_only=False
                            )
                        else:
                            self._update_rewards(tensors, v)

                        prevupd_loss_te = None
                        if use_agent_v and not use_exp_v:
                            prevupd_loss_te = tensors["PrevupdLossTe"]
                        elif use_exp_v and not use_agent_v:
                            prevupd_loss_te = tensors["PrevupdLossTe"]
                        elif use_exp_v and use_agent_v:
                            prevupd_loss_te = (
                                tensors["PrevupdLossTe"]
                                + tensors["ConstraintPrevupdLossTe"]
                            )

                        self._gradient_descent(
                            prevupd_loss_te,
                            optimizer_keys=["prevupd_traj_encoder"],
                            clip_grad=True,
                        )
                        if use_agent_v:  # agent_v: old buffer
                            self._gradient_descent(
                                tensors["ConstraintPrevupdLossDualLam"],
                                optimizer_keys=["prevupd_dual_lam_constraint"],
                                clip_grad=False,
                            )
                        if use_exp_v:  # exp_v: new buffer
                            self._gradient_descent(
                                tensors["PrevupdLossDualLam"],
                                optimizer_keys=["prevupd_dual_lam"],
                                clip_grad=False,
                            )

                            assert "exploration_rewards" in exp_v
                            self._optimize_exploration_policy(
                                tensors,
                                exp_v,  # TODO: this may be "v"
                            )  # optimizing exploration policy should happen with v
                    assert "rewards" in v
                    self._optimize_te(tensors, v)
                    self._update_rewards(tensors, v)  # i think we don't need this?

            else:
                self._update_rewards(tensors, v)
                self._optimize_te(tensors, v)
                self._update_rewards(
                    tensors, v
                )  # i think we don't need this? probably to use optimized te?
            self._optimize_op(tensors, v)

            if self.use_start_policy:
                self._update_start_option_policy()

        if self.prevupd and (
            self.exploration_type == 10
            or self.debug_noconst == 4
            or self.debug_noconst == 5
        ):
            self._update_prev_replay_buffer()

        return tensors

    def get_dist_from_encoder(self, encoder, obs, goals):
        phi_s = encoder(obs).mean
        phi_g = encoder(goals).mean
        squared_dist = ((phi_s - phi_g) ** 2).sum(axis=-1)  # double V network is used
        assert squared_dist.ndim == 1 and squared_dist.shape[0] == obs.shape[0]
        dist = torch.sqrt(
            torch.maximum(squared_dist, torch.full_like(squared_dist, 1e-6))
        )  # this makes same next current state have gradient zero
        if self.debug_noconst == 32:
            dist = squared_dist
        assert dist.ndim == 1
        return dist

    def _update_target_traj_encoders(self):
        """Update parameters in the target q-functions."""
        target_traj_encoders = [self.target_traj_encoder, self.target_traj_encoder_2]
        traj_encoders = [self.traj_encoder, self.traj_encoder_2]
        for target_traj_encoder, traj_encoder in zip(
            target_traj_encoders, traj_encoders
        ):
            for t_param, param in zip(
                target_traj_encoder.parameters(), traj_encoder.parameters()
            ):
                t_param.data.copy_(
                    t_param.data * (1.0 - self.hilp_traj_encoder_tau)
                    + param.data * self.hilp_traj_encoder_tau
                )

    def _update_loss_te_qrl(self, tensors, internal_vals, traj_encoder, dual_lam):
        obs = internal_vals["obs"]
        next_obs = internal_vals["next_obs"]
        goals = internal_vals["goals"]
        masks = internal_vals["masks"]
        success_rewards = internal_vals["success_rewards"]
        assert (
            success_rewards.ndim == masks.ndim == 1
            and success_rewards.shape[0] == obs.shape[0]
        ), success_rewards.shape

        phi_x, phi_y, phi_g = torch.split(
            traj_encoder(torch.cat([obs, next_obs, goals], dim=0)).mean, len(obs)
        )
        squared_dist = ((phi_x - phi_g) ** 2).sum(axis=-1)  # double V network is used
        dist = torch.sqrt(
            torch.maximum(squared_dist, torch.full_like(squared_dist, 1e-6))
        )  # this makes same next current state have gradient zero

        # masks are 0 if terminal, 1 otherwise
        masks = 1.0 - success_rewards
        # rewards are 0 if terminal, -1 otherwise
        success_rewards = success_rewards - 1.0  # actually we do not use this

        cst_dist = torch.ones_like(success_rewards)
        cst_penalty = cst_dist - torch.square(phi_y - phi_x).mean(dim=1)
        cst_penalty = torch.clamp(cst_penalty, max=self.dual_slack)

        dual_lam = dual_lam.param.exp()

        if self.debug_noconst == 32:
            te_obj = -dist.mean() + (dual_lam.detach() * cst_penalty).mean()
        elif self.debug_noconst == 35:
            dist_cur = self.get_dist_from_encoder(traj_encoder, next_obs, goals)
            te_obj = (dist - dist_cur).mean() - (dual_lam.detach() * cst_penalty).mean()
        elif self.debug_noconst == 70:
            te_obj = (
                -torch.nn.functional.softplus(500 - dist, beta=0.01).mean()
                + (dual_lam.detach() * cst_penalty).mean()
                - (
                    (traj_encoder.get_rep(obs) - traj_encoder.get_rep(next_obs)) ** 2
                ).mean()
            )

        else:
            te_obj = (
                -torch.nn.functional.softplus(500 - dist, beta=0.01).mean()
                + (dual_lam.detach() * cst_penalty).mean()
            )  # rewards + cst_penalty

        internal_vals.update({"cst_penalty": cst_penalty})
        tensors.update(
            {
                "DualCstPenalty": cst_penalty.mean(),
            }
        )

        loss_te = -te_obj

        tensors.update(
            {
                "TeObjMean": te_obj.mean(),
                "LossTe": loss_te,
            }
        )

    def _update_loss_te_hilp(self, tensors, internal_vals):
        obs = internal_vals["obs"]
        next_obs = internal_vals["next_obs"]
        goals = internal_vals["goals"]
        masks = internal_vals["masks"]
        success_rewards = internal_vals["success_rewards"]
        assert (
            success_rewards.ndim == masks.ndim == 1
            and success_rewards.shape[0] == obs.shape[0]
        ), success_rewards.shape

        # code from HILP
        # masks are 0 if terminal, 1 otherwise
        masks = 1.0 - success_rewards
        # rewards are 0 if terminal, -1 otherwise
        success_rewards = success_rewards - 1.0

        def get_v_from_encoder(encoder, obs, goals):
            phi_s = encoder(obs).mean
            phi_g = encoder(goals).mean
            squared_dist = ((phi_s - phi_g) ** 2).sum(
                axis=-1
            )  # double V network is used
            assert squared_dist.ndim == 1 and squared_dist.shape[0] == obs.shape[0]
            v = -torch.sqrt(
                torch.maximum(squared_dist, torch.full_like(squared_dist, 1e-6))
            )
            assert v.ndim == 1
            return v

        def expectile_loss(adv, diff, expectile=0.7):
            # 0.95 is used for Hilbert representation expectile, 0.7 for visual kitchen
            weight = torch.where(adv >= 0, expectile, 1 - expectile)
            assert weight.ndim == 1
            assert diff.shape == weight.shape
            return weight * (diff**2)

        with torch.no_grad():
            # note that target smoothing coefficient is 0.005
            next_v1 = get_v_from_encoder(self.target_traj_encoder, next_obs, goals)
            next_v2 = get_v_from_encoder(self.target_traj_encoder_2, next_obs, goals)
            next_v = torch.minimum(next_v1, next_v2)
            q = success_rewards + self.discount * masks * next_v

            v1_t = get_v_from_encoder(self.target_traj_encoder, obs, goals)
            v2_t = get_v_from_encoder(self.target_traj_encoder_2, obs, goals)
            v_t = (v1_t + v2_t) / 2
            adv = q - v_t

        q1 = success_rewards + self.discount * masks * next_v1
        q2 = success_rewards + self.discount * masks * next_v2
        assert q1.ndim == q2.ndim == 1
        v1 = get_v_from_encoder(self.traj_encoder, obs, goals)
        v2 = get_v_from_encoder(self.traj_encoder_2, obs, goals)
        assert v1.ndim == v2.ndim == 1
        v = (v1 + v2) / 2

        value_loss1 = expectile_loss(adv, q1 - v1, self.hilp_expectile).mean()
        value_loss2 = expectile_loss(adv, q2 - v2, self.hilp_expectile).mean()

        value_loss = value_loss1 + value_loss2
        loss_te = value_loss

        tensors.update(
            {
                "masks": masks.mean(),
                "success_rewards": success_rewards.mean(),
                "next_v1": next_v1.mean(),
                "next_v2": next_v2.mean(),
                "next_v": next_v.mean(),
                "q": q.mean(),
                "v1_t": v1_t.mean(),
                "v2_t": v2_t.mean(),
                "v_t": v_t.mean(),
                "q1": q1.mean(),
                "q2": q2.mean(),
                "value_loss": value_loss.mean(),
                "v max": v.max(),
                "v min": v.min(),
                "v mean": v.mean(),
                "abs adv mean": torch.abs(adv).mean(),
                "adv mean": adv.mean(),
                "adv max": adv.max(),
                "adv min": adv.min(),
                "accept prob": (adv >= 0).float().mean(),
                "LossTe": loss_te,
            }
        )

    def _optimize_te_hilp(self, tensors, internal_vars):
        if self.hilp_qrl:
            self._update_loss_te_qrl(
                tensors, internal_vars, self.traj_encoder, self.dual_lam
            )
            self._gradient_descent(
                tensors["LossTe"], optimizer_keys=["traj_encoder"], clip_grad=False
            )
            self._update_loss_dual_lam(tensors, internal_vars, self.dual_lam)
            self._gradient_descent(
                tensors["LossDualLam"],
                optimizer_keys=["dual_lam"],
                clip_grad=False,
            )
            if self.use_double_encoder:
                self._update_loss_te_qrl(
                    tensors, internal_vars, self.traj_encoder_2, self.dual_lam_2
                )
                self._gradient_descent(
                    tensors["LossTe"],
                    optimizer_keys=["traj_encoder_2"],
                    clip_grad=False,
                )
                self._update_loss_dual_lam(tensors, internal_vars, self.dual_lam_2)
                self._gradient_descent(
                    tensors["LossDualLam"],
                    optimizer_keys=["dual_lam_2"],
                    clip_grad=False,
                )

        else:
            self._update_loss_te_hilp(tensors, internal_vars)

            self._gradient_descent(
                tensors["LossTe"],
                optimizer_keys=["traj_encoder"],
                clip_grad=False,
            )
            self._update_target_traj_encoders()

    def _update_prev_replay_buffer(self):
        """Update parameters in the target q-functions."""
        self.epoch_cnt += 1
        if self.epoch_cnt % self.prevupd_freq == 0:  # for `prevupd_freq` iterations:
            print(
                f"updating prev_replay_buffer: current size: {self.replay_buffer.n_transitions_stored}"
            )
            print(
                f"updating exploration_replay_buffer: current size: {self.exploration_replay_buffer.n_transitions_stored}"
            )
            print(
                f"updating agent_replay_buffer: current size: {self.agent_replay_buffer.n_transitions_stored}"
            )
            # or we can do something like tau, later.
            if self.debug_noconst == 5:
                # self.exploration_replay_buffer =
                pass
            else:
                self.agent_replay_buffer = copy.deepcopy(self.exploration_replay_buffer)

    def _optimize_hierarchical(self, tensors, internal_vars):
        self._update_loss_qf(tensors, internal_vars, hierarchical=True)

        self._gradient_descent(
            tensors["hierarchical_LossQf1"],
            optimizer_keys=["hierarchical_qf"],
        )

    def _update_start_option_policy(self):
        """Update parameters in the target q-functions."""
        self.upd += 1
        tau_update = self.use_start_policy == 2
        if tau_update:
            target_option_policy = [self.start_option_policy]
            option_policy = [self.option_policy]
            for target_policy, policy in zip(target_option_policy, option_policy):
                for t_param, param in zip(
                    target_policy.parameters(), policy.parameters()
                ):
                    t_param.data.copy_(
                        t_param.data * (1.0 - self.tau) + param.data * self.tau
                    )

        else:
            if self.upd % (50 * 500) == 0:
                target_option_policy = [self.start_option_policy]
                option_policy = [self.option_policy]
                for target_policy, policy in zip(target_option_policy, option_policy):
                    for t_param, param in zip(
                        target_policy.parameters(), policy.parameters()
                    ):
                        t_param.data.copy_(param.data)

    def _update_icm(self, tensors, internal_vars):
        obs = internal_vars["obs"]
        next_obs = internal_vars["next_obs"]
        action = internal_vars["actions"]

        forward_error, backward_error = self.icm(obs, action, next_obs)

        loss = forward_error.mean() + backward_error.mean()

        tensors.update(
            {
                "icm_loss": loss,
            }
        )

        self._gradient_descent(
            tensors["icm_loss"],
            optimizer_keys=["icm"],
        )

    def _optimize_te(self, tensors, internal_vars):
        self._update_loss_te(tensors, internal_vars)

        self._gradient_descent(
            tensors["LossTe"], optimizer_keys=["traj_encoder"], clip_grad=False
        )

        if self.dual_reg:
            self._update_loss_dual_lam(tensors, internal_vars)
            self._gradient_descent(
                tensors["LossDualLam"],
                optimizer_keys=["dual_lam"],
            )
            if self.dual_dist == "s2_from_s":
                self._gradient_descent(
                    tensors["LossDp"],
                    optimizer_keys=["dist_predictor"],
                )

    def _optimize_op(self, tensors, internal_vars):
        self._update_loss_qf(tensors, internal_vars)

        self._gradient_descent(
            tensors["LossQf1"] + tensors["LossQf2"],
            optimizer_keys=["qf"],
        )

        self._update_loss_op(tensors, internal_vars)
        self._gradient_descent(
            tensors["LossSacp"],
            optimizer_keys=["option_policy"],
        )

        self._update_loss_alpha(tensors, internal_vars)
        self._gradient_descent(
            tensors["LossAlpha"],
            optimizer_keys=["log_alpha"],
        )

        sac_utils.update_targets(
            self, self.qf1, self.qf2, self.target_qf1, self.target_qf2
        )

    def _optimize_exploration_policy(self, tensors, internal_vars):
        self._update_loss_qf(tensors, internal_vars, exploration=True)

        self._gradient_descent(
            tensors["exploration_LossQf1"] + tensors["exploration_LossQf2"],
            optimizer_keys=["exploration_qf"],
            clip_grad=False,
        )

        self._update_loss_op(tensors, internal_vars, exploration=True)
        self._gradient_descent(
            tensors["exploration_LossSacp"],
            optimizer_keys=["exploration_policy"],
            clip_grad=False,
        )

        self._update_loss_alpha(tensors, internal_vars, exploration=True)
        self._gradient_descent(
            tensors["exploration_LossAlpha"],
            optimizer_keys=["exploration_log_alpha"],
            clip_grad=False,
        )

        sac_utils.update_targets(
            self,
            self.exploration_qf1,
            self.exploration_qf2,
            self.target_exploration_qf1,
            self.target_exploration_qf2,
        )

    def _update_apt_rewards(self, tensors, v):
        obs = v["obs"]
        next_obs = v["next_obs"]
        actions = v["actions"]

        if self.use_traj_for_apt_rep:
            rep, next_rep = v["cur_z"], v["next_z"]
            rewards = self.pbe(rep, next_rep, use_rms=True)
            if self.use_double_encoder:
                if not self.use_traj_for_apt_rep:
                    raise NotImplementedError
                rep_2, next_rep_2 = torch.split(
                    self.traj_encoder_2(torch.cat([obs, next_obs], dim=0)).mean,
                    len(obs),
                )
                rewards_2 = self.pbe(rep_2, next_rep=next_rep_2)
                rewards = (rewards + rewards_2) / 2

        else:
            assert self.icm
            rep = self.icm.get_rep(obs, actions)
            rewards = self.pbe(rep)

        rewards = rewards.flatten()
        assert rewards.ndim == 1 and rewards.shape[0] == obs.shape[0]

        tensors.update(
            {
                "AptRewardMean": rewards.mean(),
                "AptRewardStd": rewards.std(),
            }
        )

        v["exploration_rewards"] = rewards

    def _update_rewards(self, tensors, v, exp_v=None):
        obs = v["obs"]
        next_obs = v["next_obs"]
        options = v["options"]

        if "exp_obs" in v:
            obs = torch.cat([obs, v["exp_obs"]], dim=0)
            next_obs = torch.cat([next_obs, v["exp_next_obs"]], dim=0)
            options = torch.cat([options, v["exp_options"]], dim=0)
            if self.use_random_options_for_exploration:
                obs = torch.cat([obs, v["exp_obs"]], dim=0)
                next_obs = torch.cat([next_obs, v["exp_next_obs"]], dim=0)
                options = torch.cat([options, v["exp_random_options"]], dim=0)

        if self.inner:

            def get_rewards(traj_encoder, obs, next_obs, options):
                assert options.ndim == 2
                ### calc target rewards
                if self.goal_reaching:
                    cur_z, next_z, goal_z = torch.split(
                        traj_encoder(torch.cat([obs, next_obs, options], dim=0)).mean,
                        len(obs),
                    )
                    rew = torch.norm(goal_z - cur_z, dim=1) - torch.norm(
                        goal_z - next_z, dim=1
                    )
                    return rew, cur_z, next_z

                cur_z = traj_encoder(obs).mean
                next_z = traj_encoder(next_obs).mean

                target_z = next_z - cur_z

                if self.discrete:
                    dim_option = (
                        options.shape[1]
                        if not self.hierarchical
                        else self.hierarchical_dim
                    )
                    if self.perpendicular:
                        assert dim_option % 2 == 0 and options.ndim == 2
                        masks = torch.zeros_like(options)
                        masks[:, : dim_option // 2] = (
                            options[:, ::2].contiguous() - options[:, 1::2].contiguous()
                        )
                        rewards = (target_z * masks).sum(dim=1)
                        # assert masks.sum(dim=1).allclose(torch.ones_like(masks[:, 0]))
                    else:
                        masks = (
                            (
                                options
                                - (
                                    options.mean(dim=1, keepdim=True)
                                    if dim_option != 1
                                    else 0
                                )
                            )
                            * dim_option
                            / (dim_option - 1 if dim_option != 1 else 1)
                        )
                        rewards = (target_z * masks).sum(dim=1)
                else:
                    if self.perpendicular:
                        raise NotImplementedError
                    inner = (target_z * options).sum(dim=1)
                    rewards = inner

                return rewards, cur_z, next_z

            rewards, cur_z, next_z = get_rewards(
                self.traj_encoder, obs, next_obs, options
            )

            if self.debug_noconst == 123:
                rewards = v["success_rewards"] - 1
                v["dones"] = v["success_rewards"]

            if self.hilp and self.use_double_encoder:
                rewards_2, cur_z, next_z = get_rewards(
                    self.traj_encoder_2, obs, next_obs, options
                )
                rewards = (rewards + rewards_2) / 2

            # For dual objectives
            v.update(
                {
                    "cur_z": cur_z,
                    "next_z": next_z,
                }
            )

            if exp_v is not None:
                obs = exp_v["obs"]
                next_obs = exp_v["next_obs"]
                exp_options = exp_v["exp_options"]

                prevupd_rewards, prevupd_cur_z, prevupd_next_z = get_rewards(
                    self.prevupd_traj_encoder, obs, next_obs, exp_options
                )
                # we do not use prevupd_cur_z for now, since we need
                # two updates (constraint / no constraint)
                # # since options are options from the true policies,
                # v.update(
                #     {
                #         "prevupd_cur_z": prevupd_cur_z,
                #         "prevupd_next_z": prevupd_next_z,
                #     }
                # )
        else:
            target_dists = self.traj_encoder(next_obs)

            if self.discrete:
                logits = target_dists.mean
                rewards = -torch.nn.functional.cross_entropy(
                    logits, v["options"].argmax(dim=1), reduction="none"
                )
            else:
                rewards = target_dists.log_prob(v["options"])

        tensors.update(
            {
                "PureRewardMean": rewards.mean(),
                "PureRewardStd": rewards.std(),
                "PureRewardMax": rewards.max(),
                "PureRewardMin": rewards.min(),
            }
        )

        v["rewards"] = rewards

        if exp_v is not None:
            tensors.update(
                {
                    "PrevupdPureRewardMean": prevupd_rewards.mean(),
                    "PrevupdPureRewardStd": prevupd_rewards.std(),
                }
            )
            exp_v["prevupd_rewards"] = prevupd_rewards
            exp_v["exploration_rewards"] = prevupd_rewards

    def _update_loss_te(self, tensors, v):
        rewards = v["rewards"]

        obs = v["obs"]
        next_obs = v["next_obs"]
        options = v["options"]
        if "exp_obs" in v:
            obs = torch.cat([obs, v["exp_obs"]], dim=0)
            next_obs = torch.cat([next_obs, v["exp_next_obs"]], dim=0)
            options = torch.cat([options, v["exp_options"]], dim=0)
            if self.use_random_options_for_exploration:
                obs = torch.cat([obs, v["exp_obs"]], dim=0)
                next_obs = torch.cat([next_obs, v["exp_next_obs"]], dim=0)
                options = torch.cat([options, v["exp_random_options"]], dim=0)

        if self.dual_dist == "s2_from_s":
            s2_dist = self.dist_predictor(obs)
            loss_dp = -s2_dist.log_prob(next_obs - obs).mean()
            tensors.update(
                {
                    "LossDp": loss_dp,
                }
            )
        elif self.dual_dist == "quasimetric":
            raise NotImplementedError
        if self.dual_reg:
            dual_lam = self.dual_lam.param.exp()
            x = obs
            y = next_obs
            phi_x = v["cur_z"]
            phi_y = v["next_z"]

            if self.dual_dist == "l2":
                cst_dist = torch.square(y - x).mean(dim=1)
            elif self.dual_dist == "one":
                cst_dist = torch.ones_like(x[:, 0])
            elif self.dual_dist == "s2_from_s":
                s2_dist = self.dist_predictor(obs)
                s2_dist_mean = s2_dist.mean
                s2_dist_std = s2_dist.stddev
                scaling_factor = 1.0 / s2_dist_std
                geo_mean = torch.exp(
                    torch.log(scaling_factor).mean(dim=1, keepdim=True)
                )
                normalized_scaling_factor = (scaling_factor / geo_mean) ** 2
                cst_dist = torch.mean(
                    torch.square((y - x) - s2_dist_mean) * normalized_scaling_factor,
                    dim=1,
                )

                tensors.update(
                    {
                        "ScalingFactor": scaling_factor.mean(dim=0),
                        "NormalizedScalingFactor": normalized_scaling_factor.mean(
                            dim=0
                        ),
                    }
                )
            elif self.dual_dist == "gt":
                initial_state_repeated = (
                    torch.from_numpy(self.initial_state)
                    .to(obs.device)
                    .repeat(obs.shape[0], 1)
                )
                cst_dist = self.dist_predictor(initial_state_repeated, obs) ** 2
            elif self.dual_dist == "quasimetric":
                qrl_agent, qrl_losses = self.dist_predictor
                initial_state = v["initial_obs"]
                results = []
                for critic in qrl_agent.critics:
                    res = critic(initial_state, obs)
                    results.append(res)
                cst_dist = torch.stack(results, dim=1).min(dim=1).values
            else:
                raise NotImplementedError

            assert cst_dist.ndim == 1 and cst_dist.shape[0] == options.shape[0]
            if self.dot_penalty:
                z_dir_dist = ((phi_y - phi_x) * options).sum(dim=1)
                z_dir_dist_sq = torch.clamp(z_dir_dist, min=0.0) ** 2
                assert (
                    z_dir_dist_sq.ndim == 1
                    and z_dir_dist_sq.shape[0] == options.shape[0]
                )
                perpen_point = ((phi_y - phi_x) * options).sum(dim=-1, keepdim=True)
                perpen_point = perpen_point / (
                    (torch.norm(phi_y - phi_x, dim=-1, keepdim=True) + 1e-6)
                    * (torch.norm(options, dim=-1, keepdim=True))
                )
                perpen_point = perpen_point * options
                perpen_dist_sq = torch.square((phi_y - phi_x) - perpen_point).sum(
                    dim=-1
                )
                assert (
                    perpen_dist_sq.ndim == 1
                    and perpen_dist_sq.shape[0] == options.shape[0]
                )

                cst_penalty = cst_dist - (z_dir_dist_sq + perpen_dist_sq)
            else:
                cst_penalty = cst_dist - torch.square(phi_y - phi_x).mean(dim=1)
            cst_penalty = torch.clamp(cst_penalty, max=self.dual_slack)

            te_obj = rewards.mean() + (dual_lam.detach() * cst_penalty).mean()

            v.update({"cst_penalty": cst_penalty})
            tensors.update(
                {
                    "DualCstPenalty": cst_penalty.mean(),
                }
            )
        else:
            te_obj = rewards.mean()

        loss_te = -te_obj

        tensors.update(
            {
                "TeObjMean": te_obj.mean(),
                "LossTe": loss_te,
            }
        )

    def _update_loss_dual_lam(self, tensors, v, dual_lam=None):
        if dual_lam is None:
            dual_lam = self.dual_lam
        log_dual_lam = dual_lam.param
        dual_lam = log_dual_lam.exp()
        loss_dual_lam = log_dual_lam * (v["cst_penalty"].detach()).mean()

        tensors.update(
            {
                "DualLam": dual_lam,
                "LossDualLam": loss_dual_lam,
            }
        )

    def _update_loss_te_prevupd(self, tensors, v, constraint_only=False):
        obs = v["obs"]
        next_obs = v["next_obs"]
        # prevupd always has this
        # we will use only obs here

        if not constraint_only:
            rewards = v["prevupd_rewards"]  # assert this is updated
        else:
            assert "prevupd_rewards" not in v
            rewards = torch.zeros_like(v["rewards"])

        # assert self.traj_encoder not used here
        if self.dual_reg:
            if constraint_only:
                dual_lam = self.prevupd_dual_lam_constraint.param.exp()
            else:
                dual_lam = self.prevupd_dual_lam.param.exp()

            x = obs
            # y = next_obs
            cur_z = self.prevupd_traj_encoder(obs).mean
            next_z = self.prevupd_traj_encoder(next_obs).mean

            phi_x = cur_z
            phi_y = next_z
            phi_zero = torch.zeros_like(phi_x)

            if self.dual_dist == "one":
                cst_dist = torch.ones_like(x[:, 0])
            else:
                raise NotImplementedError

            assert cst_dist.ndim == 1 and cst_dist.shape[0] == obs.shape[0]
            if constraint_only:
                cst_penalty = cst_dist * 0.1 - torch.square(phi_zero - phi_x).mean(
                    dim=1
                )  # probably we do we need cst_dist clipping at all (theoretically)
            else:
                cst_penalty = cst_dist - torch.square(phi_y - phi_x).mean(dim=1)
            cst_penalty = torch.clamp(cst_penalty, max=self.dual_slack)

            te_obj = rewards.mean() + (dual_lam.detach() * cst_penalty).mean()

            v.update(
                {
                    f"{'constraint_' if constraint_only else ''}prevupd_cst_penalty": cst_penalty
                }
            )
            tensors.update(
                {
                    f"{'Constraint' if constraint_only else ''}PrevupdDualCstPenalty": cst_penalty.mean(),
                }
            )
        else:
            raise NotImplementedError

        loss_te = -te_obj

        tensors.update(
            {
                f"{'Constraint' if constraint_only else ''}PrevupdTeObjMean": te_obj.mean(),
                f"{'Constraint' if constraint_only else ''}PrevupdLossTe": loss_te,
            }
        )

    def _update_loss_prevupd_dual_lam(self, tensors, v, constraint_only=False):
        if constraint_only:
            log_prevupd_dual_lam = self.prevupd_dual_lam_constraint.param
        else:
            log_prevupd_dual_lam = self.prevupd_dual_lam.param
        prevupd_dual_lam = log_prevupd_dual_lam.exp()
        if constraint_only:
            loss_prevupd_dual_lam = (
                log_prevupd_dual_lam
                * (v["constraint_prevupd_cst_penalty"].detach()).mean()
            )
        else:
            loss_prevupd_dual_lam = (
                log_prevupd_dual_lam * (v["prevupd_cst_penalty"].detach()).mean()
            )

        tensors.update(
            {
                f"{'Constraint' if constraint_only else ''}PrevupdDualLam": prevupd_dual_lam,
                f"{'Constraint' if constraint_only else ''}PrevupdLossDualLam": loss_prevupd_dual_lam,
            }
        )

    def _update_loss_qf(self, tensors, v, exploration=False, hierarchical=False):
        obs = v["obs"]
        next_obs = v["next_obs"]
        options = v["options"]
        next_options = v["next_options"]
        actions = v["actions"]
        dones = v["dones"]

        if self.debug_noconst == 7:
            # update options and corresponding rewards
            # should update options, next_options
            assert self.discrete == 0
            options = torch.randn_like(options)
            options = options / torch.norm(options, dim=1, keepdim=True)
            next_options = options

            v["options"] = options
            v["next_options"] = next_options
            rewards = ((v["next_z"] - v["cur_z"]) * options).sum(dim=1)
            v["rewards"] = rewards
        elif self.debug_noconst == 8:
            # update options according to the initial and final state
            assert self.discrete == 0
            if np.random.random() < 0.5:
                initial_obs = v["initial_obs"]
                obs = v["obs"]
                with torch.no_grad():
                    cur_z = self.traj_encoder(initial_obs).mean
                    next_z = self.traj_encoder(obs).mean
                diff = next_z - cur_z
                options = diff / (torch.norm(diff, dim=1, keepdim=True) + 1e-6)

                next_options = options

                v["options"] = (
                    options  # is options used later? yes, for updating the policy.
                )
                v["next_options"] = next_options
                rewards = (diff * options).sum(dim=1)
                v["rewards"] = rewards

        if exploration:
            options = None
            next_options = None
            dones = v["dones_exp"]
            if self.prevupd:
                exp_options = v["exp_options"]
                next_exp_options = v["next_exp_options"]
                processed_cat_obs = self._get_concat_obs(
                    self.option_policy.process_observations(obs), exp_options
                )
                next_processed_cat_obs = self._get_concat_obs(
                    self.option_policy.process_observations(next_obs), next_exp_options
                )
            else:
                processed_cat_obs = self.exploration_policy.process_observations(obs)
                next_processed_cat_obs = self.exploration_policy.process_observations(
                    next_obs
                )

            sac_utils.update_loss_qf(
                self,
                tensors,
                v,
                obs=processed_cat_obs,
                actions=actions,
                next_obs=next_processed_cat_obs,
                dones=dones,
                rewards=v["exploration_rewards"] * self._reward_scale_factor,
                policy=self.exploration_policy,
                qf1=self.exploration_qf1,
                qf2=self.exploration_qf2,
                log_alpha=self.exploration_log_alpha,
                target_qf1=self.target_exploration_qf1,
                target_qf2=self.target_exploration_qf2,
                description_prefix="exploration_",
                discount=self.exploration_sac_discount,
            )
        elif hierarchical:
            processed_cat_obs = self._get_concat_obs(
                self.option_policy.process_observations(obs), options
            )
            next_processed_cat_obs = self._get_concat_obs(
                self.option_policy.process_observations(next_obs), next_options
            )

            sac_utils.update_loss_dqf(
                self,
                tensors,
                v,
                obs=processed_cat_obs,
                actions=v["low_options"],
                next_obs=next_processed_cat_obs,
                dones=dones,
                rewards=v["target_rewards"] * self._reward_scale_factor,
                qf1=self.hierarchical_qf1,
                target_qf1=self.target_hierarchical_qf1,
                description_prefix="hierarchical_",
            )

        else:
            processed_cat_obs = self._get_concat_obs(
                self.option_policy.process_observations(obs), options
            )
            next_options = options
            next_processed_cat_obs = self._get_concat_obs(
                self.option_policy.process_observations(next_obs), next_options
            )

            # assert dones.ndim == 1, dones.shape
            # dones = dones.bool() | ((options - next_options).abs().sum(dim=-1) > 0.001)
            # dones = dones.float()

            sac_utils.update_loss_qf(
                self,
                tensors,
                v,
                obs=processed_cat_obs,
                actions=actions,
                next_obs=next_processed_cat_obs,
                dones=dones,
                rewards=v["rewards"] * self._reward_scale_factor,
                policy=self.option_policy,
                qf1=self.qf1,
                qf2=self.qf2,
                log_alpha=self.log_alpha,
                target_qf1=self.target_qf1,
                target_qf2=self.target_qf2,
                description_prefix="",
                discount=self.discount,
            )

        v.update(
            {
                "processed_cat_obs": processed_cat_obs,
                "next_processed_cat_obs": next_processed_cat_obs,
            }
        )

    def _update_loss_op(self, tensors, v, exploration=False):
        obs = v["obs"]
        options = v["options"]

        if "exp_obs" in v:
            obs = torch.cat([obs, v["exp_obs"]], dim=0)
            options = torch.cat([options, v["exp_options"]], dim=0)
            if self.use_random_options_for_exploration:
                obs = torch.cat([obs, v["exp_obs"]], dim=0)
                options = torch.cat([options, v["exp_random_options"]], dim=0)

        # why don't we simply use v["processed_cat_obs"]?
        if not exploration:
            processed_cat_obs = self._get_concat_obs(
                self.option_policy.process_observations(obs), options
            )

            sac_utils.update_loss_sacp(
                self,
                tensors,
                v,
                obs=processed_cat_obs,
                policy=self.option_policy,
                log_alpha=self.log_alpha,
                qf1=self.qf1,
                qf2=self.qf2,
                description_prefix="",
            )

        else:
            if self.prevupd:
                exp_options = v["exp_options"]
                processed_cat_obs = self._get_concat_obs(
                    self.option_policy.process_observations(obs), exp_options
                )
            else:
                processed_cat_obs = self.exploration_policy.process_observations(obs)
            sac_utils.update_loss_sacp(
                self,
                tensors,
                v,
                obs=processed_cat_obs,
                policy=self.exploration_policy,
                log_alpha=self.exploration_log_alpha,
                qf1=self.exploration_qf1,
                qf2=self.exploration_qf2,
                description_prefix="exploration_",
            )

    def _update_loss_alpha(self, tensors, v, exploration=False):
        if not exploration:
            sac_utils.update_loss_alpha(
                self,
                tensors,
                v,
                log_alpha=self.log_alpha,
                description_prefix="",
            )
        else:
            sac_utils.update_loss_alpha(
                self,
                tensors,
                v,
                log_alpha=self.exploration_log_alpha,
                description_prefix="exploration_",
            )

    def _evaluate_policy(self, runner):
        if self.visualize_rewards:
            if self.replay_buffer.n_transitions_stored >= self.min_buffer_size:

                def reward_fn(self, obs, actions, next_obs):
                    base_obs = self.replay_buffer.sample_transitions(
                        self._trans_minibatch_size
                    )["obs"]
                    base_obs = torch.from_numpy(base_obs).float().to(self.device)
                    next_obs = torch.from_numpy(next_obs).float().to(self.device)
                    obs = torch.from_numpy(obs).float().to(self.device)
                    actions = torch.from_numpy(actions).float().to(self.device)

                    with torch.no_grad():
                        if self.use_traj_for_apt_rep:
                            base_rep, rep = torch.split(
                                self.traj_encoder(
                                    torch.cat([base_obs, obs], dim=0)
                                ).mean,
                                [len(base_obs), len(obs)],
                            )
                            rewards = self.pbe.get_reward(rep, base_rep, use_rms=False)

                        else:
                            assert self.icm
                            rep = self.icm.get_rep(obs, actions)
                            rewards = self.pbe(rep)

                    return rewards.cpu().numpy()

                plot_visualize_rewards(self, runner, self.replay_buffer, reward_fn)
        if (
            self.goal_reaching
            and hasattr(runner._env, "get_goals")
            and self.env_name == "kitchen"
        ):  # currently only supported for kitchen
            # self._evaluate_goal_reaching_policy(runner)
            pass
        if self.hilp:
            if self.replay_buffer.n_transitions_stored >= self.min_buffer_size:
                for i in range(1):
                    samples = self.replay_buffer.sample_transitions_with_goals(
                        self._trans_minibatch_size, p_trajgoal=self.hilp_p_trajgoal
                    )
                    # plot observations and corresponding phis

                    self.draw_one_plots(
                        runner,
                        description_prefix=f"Hilp{i}_",
                        data={
                            "obs": [samples["obs"]],
                            "coordinates": [samples["obs"][:, :2]],
                        },
                    )
                # self.draw_one_plots_with_video(runner, deterministic_policy=True)

        if self.no_plot:
            return
        self._evaluate_policy_inner(
            runner, deterministic_policy=True, description_prefix=""
        )
        self._evaluate_policy_inner(
            runner, deterministic_policy=False, description_prefix="train_"
        )
        if self.prevupd:
            self._evaluate_policy_inner(
                runner,
                deterministic_policy=False,
                prevupd=True,
                description_prefix="prevupd_",
            )

    def _restrict_te_obs(self, obs):
        # if self.te_restrict_obs_idxs is not None:
        #     return obs[:, self.te_restrict_obs_idxs]
        return obs

    def _evaluate_policy_inner(
        self, runner, deterministic_policy=True, prevupd=False, description_prefix=""
    ):
        if prevupd:
            assert self.prevupd
            self.dummy_traj_encoder = self.traj_encoder
            self.traj_encoder = self.prevupd_traj_encoder
        if self.hierarchical:
            self.dummy = self.dim_option
            self.dim_option = self.hierarchical_dim
        if self.discrete:
            eye_options = np.eye(self.dim_option)
            random_options = []
            colors = []
            for i in range(self.dim_option):
                num_trajs_per_option = (
                    self.num_random_trajectories // self.dim_option
                    + (i < self.num_random_trajectories % self.dim_option)
                )
                for _ in range(num_trajs_per_option):
                    random_options.append(eye_options[i])
                    colors.append(i)
            random_options = np.array(random_options)
            colors = np.array(colors)
            num_evals = len(random_options)
            from matplotlib import cm

            cmap = "tab10" if self.dim_option <= 10 else "tab20"
            random_option_colors = []
            for i in range(num_evals):
                random_option_colors.extend([cm.get_cmap(cmap)(colors[i])[:3]])
            random_option_colors = np.array(random_option_colors)
        else:
            random_options = np.random.randn(
                self.num_random_trajectories, self.dim_option
            )
            if self.unit_length:
                random_options = random_options / np.linalg.norm(
                    random_options, axis=1, keepdims=True
                )
            random_option_colors = get_option_colors(random_options * 4)

        if self.goal_reaching:
            random_options = self.get_random_goals(self.num_random_trajectories)

        if not deterministic_policy and self.exploration_type > 0:
            trajectories_kwargs = self._get_train_trajectories_kwargs(
                runner, options=random_options
            )
        else:
            trajectories_kwargs = {
                "sampler_key": "option_policy",
                "extras": self._generate_option_extras(random_options),
            }

        random_trajectories = self._get_trajectories(
            runner,
            **trajectories_kwargs,
            worker_update=dict(
                _render=False,
                _deterministic_policy=deterministic_policy,
            ),
            env_update=dict(_action_noise_std=None),
        )

        data = self.process_samples(random_trajectories)
        last_obs = torch.stack(
            [torch.from_numpy(ob[-1]).to(self.device) for ob in data["obs"]]
        )
        option_dists = self.traj_encoder(last_obs)

        if self.inner:
            option_stddevs = torch.ones_like(option_dists.stddev.detach().cpu()).numpy()
        else:
            option_stddevs = option_dists.stddev.detach().cpu().numpy()
        option_samples = option_dists.mean.detach().cpu().numpy()

        option_colors = random_option_colors

        with FigManager(
            runner, f"{description_prefix}TrajPhiPlot_RandomZ", subplot_spec=(1, 2)
        ) as fm:
            runner._env.render_trajectories(
                random_trajectories, random_option_colors, self.eval_plot_axis, fm.ax[0]
            )
            if self.goal_reaching:
                # draw goals on top of this if state-based obs
                goals = runner._env._apply_unnormalize_obs(
                    random_options.reshape(-1, *runner._env.observation_space.shape)
                )

                # render the goals
                if len(runner._env.observation_space.shape) == 1:
                    # state-space
                    if any(x in self.env_name for x in ["half_cheetah", "walker"]):
                        # get_ylim
                        ymin, ymax = fm.ax[0].get_ylim()
                        for i, (goal, color) in enumerate(zip(goals, option_colors)):
                            fm.ax[0].plot(
                                goal[0],
                                (i - len(goals) / 2) / 1.25,
                                "*",
                                color=color,
                                markersize=3,
                            )
                    else:
                        for goal, color in zip(goals, option_colors):
                            fm.ax[0].plot(
                                goal[0], goal[1], "*", color=color, markersize=10
                            )
                            # TODO: render the goal image

                if not deterministic_policy:
                    # save trajectories of training time (can we do it just with replay buffer? but we may want to see what goals are proposed.)
                    file_name = os.path.join(
                        runner._snapshotter._snapshot_dir,
                        f"{runner.step_itr}_train_trajectories.pkl",
                    )
                    coordinates_and_goals = {
                        "coordinates": [
                            traj["env_infos"]["coordinates"]
                            for traj in random_trajectories
                        ],  # [np.array (L * 2), ...]
                        "goals": [g for g in goals],  # [np.array(2), ...]
                    }
                    with open(file_name, "wb") as f:
                        pickle.dump(coordinates_and_goals, f)

            # with FigManager(runner, f"{description_prefix}PhiPlot") as fm:
            if option_samples.max() > 20:
                # 1d auxiliary drawing is needed for when the samples have too high values

                draw_2d_gaussians(
                    option_samples, option_stddevs, option_colors, fm.ax[1], alpha=1.0
                )

            draw_2d_gaussians(
                option_samples,
                [[0.03, 0.03]] * len(option_samples),
                option_colors,
                fm.ax[1],
                alpha=1.0,
                fill=True,
                use_adaptive_axis=True,
            )

        if self.goal_reaching and len(runner._env.observation_space.shape) == 3:
            with FigManager(runner, f"{description_prefix}GoalImages") as fm:
                # pixel-space
                num_goals = len(random_trajectories)
                grid_width = 6
                grid_height = num_goals // 6 + (num_goals % 6 > 0)

                # Create a separate figure for rendering the goals
                fig, axs = plt.subplots(
                    grid_height, grid_width, figsize=(15, grid_height * 2.5)
                )
                axs = axs.flatten()
                goal_images = goals[:, :, :, :3].astype(np.uint8)
                i = 0
                for i, goal_image in enumerate(goal_images):
                    axs[i].imshow(goal_image)
                    axs[i].axis("off")

                # Turn off the remaining empty subplots
                for j in range(i + 1, len(axs)):
                    axs[j].axis("off")

                plt.tight_layout()

                # Render the figure to a canvas

                canvas = FigureCanvas(fig)
                canvas.draw()
                width, height = fig.get_size_inches() * fig.get_dpi()
                image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8").reshape(
                    int(height), int(width), 3
                )

                plt.close(fig)  # Close the figure to free memory

                # Display the rendered image in fm.ax[2]
                fm.ax.imshow(image)
                fm.ax.axis("off")

        # =====unused====
        # with FigManager(runner, f"{description_prefix}PhiPlotSampled") as fm:
        #     for i in range(5):
        #         obs = torch.stack(
        #             [torch.from_numpy(ob).to(self.device) for ob in data["obs"][i]]
        #         )
        #         option_dists = self.traj_encoder(obs)
        #         option_means = option_dists.mean.detach().cpu().numpy()
        #         if self.inner:
        #             option_stddevs = torch.ones_like(
        #                 option_dists.stddev.detach().cpu()
        #             ).numpy()
        #         else:
        #             option_stddevs = option_dists.stddev.detach().cpu().numpy()
        #         colors = np.tile(option_colors[i], (len(option_means), 1))

        #         draw_2d_gaussians(option_means, option_stddevs, colors, fm.ax)
        #         draw_2d_gaussians(
        #             option_samples,
        #             [[0.03, 0.03]] * len(option_samples),
        #             colors,
        #             fm.ax,
        #             fill=True,
        #             use_adaptive_axis=True,
        #         )
        # ===============

        eval_option_metrics = {}

        # Videos
        if self.eval_record_video:
            video_trajectories, extras = self._get_video_trajectories(
                runner, deterministic_policy=deterministic_policy, return_extra=True
            )
            if self.goal_reaching and len(runner._env.observation_space.shape) == 3:
                assert extras[0]["option"].ndim == 1
                options = np.stack([extra["option"] for extra in extras], axis=0)
                goal_images = (
                    options.reshape(-1, *runner._env.observation_space.shape)[
                        :, :, :, :3
                    ]
                    .transpose(0, 3, 1, 2)
                    .astype(np.uint8)
                )
            else:
                extras = None
                goal_images = None

            record_video(
                runner,
                f"{description_prefix}Video_RandomZ",
                video_trajectories,
                skip_frames=self.video_skip_frames,
                goal_images=goal_images,
            )

            # if option_freq > 0, record additional video
            if self.option_freq > 0:
                video_option_freq = 20
                if self.discrete:
                    num_videos = self.num_video_repeats * self.dim_option
                    video_options = np.eye(self.dim_option)[
                        np.random.randint(
                            0,
                            self.dim_option,
                            num_videos * self.max_path_length // video_option_freq,
                        )
                    ]
                    video_options = video_options.reshape(
                        num_videos, -1, self.dim_option
                    )
                    video_options = video_options.repeat(video_option_freq, axis=1)
                else:
                    if self.dim_option == 2:
                        radius = 1.0 if self.unit_length else 1.5
                        video_options = []
                        for angle in [3, 2, 1, 4]:
                            video_options.append(
                                [
                                    radius * np.cos(angle * np.pi / 4),
                                    radius * np.sin(angle * np.pi / 4),
                                ]
                            )
                        video_options.append([0, 0])
                        for angle in [0, 5, 6, 7]:
                            video_options.append(
                                [
                                    radius * np.cos(angle * np.pi / 4),
                                    radius * np.sin(angle * np.pi / 4),
                                ]
                            )
                        video_options = np.array(video_options)
                    else:
                        video_options = np.random.randn(9, self.dim_option)
                        if self.unit_length:
                            video_options = video_options / np.linalg.norm(
                                video_options, axis=1, keepdims=True
                            )
                    video_options = video_options.repeat(num_videos, axis=0)
                    video_options = video_options.reshape(
                        -1, 1, self.dim_option
                    ).repeat(self.option_freq, axis=1)
                video_trajectories = self._get_trajectories(
                    runner,
                    sampler_key="local_option_policy",
                    extras=self._generate_option_extras_list(video_options),
                    worker_update=dict(
                        _render=True,
                        _deterministic_policy=True,
                    ),
                )
                record_video(
                    runner,
                    f"{description_prefix}Video_RandomZ_OptionFreq20",
                    video_trajectories,
                    skip_frames=self.video_skip_frames,
                )
        if self.eval_record_video:
            self.draw_one_plots(
                runner,
                deterministic_policy=deterministic_policy,
                prevupd=prevupd,
                description_prefix=description_prefix,
            )

        eval_option_metrics.update(
            runner._env.calc_eval_metrics(
                random_trajectories, is_option_trajectories=True
            )
        )
        if deterministic_policy:
            eval_option_metrics.update(self._get_goal_conditioned_metrics(runner))
            with global_context.GlobalContext({"phase": "eval", "policy": "option"}):
                log_performance_ex(
                    runner.step_itr,
                    TrajectoryBatch.from_trajectory_list(
                        self._env_spec, random_trajectories
                    ),
                    discount=self.discount,
                    additional_records=eval_option_metrics,
                )

            if not prevupd:  # why is this here?
                self._log_eval_metrics(runner)

        if self.hierarchical:
            self.dim_option = self.dummy
        if prevupd:
            assert self.prevupd
            self.traj_encoder = self.dummy_traj_encoder
            self.dummy_traj_encoder = None

    def _get_goal_conditioned_metrics(self, runner):
        eval_option_metrics = {}

        self.goal_range = None
        if self.env_name in [
            "half_cheetah",
            "ant",
            "dmc_quadruped",
            "dmc_humanoid",
            "dmc_humanoid_state",
        ]:
            if self.env_name == "half_cheetah":
                self.goal_range = 100
            elif self.env_name == "ant":
                self.goal_range = 50
            elif self.env_name == "dmc_quadruped":
                self.goal_range = 15
            elif self.env_name == "dmc_humanoid":
                self.goal_range = 10
            elif self.env_name == "dmc_humanoid_state":
                self.goal_range = 40

        # Goal-conditioned metrics
        env = runner._env
        goals = []  # list of (goal_obs, goal_info)
        goal_metrics = defaultdict(list)
        if self.env_name == "kitchen":
            # goal_names = [
            #     "BottomBurner",
            #     "LightSwitch",
            #     "SlideCabinet",
            #     "HingeCabinet",
            #     "Microwave",
            #     "Kettle",
            # ]
            goal_names = [
                "_".join(c) for c in env.goal_configs
            ]  # [['bottom_burner'], ['light_switch'], ['slide_cabinet'], ['hinge_cabinet'], ['microwave'], ['kettle'], ['light_switch', 'slide_cabinet'], ['light_switch', 'hinge_cabinet'], ['light_switch', 'kettle'], ['slide_cabinet', 'hinge_cabinet'], ['slide_cabinet', 'kettle'], ['hinge_cabinet', 'kettle']]
            for i, goal_name in enumerate(goal_names):
                goal_obs = env.render_goal(goal_idx=i).copy().astype(np.float32)
                goal_obs = np.tile(goal_obs, self.frame_stack or 1).flatten()
                goals.append((goal_obs, {"goal_idx": i, "goal_name": goal_name}))
        elif self.env_name in ["antmaze-large-play", "antmaze-ultra-play"]:
            env.reset()

            base_observation = env.unwrapped._get_obs().astype(np.float32)
            print("env goals:", env.unwrapped.goals)
            for i, env_goal in enumerate(env.unwrapped.goals):
                obs_goal = base_observation.copy()
                obs_goal[:2] = env_goal
                goals.append((obs_goal, {"env_goal": env_goal}))

        elif self.env_name in ["dmc_cheetah", "dmc_quadruped", "dmc_humanoid"]:
            for i in range(20):
                env.reset()
                state = env.physics.get_state().copy()
                if self.env_name == "dmc_cheetah":
                    goal_loc = (np.random.rand(1) * 2 - 1) * self.goal_range
                    state[:1] = goal_loc
                else:
                    goal_loc = (np.random.rand(2) * 2 - 1) * self.goal_range
                    state[:2] = goal_loc
                env.physics.set_state(state)
                if self.env_name == "dmc_humanoid":
                    for _ in range(50):
                        env.step(np.zeros_like(env.action_space.sample()))
                else:
                    env.step(np.zeros_like(env.action_space.sample()))
                goal_obs = (
                    env.render(mode="rgb_array", width=64, height=64)
                    .copy()
                    .astype(np.float32)
                )
                goal_obs = np.tile(goal_obs, self.frame_stack or 1).flatten()
                goals.append((goal_obs, {"goal_loc": goal_loc}))
        elif self.env_name in ["dmc_humanoid_state"]:
            for i in range(20):
                ob = env.reset()
                state = env.physics.get_state().copy()
                goal_loc = (np.random.rand(2) * 2 - 1) * self.goal_range
                ob[:2] = goal_loc
                for _ in range(5):
                    env.step(np.zeros_like(env.action_space.sample()))
                goal_obs = ob.copy().astype(np.float32)
                goals.append((goal_obs, {"goal_loc": goal_loc}))
        elif self.env_name in [
            "ant",
            "ant_pixel",
            "half_cheetah",
        ]:
            for i in range(20):
                ob = env.reset()
                state = env.unwrapped._get_obs().copy()
                if self.env_name in ["half_cheetah"]:
                    goal_loc = (np.random.rand(1) * 2 - 1) * self.goal_range
                    state[:1] = goal_loc
                    env.set_state(state[:9], state[9:])
                elif self.env_name in ["dmc_humanoid_state"]:
                    goal_loc = (np.random.rand(2) * 2 - 1) * self.goal_range
                    ob[:2] = goal_loc
                else:
                    goal_loc = (np.random.rand(2) * 2 - 1) * self.goal_range
                    state[:2] = goal_loc
                    env.set_state(state[:15], state[15:])
                for _ in range(5):
                    env.step(np.zeros_like(env.action_space.sample()))
                if self.env_name == "ant_pixel":
                    goal_obs = (
                        env.render(mode="rgb_array", width=64, height=64)
                        .copy()
                        .astype(np.float32)
                    )
                    goal_obs = np.tile(goal_obs, self.frame_stack or 1).flatten()
                else:
                    if self.env_name == "dmc_humanoid_state":
                        goal_obs = ob.copy().astype(np.float32)
                    else:
                        goal_obs = env._apply_normalize_obs(state).astype(np.float32)
                goals.append((goal_obs, {"goal_loc": goal_loc}))

        renders = []
        coordinates = []

        for method in ["Adaptive"] if self.discrete else [""]:
            self.option_policy._force_use_mode_actions = True
            if len(goals) == 0:
                break
            for goal_obs, goal_info in goals:
                if self.env_name in ["antmaze-large-play", "antmaze-ultra-play"]:
                    env.unwrapped.set_target(goal_info["env_goal"])
                obs = env.reset()
                step = 0
                done = False
                success = 0
                option = None
                render = []
                coordinate = []
                while step < self.max_path_length and not done:
                    if self.inner:  # always the case
                        if self.goal_reaching:
                            option = goal_obs
                        else:
                            te_input = torch.from_numpy(np.stack([obs, goal_obs])).to(
                                self.device
                            )  # 2, d
                            phi_s, phi_g = self.traj_encoder(
                                self._restrict_te_obs(te_input)
                            ).mean
                            phi_s, phi_g = (
                                phi_s.detach().cpu().numpy(),
                                phi_g.detach().cpu().numpy(),
                            )
                            if self.discrete:
                                if method == "Adaptive":
                                    option = np.eye(self.dim_option)[
                                        (phi_g - phi_s).argmax()
                                    ]
                                else:
                                    if option is None:
                                        option = np.eye(self.dim_option)[
                                            (phi_g - phi_s).argmax()
                                        ]
                            else:
                                option = (phi_g - phi_s) / np.linalg.norm(phi_g - phi_s)
                    else:
                        te_input = torch.from_numpy(goal_obs[None, ...]).to(self.device)
                        phi = self.traj_encoder(self._restrict_te_obs(te_input)).mean[0]
                        phi = phi.detach().cpu().numpy()
                        if self.discrete:
                            option = np.eye(self.dim_option)[phi.argmax()]
                        else:
                            option = phi
                    action, agent_info = self.option_policy.get_action(
                        np.concatenate([obs, option])
                    )
                    next_obs, reward, done, info = env.step(action, render=True)
                    obs = next_obs

                    if self.env_name == "kitchen":
                        success = max(
                            success, env.compute_success(goal_info["goal_idx"])[0]
                        )
                    elif self.env_name in ["antmaze-large-play", "antmaze-ultra-play"]:
                        success = max(success, reward)  # assume sparse reward

                    step += 1

                    coordinate.append(info["coordinates"])

                    if True and step % 3 == 0:  # original code: if i >= num_episode
                        # https://github.com/seohongpark/HIQL/blob/b3e8366ccaec99113778bc360b19894e7a63317c/jaxrl_m/evaluation.py#L55
                        env_name = self.env_name
                        # if "antmaze" in env_name:
                        #     size = 200
                        #     cur_frame = (
                        #         env.render(
                        #             mode="rgb_array",
                        #             width=size,
                        #             height=size,
                        #             camera_id=0,
                        #         )
                        #         .transpose(2, 0, 1)
                        #         .copy()
                        #     )
                        #     render.append(cur_frame)
                        # elif (
                        #     "pointmaze" in env_name
                        #     or env_name == "ant"
                        #     or env_name == "half_cheetah"
                        #     or env_name == "dmc_humanoid_state"
                        # ):
                        #     size = 64
                        #     cur_frame = (
                        #         env.render(
                        #             mode="rgb_array",
                        #             width=size,
                        #             height=size,
                        #             camera_id=0,
                        #         )
                        #         .transpose(2, 0, 1)
                        #         .copy()
                        #     )
                        #     render.append(cur_frame)

                        # elif env_name in [
                        #     "kitchen",
                        #     "dmc_quadruped",
                        #     "dmc_humanoid",
                        # ]:
                        #     # # try kitchen pixel only
                        #     # assert (
                        #     #     next_obs.ndim == 3 and next_obs.shape[-1] % 3 == 0
                        #     # ), next_obs.shape

                        #     render.append(
                        #         next_obs.reshape(64, 64, -1)[:, :, :3]
                        #         .transpose(2, 0, 1)
                        #         .astype(np.uint8)
                        #     )

                        render_image = info["render"]
                        assert render_image.ndim == 3
                        if render_image.shape[-1] == 3:  # 64, 64, 3
                            render_image = render_image.transpose(2, 0, 1)
                        render.append(render_image)

                # render video
                renders.append(np.stack(render, axis=0))
                coordinates.append((np.stack(coordinate, axis=0), goal_info))

                # update metrics
                if self.env_name == "kitchen":
                    goal_metrics[f'Kitchen{method}Goal{goal_info["goal_name"]}'].append(
                        success
                    )
                    goal_metrics[f"Kitchen{method}GoalOverall"].append(
                        success * len(goal_names)
                    )  # we calculate the mean: so we multiply by the number of goals
                    if goal_info["goal_idx"] < 6:
                        goal_metrics[f"Kitchen{method}GoalOverall6"].append(
                            success * len(goal_names[:6])
                        )
                elif self.env_name in ["antmaze-large-play", "antmaze-ultra-play"]:
                    goal_metrics[f'Antmaze{method}Goal{goal_info["env_goal"]}'].append(
                        success
                    )
                    goal_metrics[f"Antmaze{method}GoalOverall"].append(
                        success * len(goals)
                    )
                elif self.env_name in [
                    "dmc_cheetah",
                    "dmc_quadruped",
                    "dmc_humanoid",
                    "ant",
                    "ant_pixel",
                    "half_cheetah",
                    "dmc_humanoid_state",
                ]:
                    if self.env_name in ["dmc_cheetah"]:
                        cur_loc = env.physics.get_state()[:1]
                    elif self.env_name in [
                        "dmc_quadruped",
                        "dmc_humanoid",
                        "dmc_humanoid_state",
                    ]:
                        cur_loc = env.physics.get_state()[:2]
                    elif self.env_name in ["half_cheetah"]:
                        cur_loc = env.unwrapped._get_obs()[:1]
                    else:
                        cur_loc = env.unwrapped._get_obs()[:2]
                    distance = np.linalg.norm(cur_loc - goal_info["goal_loc"])
                    squared_distance = distance**2
                    goal_metrics[f"Goal{method}Distance"].append(distance)
                    goal_metrics[f"Goal{method}SquaredDistance"].append(
                        squared_distance
                    )

            # pack "renders" as trajectory: List[{"env_infos": {"render": np.ndarray}]
            trajectories = [
                {
                    "env_infos": {"render": render},
                    "agent_infos": {"cur_exploration": np.zeros(len(render))},
                }
                for render in renders
            ]

            # set goal_images if possible
            if (
                "antmaze-large-play" in self.env_name
                or "antmaze-ultra-play" in self.env_name
            ):
                goal_images = []
                for goal_obs, goal_info in goals:
                    goal_coords = goal_info["env_goal"]  # tuple(x, y)
                    fig = Figure()
                    canvas = FigureCanvas(fig)
                    ax = fig.add_subplot()

                    env.scatter_trajectory(
                        np.array(goal_coords)[None], color="r", ax=ax
                    )
                    ax.scatter(
                        goal_coords[0], goal_coords[1], color="r", marker="*", s=100
                    )
                    # resize this figure to have the same wh as renders

                    canvas.draw()
                    image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
                    image = image.reshape(canvas.get_width_height()[::-1] + (3,))
                    goal_images.append(image)

                goal_images = np.stack(goal_images, axis=0)  # N, H, W, C
                goal_images = goal_images.transpose(0, 3, 1, 2)

            elif self.env_name in [
                "dmc_cheetah",
                "dmc_quadruped",
                "kitchen",
                "dmc_humanoid",
                "dmc_humanoid_state",
            ]:
                goal_images = None
            else:
                goal_images = None

            record_video(
                runner,
                f"Video_{method}_GoalConditioned",
                trajectories,
                skip_frames=self.video_skip_frames,
                goal_images=goal_images,
            )

        goal_metrics = {key: np.mean(value) for key, value in goal_metrics.items()}
        eval_option_metrics.update(goal_metrics)

        # Train coverage metric
        if len(self.coverage_queue) > 0:
            coverage_data = np.array(self.coverage_queue)
            if self.env_name == "kitchen":
                assert (
                    coverage_data.ndim == 2 and coverage_data.shape[-1] == 6
                ), coverage_data.shape
                coverage = coverage_data.max(axis=0)
                goal_names = [
                    "BottomBurner",
                    "LightSwitch",
                    "SlideCabinet",
                    "HingeCabinet",
                    "Microwave",
                    "Kettle",
                ]

                for i, goal_name in enumerate(goal_names):
                    eval_option_metrics[f"TrainKitchenTask{goal_name}"] = coverage[i]
                    eval_option_metrics[f"CoincidentalRate{goal_name}"] = (
                        self.coincidental_goal_success[i]
                    )
                eval_option_metrics[f"TrainKitchenOverall"] = coverage.sum()

            else:
                total_coverage_data = np.array(self.coverage_log)
                assert coverage_data.ndim == 3
                assert (
                    coverage_data.shape[-1] == 2
                    if not self.env_name in ["half_cheetah", "walker"]
                    else 1
                ), coverage_data.shape
                coverage_data = coverage_data.reshape(-1, coverage_data.shape[-1])
                total_coverage_data = total_coverage_data.reshape(
                    -1, total_coverage_data.shape[-1]
                )
                uniq_coords = np.unique(np.floor(coverage_data), axis=0)
                total_uniq_coords = np.unique(np.floor(total_coverage_data), axis=0)
                eval_option_metrics["TrainNumUniqueCoords"] = len(uniq_coords)
                eval_option_metrics["TrainTotalNumUniqueCoords"] = len(
                    total_uniq_coords
                )
                eval_option_metrics["MaxDistFromOrigin"] = np.linalg.norm(
                    coverage_data, axis=-1
                ).max(axis=0)
        else:
            if self.env_name == "kitchen":
                goal_names = [
                    "BottomBurner",
                    "LightSwitch",
                    "SlideCabinet",
                    "HingeCabinet",
                    "Microwave",
                    "Kettle",
                ]
                for i, goal_name in enumerate(goal_names):
                    eval_option_metrics[f"TrainKitchenTask{goal_name}"] = 0
                eval_option_metrics[f"TrainKitchenOverall"] = 0 
            else:
                eval_option_metrics["TrainNumUniqueCoords"] = 0
                eval_option_metrics["TrainTotalNumUniqueCoords"] = 0

        # save all goal-conditioned and paper relevant data
        # coordinates during evaluation
        # proposed goals
        # coordinates during training

        file_name = os.path.join(
            runner._snapshotter._snapshot_dir, f"{runner.step_itr}_coords.pkl"
        )
        save_goals = None
        if len(goals) == 0:
            save_goals = []
        elif "env_goal" in goal_info:
            save_goals = [goal_info["env_goal"] for (goal_obs, goal_info) in goals]
        elif "goal_loc" in goal_info:
            save_goals = [goal_info["goal_loc"] for (goal_obs, goal_info) in goals]
        elif "goal_name" in goal_info:
            save_goals = [goal_obs for (goal_obs, goal_info) in goals]

        coordinates_and_goals = {
            "coordinates": coordinates,  # [np.array (L * 2), ...]
            "goals": save_goals,
        }
        with open(file_name, "wb") as f:
            pickle.dump(coordinates_and_goals, f)

        if len(self.coverage_log) > 0:
            file_name = os.path.join(
                runner._snapshotter._snapshot_dir, f"totalcoords.pkl"
            )
            coordinates_and_goals = {
                "coordinates": np.array(self.coverage_log)  # this is already correct
            }
            with open(file_name, "wb") as f:
                pickle.dump(coordinates_and_goals, f)

        return eval_option_metrics

    def _evaluate_goal_reaching_policy(self, runner):
        deterministic_policy = True
        description_prefix = "Goal_"

        goal_indices = runner._env.get_goals()
        goals = []
        for idx in goal_indices:
            goal = runner._env.render_goal(goal_idx=idx)
            goals.append(goal)
        assert isinstance(goals, list) and goals[0].shape == (64, 64, 3)
        goals = np.stack(goals, axis=0)  # N, 64, 64, 3

        # # draw goals as grid 6 x 2
        # with FigManager(
        #     runner, f"{description_prefix}Goals", subplot_spec=(2, 6)
        # ) as fm:
        #     for i in range(12):
        #         fm.ax.imshow(goals[i])
        #         fm.ax.axis("off")

        obs_dim = np.prod(runner._env.observation_space.shape)
        goal_dim = np.prod(goals.shape[1:])
        assert obs_dim % goal_dim == 0  # frame stack
        options = np.tile(goals, obs_dim // goal_dim).reshape(len(goals), -1)
        assert options[-1].shape == obs_dim

        option_colors = get_option_colors(options * 4)

        trajectories_kwargs = {
            "sampler_key": "option_policy",
            "extras": self._generate_option_extras(options),
        }

        random_trajectories = self._get_trajectories(
            runner,
            **trajectories_kwargs,
            worker_update=dict(
                _render=False,
                _deterministic_policy=deterministic_policy,
            ),
            env_update=dict(_action_noise_std=None),
        )

        data = self.process_samples(random_trajectories)
        last_obs = torch.stack(
            [torch.from_numpy(ob[-1]).to(self.device) for ob in data["obs"]]
        )
        option_dists = self.traj_encoder(last_obs)

        option_means = option_dists.mean.detach().cpu().numpy()
        option_stddevs = torch.ones_like(option_dists.stddev.detach().cpu()).numpy()
        option_samples = option_dists.mean.detach().cpu().numpy()

        with FigManager(
            runner, f"{description_prefix}TrajPhiPlot_RandomZ", subplot_spec=(1, 2)
        ) as fm:
            runner._env.render_trajectories(
                random_trajectories, option_colors, self.eval_plot_axis, fm.ax[0]
            )
            if self.goal_reaching:
                # draw options on top of this
                goals = runner._env._apply_unnormalize_obs(options)
                for goal, color in zip(goals, option_colors):
                    fm.ax[0].plot(goal[0], goal[1], "*", color=color, markersize=10)

            if option_samples.max() > 20:
                # 1d auxiliary drawing is needed for when the samples have too high values
                draw_2d_gaussians(
                    option_means, option_stddevs, option_colors, fm.ax[1], alpha=1.0
                )

            draw_2d_gaussians(
                option_samples,
                [[0.03, 0.03]] * len(option_samples),
                option_colors,
                fm.ax[1],
                alpha=1.0,
                fill=True,
                use_adaptive_axis=True,
            )

        eval_option_metrics = {}

        # Videos
        if self.eval_record_video:
            video_trajectories = self._get_video_trajectories(
                runner, deterministic_policy=deterministic_policy, options=options
            )
            goal_images = (
                options.reshape(-1, *runner._env.observation_space.shape)[:, :, :, :3]
                .transpose(0, 3, 1, 2)
                .astype(np.uint8)
            )

            record_video(
                runner,
                f"{description_prefix}Video_RandomZ",
                video_trajectories,
                skip_frames=self.video_skip_frames,
                goal_images=goal_images,
            )

        eval_option_metrics.update(
            runner._env.calc_eval_metrics(
                random_trajectories, is_option_trajectories=True, prefix="Goal"
            )
        )

        with global_context.GlobalContext({"phase": "eval", "policy": "option"}):
            log_performance_ex(
                runner.step_itr,
                TrajectoryBatch.from_trajectory_list(
                    self._env_spec, random_trajectories
                ),
                discount=self.discount,
                additional_records=eval_option_metrics,
            )

        self._log_eval_metrics(runner)

    def _get_video_trajectories(
        self, runner, deterministic_policy=True, return_extra=False, options=None
    ):
        # Videos
        assert self.eval_record_video
        if options is None:
            if self.discrete:
                video_options = np.eye(self.dim_option)
                video_options = video_options.repeat(self.num_video_repeats, axis=0)
            else:
                if self.dim_option == 2:
                    radius = 1.0 if self.unit_length else 1.5
                    video_options = []
                    for angle in [3, 2, 1, 4]:
                        video_options.append(
                            [
                                radius * np.cos(angle * np.pi / 4),
                                radius * np.sin(angle * np.pi / 4),
                            ]
                        )
                    video_options.append([0, 0])
                    for angle in [0, 5, 6, 7]:
                        video_options.append(
                            [
                                radius * np.cos(angle * np.pi / 4),
                                radius * np.sin(angle * np.pi / 4),
                            ]
                        )
                    video_options = np.array(video_options)
                else:
                    video_options = np.random.randn(9, self.dim_option)
                    if self.unit_length:
                        video_options = video_options / np.linalg.norm(
                            video_options, axis=1, keepdims=True
                        )
                video_options = video_options.repeat(self.num_video_repeats, axis=0)

            if self.goal_reaching:
                video_options = self.get_random_goals(self.num_video_repeats)
        else:
            video_options = options
        if not deterministic_policy:
            extras = self._get_train_trajectories_kwargs(runner, options=video_options)[
                "extras"
            ]  # we want exploration policy visualization & reward visualization.
        else:
            extras = self._generate_option_extras(video_options)
        video_trajectories = self._get_trajectories(
            runner,
            sampler_key="local_option_policy",
            extras=extras,
            worker_update=dict(
                _render=True,
                _deterministic_policy=deterministic_policy,
            ),
        )

        if return_extra:
            return video_trajectories, extras
        return video_trajectories

    def draw_one_plots(
        self,
        runner,
        deterministic_policy=True,
        prevupd=False,
        description_prefix="",
        data: dict = None,
    ):
        # Here, data requires "obs" and "coordinates"
        # Videos
        if self.hierarchical:
            self.dummy = self.dim_option
            self.dim_option = self.hierarchical_dim

        option_means = []
        if data is None:
            video_trajectories = self._get_video_trajectories(
                runner, deterministic_policy=deterministic_policy
            )

            data = self.process_samples(video_trajectories)

        obs_list = []
        coords_list = []
        option_means_list = []
        is_exploration_list = []

        for i in range(len(data["obs"])):
            obs = torch.stack(
                [torch.from_numpy(ob).to(self.device) for ob in data["obs"][i]]
            )
            option_dists = self.traj_encoder(obs)
            option_means = option_dists.mean.detach().cpu().numpy()

            coords = data["coordinates"][i]

            obs_list.append(obs)
            coords_list.append(coords)
            option_means_list.append(option_means)

            # record_video(
            #     runner,
            #     f"Video_RandomZ_{i}",
            #     [video_trajectories[i]],
            #     skip_frames=self.video_skip_frames,
            #     phi=[option_means],
            # )
            is_exploration_list.append(
                np.ones(option_means.shape[0], dtype=bool)
                if i % 2 == 0
                else np.zeros(option_means.shape[0], dtype=bool)
            )

        coords = np.concatenate(coords_list, axis=0)
        option_means = np.concatenate(option_means_list, axis=0)[:, :]
        is_exploration_list = np.concatenate(is_exploration_list, axis=0)

        color_based_on_coords = True
        if color_based_on_coords:
            # align coordinates as distances from the left bottom and
            base_x, base_y = coords[:, 0].min(), coords[:, 1].min()
            max_x, max_y = coords[:, 0].max(), coords[:, 1].max()

            print("base x y:", base_x, base_y)
            print("max x y:", max_x, max_y)
            dists_x = (coords[:, 0] - base_x) / (max_x - base_x + 1e-6)
            dists_y = (coords[:, 1] - base_y) / (max_y - base_y + 1e-6)

            assert dists_x.ndim == 1 and dists_y.ndim == 1

            option_colors = np.stack(
                [
                    dists_x,
                    dists_y,
                    np.abs(dists_x - dists_y) / 2,
                    np.ones_like(dists_x),
                ],
                axis=1,
            )

        if prevupd:
            option_means_norm = np.linalg.norm(option_means, axis=1)
            largest_val = option_means_norm.max()
            smallest_val = option_means_norm.min()
            normalized_option_means_norm = (option_means_norm - smallest_val) / (
                largest_val - smallest_val + 0.2
            )

            assert normalized_option_means_norm.ndim == 1

            option_colors = np.stack(
                [
                    normalized_option_means_norm,
                    np.zeros_like(normalized_option_means_norm),
                    np.ones_like(normalized_option_means_norm)
                    - normalized_option_means_norm,
                ],
                axis=1,
            )
            option_colors_green = np.stack(
                [
                    np.zeros_like(dists_x),
                    np.ones_like(dists_x),
                    np.zeros_like(dists_x),
                    np.ones_like(dists_x)
                    * 0.01,  # actually this doesn't affect the alpha
                ],
                axis=1,
            )

        if self.plot_first_2dims:
            with FigManager(runner, f"{description_prefix}Phi2dimPlot") as fm:
                if (option_means.max() - option_means.min()) > 20:
                    draw_2d_gaussians(
                        option_means,
                        [[1, 1]] * len(option_means),
                        option_colors,
                        fm.ax,
                        alpha=1.0,
                        fill=True,
                        use_adaptive_axis=True,
                    )
                draw_2d_gaussians(
                    option_means,
                    [[0.03, 0.03]] * len(option_means),
                    option_colors,
                    fm.ax,
                    alpha=1.0,
                    fill=True,
                    use_adaptive_axis=True,
                )

        # make the colors
        with FigManager(
            runner, f"{description_prefix}CoordsPhiPlot", subplot_spec=(1, 2)
        ) as fm:
            if option_means.shape[-1] > 2:

                tsne_option_means = TSNE(n_components=2).fit_transform(option_means)
                assert tsne_option_means.shape == (option_means.shape[0], 2)
                option_means = tsne_option_means

            draw_2d_gaussians(
                option_means,
                [[0.03, 0.03]] * len(option_means),
                option_colors,
                fm.ax[0],
                alpha=1.0,
                fill=True,
                use_adaptive_axis=True,
            )
            if prevupd:
                raise NotImplementedError  # Check if it overlaps
                draw_2d_gaussians(
                    coords[is_exploration_list],
                    ([[0.1, 0.1]] * len(option_means[is_exploration_list])),
                    option_colors_green[is_exploration_list],
                    fm.ax[1],
                    use_adaptive_axis=True,
                    alpha=0.1,
                )
            if hasattr(runner._env, "scatter_trajectory"):
                runner._env.scatter_trajectory(
                    coords,
                    option_colors,
                    fm.ax[1],
                )
            else:
                draw_2d_gaussians(
                    coords,
                    [[0.03, 0.03]] * len(option_means),
                    option_colors,
                    fm.ax[1],
                    alpha=1.0,
                    fill=True,
                    use_adaptive_axis=True,
                )

        if self.hierarchical:
            self.dim_option = self.dummy

    def draw_one_plots_with_video(
        self,
        runner,
        deterministic_policy=True,
    ):
        def get_base_phis():
            samples = self.replay_buffer.sample_transitions_with_goals(
                self._trans_minibatch_size, p_trajgoal=self.hilp_p_trajgoal
            )
            data = {
                "obs": [samples["obs"]],
                "coordinates": [samples["obs"][:, :2]],
            }
            obs = torch.stack(
                [torch.from_numpy(ob).to(self.device) for ob in data["obs"][0]]
            )
            option_dists = self.traj_encoder(obs)
            option_means = option_dists.mean.detach().cpu().numpy()

            return option_means

        def get_video_phis():
            video_trajectories, extras = self._get_video_trajectories(
                runner, deterministic_policy=deterministic_policy, return_extra=True
            )
            goal = extras[0]["option"]
            data = self.process_samples(video_trajectories)

            obs_list = []
            coords_list = []
            option_means_list = []
            is_exploration_list = []

            for i in range(len(data["obs"])):
                obs = torch.stack(
                    [torch.from_numpy(ob).to(self.device) for ob in data["obs"][i]]
                )
                option_dists = self.traj_encoder(obs)
                option_means = option_dists.mean.detach().cpu().numpy()

                coords = data["coordinates"][i]

                obs_list.append(obs)
                coords_list.append(coords)
                option_means_list.append(option_means)

                is_exploration_list.append(
                    np.ones(option_means.shape[0], dtype=bool)
                    if i % 2 == 0
                    else np.zeros(option_means.shape[0], dtype=bool)
                )

            coords = coords_list[0]
            option_means = np.stack(option_means_list, axis=0)
            is_exploration_list = is_exploration_list[0]
            return video_trajectories, option_means, goal

        def get_goal_phis(goals):
            assert goals.ndim == 2
            goals = torch.from_numpy(goals).to(self.device)
            option_dists = self.traj_encoder(goals)
            option_means = option_dists.mean.detach().cpu().numpy()
            return option_means

        base_phis = get_base_phis()
        video_trajectories, video_phis, goal = get_video_phis()
        goal_phis = get_goal_phis(goal[None])

        def get_phi_vid(base_phis, video_phis, goal_phis):
            frames_list = []
            for video_phi in video_phis:
                fig, ax = plt.subplots()
                ax.scatter(base_phis[:, 0], base_phis[:, 1], c="b", s=5)
                ax.scatter(video_phi[0, 0], video_phi[0, 1], c="r", s=10)
                ax.scatter(goal_phis[-1, 0], goal_phis[-1, 1], c="g", s=10)

                def update(frame):
                    # for each frame, update the data stored on each artist.
                    # update the scatter plot:
                    ax.scatter(video_phi[frame][0], video_phi[frame][1], c="r", s=10)
                    return

                ani = animation.FuncAnimation(
                    fig=fig, func=update, frames=len(video_phi), interval=30
                )

                # create temp file and save video
                frames = []
                with tempfile.NamedTemporaryFile(suffix=".mp4") as f:
                    ani.save(f.name, writer="ffmpeg", fps=30)

                    cap = cv2.VideoCapture(f.name)
                    while True:
                        ret, frame = cap.read()
                        if not ret:
                            break
                        frames.append(frame)
                    cap.release()

                # resize to 64x64
                frames_resized = []
                for frame in frames:
                    frame_resized = cv2.resize(frame, (256, 256))
                    frames_resized.append(frame_resized)

                # convert to numpy array
                frames = np.array(frames_resized, dtype=np.uint8)
                frames = frames.transpose(0, 3, 1, 2)
                frames_list.append(frames)

            frames = np.stack(frames_list, axis=0)
            return frames

        video = get_phi_vid(base_phis, video_phis, goal_phis)
        assert video.ndim == 5
        record_video(
            runner,
            f"PhiVideo_RandomZ",
            video_trajectories,
            skip_frames=self.video_skip_frames,
            goal_images=video,
        )

    def _get_goal_conditioned_metrics_custom(self, runner):
        eval_option_metrics = {}

        self.goal_range = None
        if self.env_name in [
            "half_cheetah",
            "ant",
            "dmc_quadruped",
            "dmc_humanoid",
            "dmc_humanoid_state",
        ]:
            if self.env_name == "half_cheetah":
                self.goal_range = 100
            elif self.env_name == "ant":
                self.goal_range = 50
            elif self.env_name == "dmc_quadruped":
                self.goal_range = 15
            elif self.env_name == "dmc_humanoid":
                self.goal_range = 10
            elif self.env_name == "dmc_humanoid_state":
                self.goal_range = 40

        # Goal-conditioned metrics
        env = runner._env
        goals = []  # list of (goal_obs, goal_info)
        goal_metrics = defaultdict(list)
        if self.env_name == "kitchen":
            # goal_names = [
            #     "BottomBurner",
            #     "LightSwitch",
            #     "SlideCabinet",
            #     "HingeCabinet",
            #     "Microwave",
            #     "Kettle",
            # ]
            goal_names = [
                "_".join(c) for c in env.goal_configs
            ]  # [['bottom_burner'], ['light_switch'], ['slide_cabinet'], ['hinge_cabinet'], ['microwave'], ['kettle'], ['light_switch', 'slide_cabinet'], ['light_switch', 'hinge_cabinet'], ['light_switch', 'kettle'], ['slide_cabinet', 'hinge_cabinet'], ['slide_cabinet', 'kettle'], ['hinge_cabinet', 'kettle']]
            for i, goal_name in enumerate(goal_names):
                goal_obs = env.render_goal(goal_idx=i).copy().astype(np.float32)
                goal_obs = np.tile(goal_obs, self.frame_stack or 1).flatten()
                goals.append((goal_obs, {"goal_idx": i, "goal_name": goal_name}))
        elif self.env_name in ["antmaze-large-play", "antmaze-ultra-play"]:
            env.reset()

            base_observation = env.unwrapped._get_obs().astype(np.float32)
            print("env goals:", env.unwrapped.goals)
            for i, env_goal in enumerate(env.unwrapped.goals):
                obs_goal = base_observation.copy()
                obs_goal[:2] = env_goal
                goals.append((obs_goal, {"env_goal": env_goal}))

        elif self.env_name in ["dmc_cheetah", "dmc_quadruped", "dmc_humanoid"]:
            for i in range(20):
                env.reset()
                state = env.physics.get_state().copy()
                if self.env_name == "dmc_cheetah":
                    goal_loc = (np.random.rand(1) * 2 - 1) * self.goal_range
                    state[:1] = goal_loc
                else:
                    goal_loc = (np.random.rand(2) * 2 - 1) * self.goal_range
                    state[:2] = goal_loc
                env.physics.set_state(state)
                if self.env_name == "dmc_humanoid":
                    for _ in range(50):
                        env.step(np.zeros_like(env.action_space.sample()))
                else:
                    env.step(np.zeros_like(env.action_space.sample()))
                goal_obs = (
                    env.render(mode="rgb_array", width=64, height=64)
                    .copy()
                    .astype(np.float32)
                )
                goal_obs = np.tile(goal_obs, self.frame_stack or 1).flatten()
                goals.append((goal_obs, {"goal_loc": goal_loc}))
        elif self.env_name in ["dmc_humanoid_state"]:
            for i in range(20):
                ob = env.reset()
                state = env.physics.get_state().copy()
                goal_loc = (np.random.rand(2) * 2 - 1) * self.goal_range
                ob[:2] = goal_loc
                for _ in range(5):
                    env.step(np.zeros_like(env.action_space.sample()))
                goal_obs = ob.copy().astype(np.float32)
                goals.append((goal_obs, {"goal_loc": goal_loc}))
        elif self.env_name in [
            "ant",
            "ant_pixel",
            "half_cheetah",
        ]:
            for i in range(20):
                ob = env.reset()
                state = env.unwrapped._get_obs().copy()
                if self.env_name in ["half_cheetah"]:
                    goal_loc = (np.random.rand(1) * 2 - 1) * self.goal_range
                    state[:1] = goal_loc
                    env.set_state(state[:9], state[9:])
                elif self.env_name in ["dmc_humanoid_state"]:
                    goal_loc = (np.random.rand(2) * 2 - 1) * self.goal_range
                    ob[:2] = goal_loc
                else:
                    goal_loc = (np.random.rand(2) * 2 - 1) * self.goal_range
                    state[:2] = goal_loc
                    env.set_state(state[:15], state[15:])
                for _ in range(5):
                    env.step(np.zeros_like(env.action_space.sample()))
                if self.env_name == "ant_pixel":
                    goal_obs = (
                        env.render(mode="rgb_array", width=64, height=64)
                        .copy()
                        .astype(np.float32)
                    )
                    goal_obs = np.tile(goal_obs, self.frame_stack or 1).flatten()
                else:
                    if self.env_name == "dmc_humanoid_state":
                        goal_obs = ob.copy().astype(np.float32)
                    else:
                        goal_obs = env._apply_normalize_obs(state).astype(np.float32)
                goals.append((goal_obs, {"goal_loc": goal_loc}))

        for method in ["Single", "Adaptive"] if self.discrete else [""]:
            self.option_policy._force_use_mode_actions = True
            if len(goals) == 0:
                break
            renders = []
            coordinates = []
            for goal_obs, goal_info in goals:
                if self.env_name in ["antmaze-large-play", "antmaze-ultra-play"]:
                    env.unwrapped.set_target(goal_info["env_goal"])
                obs = env.reset()
                step = 0
                done = False
                success = 0
                option = None
                render = []
                coordinate = []
                while step < self.max_path_length and not done:
                    if self.inner:  # always the case
                        if self.goal_reaching:
                            option = goal_obs
                        else:
                            te_input = torch.from_numpy(np.stack([obs, goal_obs])).to(
                                self.device
                            )  # 2, d
                            phi_s, phi_g = self.traj_encoder(
                                self._restrict_te_obs(te_input)
                            ).mean
                            phi_s, phi_g = (
                                phi_s.detach().cpu().numpy(),
                                phi_g.detach().cpu().numpy(),
                            )
                            if self.discrete:
                                if method == "Adaptive":
                                    option = np.eye(self.dim_option)[
                                        (phi_g - phi_s).argmax()
                                    ]
                                else:
                                    if option is None:
                                        option = np.eye(self.dim_option)[
                                            (phi_g - phi_s).argmax()
                                        ]
                            else:
                                option = (phi_g - phi_s) / np.linalg.norm(phi_g - phi_s)
                    else:
                        te_input = torch.from_numpy(goal_obs[None, ...]).to(self.device)
                        phi = self.traj_encoder(self._restrict_te_obs(te_input)).mean[0]
                        phi = phi.detach().cpu().numpy()
                        if self.discrete:
                            option = np.eye(self.dim_option)[phi.argmax()]
                        else:
                            option = phi
                    action, agent_info = self.option_policy.get_action(
                        np.concatenate([obs, option])
                    )
                    next_obs, reward, done, info = env.step(action, render=True)
                    obs = next_obs

                    if self.env_name == "kitchen":
                        success = max(
                            success, env.compute_success(goal_info["goal_idx"])[0]
                        )
                    elif self.env_name in ["antmaze-large-play", "antmaze-ultra-play"]:
                        success = max(success, reward)  # assume sparse reward

                    step += 1

                    coordinate.append(info["coordinates"])

                    if True and step % 1 == 0:  # original code: if i >= num_episode
                        # https://github.com/seohongpark/HIQL/blob/b3e8366ccaec99113778bc360b19894e7a63317c/jaxrl_m/evaluation.py#L55
                        env_name = self.env_name
                        # if "antmaze" in env_name:
                        #     size = 200
                        #     cur_frame = (
                        #         env.render(
                        #             mode="rgb_array",
                        #             width=size,
                        #             height=size,
                        #             camera_id=0,
                        #         )
                        #         .transpose(2, 0, 1)
                        #         .copy()
                        #     )
                        #     render.append(cur_frame)
                        # elif (
                        #     "pointmaze" in env_name
                        #     or env_name == "ant"
                        #     or env_name == "half_cheetah"
                        #     or env_name == "dmc_humanoid_state"
                        # ):
                        #     size = 64
                        #     cur_frame = (
                        #         env.render(
                        #             mode="rgb_array",
                        #             width=size,
                        #             height=size,
                        #             camera_id=0,
                        #         )
                        #         .transpose(2, 0, 1)
                        #         .copy()
                        #     )
                        #     render.append(cur_frame)

                        # elif env_name in [
                        #     "kitchen",
                        #     "dmc_quadruped",
                        #     "dmc_humanoid",
                        # ]:
                        #     # # try kitchen pixel only
                        #     # assert (
                        #     #     next_obs.ndim == 3 and next_obs.shape[-1] % 3 == 0
                        #     # ), next_obs.shape

                        #     render.append(
                        #         next_obs.reshape(64, 64, -1)[:, :, :3]
                        #         .transpose(2, 0, 1)
                        #         .astype(np.uint8)
                        #     )

                        render_image = info["render"]
                        assert render_image.ndim == 3
                        if render_image.shape[-1] == 3:  # 64, 64, 3
                            render_image = render_image.transpose(2, 0, 1)
                        render.append(render_image)

                # render video
                renders.append(np.stack(render, axis=0))
                coordinates.append((np.stack(coordinate, axis=0), goal_info))

                # update metrics
                if self.env_name == "kitchen":
                    goal_metrics[f'Kitchen{method}Goal{goal_info["goal_name"]}'].append(
                        success
                    )
                    goal_metrics[f"Kitchen{method}GoalOverall"].append(
                        success * len(goal_names)
                    )  # we calculate the mean: so we multiply by the number of goals
                    if goal_info["goal_idx"] < 6:
                        goal_metrics[f"Kitchen{method}GoalOverall6"].append(
                            success * len(goal_names[:6])
                        )
                elif self.env_name in ["antmaze-large-play", "antmaze-ultra-play"]:
                    goal_metrics[f'Antmaze{method}Goal{goal_info["env_goal"]}'].append(
                        success
                    )
                    goal_metrics[f"Antmaze{method}GoalOverall"].append(
                        success * len(goals)
                    )
                elif self.env_name in [
                    "dmc_cheetah",
                    "dmc_quadruped",
                    "dmc_humanoid",
                    "ant",
                    "ant_pixel",
                    "half_cheetah",
                    "dmc_humanoid_state",
                ]:
                    if self.env_name in ["dmc_cheetah"]:
                        cur_loc = env.physics.get_state()[:1]
                    elif self.env_name in [
                        "dmc_quadruped",
                        "dmc_humanoid",
                        "dmc_humanoid_state",
                    ]:
                        cur_loc = env.physics.get_state()[:2]
                    elif self.env_name in ["half_cheetah"]:
                        cur_loc = env.unwrapped._get_obs()[:1]
                    else:
                        cur_loc = env.unwrapped._get_obs()[:2]
                    distance = np.linalg.norm(cur_loc - goal_info["goal_loc"])
                    squared_distance = distance**2
                    goal_metrics[f"Goal{method}Distance"].append(distance)
                    goal_metrics[f"Goal{method}SquaredDistance"].append(
                        squared_distance
                    )

            # pack "renders" as trajectory: List[{"env_infos": {"render": np.ndarray}]
            trajectories = [
                {
                    "env_infos": {"render": render},
                    "agent_infos": {"cur_exploration": np.zeros(len(render))},
                }
                for render in renders
            ]

            # set goal_images if possible
            if (
                "antmaze-large-play" in self.env_name
                or "antmaze-ultra-play" in self.env_name
            ):
                goal_images = []
                for goal_obs, goal_info in goals:
                    goal_coords = goal_info["env_goal"]  # tuple(x, y)
                    fig = Figure()
                    canvas = FigureCanvas(fig)
                    ax = fig.add_subplot()

                    env.scatter_trajectory(
                        np.array(goal_coords)[None], color="r", ax=ax
                    )
                    ax.scatter(
                        goal_coords[0], goal_coords[1], color="r", marker="*", s=100
                    )
                    # resize this figure to have the same wh as renders

                    canvas.draw()
                    image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
                    image = image.reshape(canvas.get_width_height()[::-1] + (3,))
                    goal_images.append(image)

                goal_images = np.stack(goal_images, axis=0)  # N, H, W, C
                goal_images = goal_images.transpose(0, 3, 1, 2)

            elif self.env_name in [
                "dmc_cheetah",
                "dmc_quadruped",
                "kitchen",
                "dmc_humanoid",
                "dmc_humanoid_state",
            ]:
                goal_images = None
            else:
                goal_images = None
            record_video(
                runner,
                f"Video_{method}_GoalConditioned",
                trajectories,
                skip_frames=self.video_skip_frames,
                goal_images=goal_images,
            )

        goal_metrics = {key: np.mean(value) for key, value in goal_metrics.items()}
        eval_option_metrics.update(goal_metrics)

        # Train coverage metric
        if len(self.coverage_queue) > 0:
            coverage_data = np.array(self.coverage_queue)
            if self.env_name == "kitchen":
                assert (
                    coverage_data.ndim == 2 and coverage_data.shape[-1] == 6
                ), coverage_data.shape
                coverage = coverage_data.max(axis=0)
                goal_names = [
                    "BottomBurner",
                    "LightSwitch",
                    "SlideCabinet",
                    "HingeCabinet",
                    "Microwave",
                    "Kettle",
                ]

                for i, goal_name in enumerate(goal_names):
                    eval_option_metrics[f"TrainKitchenTask{goal_name}"] = coverage[i]
                    eval_option_metrics[f"CoincidentalRate{goal_name}"] = (
                        self.coincidental_goal_success[i]
                    )
                eval_option_metrics[f"TrainKitchenOverall"] = coverage.sum()

            else:
                total_coverage_data = np.array(self.coverage_log)
                assert coverage_data.ndim == 3
                assert (
                    coverage_data.shape[-1] == 2
                    if not self.env_name in ["half_cheetah", "walker"]
                    else 1
                ), coverage_data.shape
                coverage_data = coverage_data.reshape(-1, coverage_data.shape[-1])
                total_coverage_data = total_coverage_data.reshape(
                    -1, total_coverage_data.shape[-1]
                )
                uniq_coords = np.unique(np.floor(coverage_data), axis=0)
                total_uniq_coords = np.unique(np.floor(total_coverage_data), axis=0)
                eval_option_metrics["TrainNumUniqueCoords"] = len(uniq_coords)
                eval_option_metrics["TrainTotalNumUniqueCoords"] = len(
                    total_uniq_coords
                )
                eval_option_metrics["MaxDistFromOrigin"] = np.linalg.norm(
                    coverage_data, axis=-1
                ).max(axis=0)
        else:
            if self.env_name == "kitchen":
                goal_names = [
                    "BottomBurner",
                    "LightSwitch",
                    "SlideCabinet",
                    "HingeCabinet",
                    "Microwave",
                    "Kettle",
                ]
                for i, goal_name in enumerate(goal_names):
                    eval_option_metrics[f"TrainKitchenTask{goal_name}"] = 0
                eval_option_metrics[f"TrainKitchenOverall"] = 0
            else:
                eval_option_metrics["TrainNumUniqueCoords"] = 0
                eval_option_metrics["TrainTotalNumUniqueCoords"] = 0

        return eval_option_metrics
