from collections import OrderedDict
import gzip
import pickle
from time import time

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.distributions
from torch.optim.lr_scheduler import StepLR

from .gail_agent import GAILAgent
from .dataset import DataReplayBuffer, RandomSampler
from .expert_dataset import *
from networks.discriminator import DiscriminatorEnsemble
from .expert_dataset import ExpertDataset
from utils.info_dict import Info
from utils.visual_tool import *
from utils.logger import logger
from utils.mpi import mpi_average
from utils.pytorch import (
    optimizer_cuda,
    count_parameters,
    sync_networks,
    sync_grads,
    to_tensor,
    weights_init,
    sample_from_dataloader,
    process_transition,
)
import copy
# from contrastive_learning import Encoder
# from sklearn.manifold import TSNE
import numpy as np
import os

def get_demo_files(demo_file_path):
        demos = []
        if not demo_file_path.endswith(".pkl"):
            demo_file_path = demo_file_path + "*.pkl"
        for f in glob.glob(demo_file_path):
            if os.path.isfile(f):
                demos.append(f)
        return demos

class ReachableGAILAgentV0(GAILAgent):
    def __init__(self, config, ob_space, ac_space, env_ob_space, layout):
        super().__init__(config, ob_space, ac_space, env_ob_space, layout)

        self._reachable = DiscriminatorEnsemble(config.Dr_ensemble_size, config, ob_space, ac_space if config.reach_discriminator_input else None)
        self._reachable_optim = optim.Adam(self._reachable.parameters(), lr=config.discriminator_lr)
        self._reachable.to(config.device)

        ## Initialize reachable buffer with target task demos and unreachable buffer with other task demos
        sampler = RandomSampler(image_crop_size=config.encoder_image_size)
        buffer_keys = ["ob", "ob_next", "ac", "done", "rew", "traj_start_indicator"]
        self._reachable_buffer = DataReplayBuffer(
            buffer_keys, config.buffer_size, sampler.sample_func, is_reachable_buffer=True
        )
        self._unreachable_buffer = DataReplayBuffer(
            buffer_keys, config.buffer_size, sampler.sample_func
        )

        self._reach_discriminator_loss = nn.MSELoss()

        ### ASDF: rollout_length instead of buffer size?  2000 vs 1e6
        paths = glob.glob(os.path.join(self._config.log_dir, "*.pkl"))
        if len(paths)==0:
            if not config.single_task_training: #and config.add_more_demos2reachbuffer is None:
                _other_dataset = ExpertDataset(
                    config.other_demo_path,
                    config.demo_subsample_interval,
                    ac_space,
                    use_low_level=config.demo_low_level,
                    sample_range_start=config.demo_sample_range_start,
                    sample_range_end=config.demo_sample_range_end,
                    num_task=config.num_task,
                    target_taskID=None,
                    with_taskID= config.with_taskID,
                    frame_stack=self._config.frame_stack if config.encoder_type == "cnn" else None,
                )
                self._unreachable_buffer.store_transitions(_other_dataset._data)
            
            if config.add_more_demos2reachbuffer is not None:
                in_dataset = ExpertDataset(
                            self._config.target_demo_path,
                            self._config.demo_subsample_interval,
                            self._ac_space,
                            use_low_level=self._config.demo_low_level,
                            sample_range_start=self._config.demo_sample_range_start,
                            sample_range_end=self._config.demo_sample_range_end,
                            num_task=self._config.num_task,
                            target_taskID=self._config.target_taskID,
                            num_target_demos=self._config.add_more_demos2reachbuffer,
                            target_demo_path=self._config.target_demo_path,
                            is_sqil=(self._config.algo=="sqil"),
                            frame_stack=self._config.frame_stack if self._config.encoder_type == "cnn" else None,
                        )

            self._reachable_buffer.store_transitions(copy.deepcopy(self._dataset._data if config.add_more_demos2reachbuffer is None else in_dataset._data))
            # label the reward of reachable buffer as 1, and unreachable buffer as 0
            self._reachable_buffer._buffer['rew'] = [np.ones_like(self._reachable_buffer._buffer['rew'][i]) for i in range(len(self._reachable_buffer._buffer['rew']))]
            if not config.single_task_training: #and config.add_more_demos2reachbuffer is None:
                self._unreachable_buffer._buffer['rew'] = [np.zeros_like(self._unreachable_buffer._buffer['rew'][i]) for i in range(len(self._unreachable_buffer._buffer['rew']))]
        
        ### Load OOD states
        # if config.eval_ood_states_reachability:
        #     self._ood_traj_files = get_demo_files(config.ood_traj_path)

            # load ood states into buffer #TODO: uncomment this when we have ood states, now use other expert demos
            # dataset = ExpertDataset(
            #     config.ood_traj_path,
            #     config.demo_subsample_interval,
            #     ac_space,
            #     use_low_level=config.demo_low_level,
            #     sample_range_start=config.demo_sample_range_start,
            #     sample_range_end=config.demo_sample_range_end,
            #     num_task=config.num_task,
            #     target_taskID=None,
            # )
            # dataset._data['rew'] = np.zeros_like(dataset._data['rew'])
            # self._unreachable_buffer.store_transitions(dataset._data)
        
        if self._config.gail_rl_algo=='ppo':
            self._num_batches_reach = (  # = 3
                        self._config.rollout_length
                        // self._config.batch_size
                        // self._config.reachability_discriminator_update_freq
                )
            assert self._num_batches_reach > 0

        ## Define Pretraining Counters
        self._pretrain_outer_loop = 0
        self._pretrain_inner_loop = 0

        ## test correlation between reachability and dvd reward
        self._correlation_dvd = []
        self._correlation_reach = []

        if self._config.backwards_relabelling and self._unreachable_buffer._current_size > 0:
            self._matches = []
            for i, item in enumerate(self._unreachable_buffer._buffer["traj_start_indicator"]):
                if item[0] == 1:
                    self._matches.append(i)
    
    def state_dict(self):
        return {
                "rl_agent": self._rl_agent.state_dict(),
                "discriminator_state_dict": self._discriminator.state_dict(),
                "discriminator_optim_state_dict": self._discriminator_optim.state_dict(),
                "ob_norm_state_dict": self._ob_norm.state_dict(),
                "reachable_state_dict": self._reachable.state_dict(),
                "reachable_optim_state_dict": self._reachable_optim.state_dict(),
            }

    def load_state_dict(self, ckpt):
        super().load_state_dict(ckpt)
        if "rl_agent" in ckpt:
            self._reachable.load_state_dict(ckpt["reachable_state_dict"])
            self._reachable.to(self._config.device)
            self._reachable_optim.load_state_dict(ckpt["reachable_optim_state_dict"])
            optimizer_cuda(self._reachable_optim, self._config.device)

    def sync_networks(self):
        super().sync_networks()
        sync_networks(self._reachable)


    def _sample_expert_data(self, task_id, num_samples=None):
        return sample_from_dataloader(self._data_loader[task_id], num_samples or self._config.batch_size, self._config.batch_size)

    def _sample_expert_traj(self, task_id, num_samples=None):
        return sample_from_dataloader(self._traj_loader[task_id], num_samples or self._config.batch_size, self._config.batch_size)


    def _preprocess_data(self, data, key="ob"):
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        data_o = data[key]
        if "ob" in key:
            data_o = self.normalize(data_o)
        data_o = _to_tensor(data_o)
        return data_o

    def pretrain(self):
        ### Pretrain Reachable
        # Train D_r

        info = Info()

        reachable_data, _, _ = self._reachable_buffer.sample(self._config.batch_size)
        unreachable_data, _, _ = self._unreachable_buffer.sample(self._config.batch_size) # all other tasks
        _info = self.update_reachability_discriminator(reachable_data, unreachable_data)
        info.add(_info)
        self._pretrain_inner_loop += 1

        if self._pretrain_inner_loop % self._config.reachability_update_frequency == 0:
            #self._pretrain_inner_loop == self._config.pretrain_inner_n_epochs:
            # Re-evaluate unreachable data
            num_update_steps = self._config.reachability_update_steps
            if num_update_steps == -1:
                num_update_steps = self._unreachable_buffer._current_size // self._config.batch_size
            for step_i in range(num_update_steps):
                data, episode_indx, t_samples = self._unreachable_buffer.sample(self._config.batch_size)
                self.relabel_reachability_dense(data, episode_indx, t_samples)
               
                # if self._config.is_visual_reachability and reachable_data is not None:
                #     if (step_i==0 and self._pretrain_outer_loop==0 and self._pretrain_inner_loop==self._config.reachability_update_frequency) or step_i==(num_update_steps-1):
                #         im = visualize_reachability(self, reachable_data, step_i, 're-label')
                #         info["visualized_states"] = im
                        # im = visualize_reachability_colormap(self, step_i, 're-label')
                        # info["visualized_colormap"] = im

            # eval ood states
            if self._config.eval_ood_states_reachability:               
               ood_im = visual_ood_states_reachability(self, step=(self._pretrain_outer_loop*self._config.pretrain_inner_n_epochs + self._pretrain_inner_loop), mode='video', suffix='re-label')
               if ood_im:
                   info["visualized_ood"] = ood_im

        if ((self._pretrain_outer_loop == 0 and self._pretrain_inner_loop == self._config.pretrain_warmup_inner_n_epochs)
                or (self._pretrain_outer_loop > 0 and self._pretrain_inner_loop == self._config.pretrain_inner_n_epochs)):
            self._pretrain_outer_loop += 1
            self._pretrain_inner_loop = 0
            # info["num_reachable"] = self._reachable_buffer._current_size
            # info["percent_reachable"] = self._reachable_buffer._current_size / (self._reachable_buffer._current_size + self._unreachable_buffer._current_size) * 100
            # print(f"Pretrain Epoch: {self._pretrain_outer_loop}, Reachable States: {self._reachable_buffer._current_size}, Percent Reachable: {info['percent_reachable']}")

        # else:
        #     print(
        #         f"Pretrain Epoch: {self._pretrain_outer_loop}, Pretrain Step: {self._pretrain_inner_loop} Reachable Loss: {info['reachable_loss']}, Pos Output: {info['reachable_pos_output']}, Neg Output: {info['reachable_neg_output']}")
  
        done = self._pretrain_outer_loop > self._config.pretrain_n_epochs
        # if done:
        #     im_reach, im_unreach = visualize_reachability_all(self, suffix='re-label')
        #     info["visualized_states_reach"] = im_reach
        #     info["visualized_states_unreach"] = im_unreach
        return mpi_average(info.get_dict(only_scalar=False)), done

    def relabel_reachability_dense(self, unreachable_data, episode_indx, t_samples):
        index = np.random.randint(len(episode_indx), size=1).item()
        ob = self._unreachable_buffer._buffer['ob_next'][episode_indx[index]][t_samples[index]]['ob']
        ob_verify = unreachable_data['ob_next']['ob'][index]
        assert np.all(ob == ob_verify), "ob != ob_verify: ep_index and t_index are not correct! unreachable_data mismatch with unreachable_buffer!" 
        
        o_next = self._preprocess_data(unreachable_data, key="ob_next")
        ac = None
        if self._config.reach_discriminator_input==2:
            ac = self._preprocess_data(unreachable_data, key="ac")
        elif self._config.reach_discriminator_input==1:
            _ac_list = []
            for e_idx, t_idx in zip(episode_indx, t_samples): # add ac for these transitions
                data1 = self._unreachable_buffer._buffer['ob_next'][e_idx][t_idx]['ob']
                if e_idx == self._unreachable_buffer._current_size - 1:
                    _ac_list.append(np.zeros_like(self._unreachable_buffer._buffer['ac'][e_idx][t_idx]['ac']))
                    continue
                data2 = self._unreachable_buffer._buffer['ob'][e_idx + 1][t_idx]['ob']
                
                if (data1 == data2).all():
                    _ac = self._unreachable_buffer._buffer['ac'][e_idx+1][t_idx]['ac']
                    _ac_list.append(_ac)
                else: # if the state is the last state, then the action is 0
                    _ac_list.append(np.zeros_like(self._unreachable_buffer._buffer['ac'][e_idx][t_idx]['ac']))
            ac = {"ac": np.array(_ac_list)}
            ac = self._preprocess_data(ac, key="ac")

        if ac is not None:
            logit, _ = self._reachable(o_next, ac)
        else:
            logit, _ = self._reachable(o_next)
        output = torch.sigmoid(logit).squeeze().detach().cpu().numpy()

        for idx, (ep_indx, t_indx) in enumerate(zip(episode_indx, t_samples)):
            new_rew = output[idx]
            self._unreachable_buffer._buffer['rew'][ep_indx][t_indx] = new_rew #if self._config.dense_reward_type is None else max(new_rew, self._unreachable_buffer._buffer['rew'][ep_indx][t_indx])
            
            if self._config.backwards_relabelling:
                if ep_indx in self._matches: # if the state is the beginning of a trajectory, then
                    continue
                # find the index of the first element in self._matches that is greater than ep_indx then the previous element is the index we want
                match_indx = np.searchsorted(self._matches, ep_indx, side='right') - 1
                ep_start_indx = self._matches[match_indx]

                def update_reward_range(start, end, step=1):
                    for i in range(start, end, -1):
                        new_rew = output[idx] - self._config.dense_reward_scale * step #output[idx]*np.float_power(0.99, step)
                        step += 1
                        if self._unreachable_buffer._buffer['rew'][i][t_indx] > self._config.relabel_skip_threshold:
                            continue
                        if new_rew < self._config.backward_relabel_threshold:
                            break
                        self._unreachable_buffer._buffer['rew'][i][t_indx] = new_rew #if self._config.dense_reward_type is None else max(new_rew, self._unreachable_buffer._buffer['rew'][ep_indx][t_indx])
                        
                    return step

                if ep_indx < ep_start_indx:
                    step = update_reward_range(ep_indx - 1, -1, 1)
                   
                    if step == (ep_indx + 1):
                        step = update_reward_range(self._config.buffer_size - 1, ep_start_indx - 1, step)
                else:
                    step = update_reward_range(ep_indx - 1, ep_start_indx - 1, step=1)

    def update_reachability_discriminator(self, reachable_data, unreachable_data):
        info = Info()

        reachable_o = self._preprocess_data(reachable_data)
        unreachable_o = self._preprocess_data(unreachable_data)

        if self._config.reach_discriminator_input:
            reachable_ac = self._preprocess_data(reachable_data, key="ac")
            unreachable_ac = self._preprocess_data(unreachable_data, key="ac")
        else:
            reachable_ac = None
            unreachable_ac = None

        pos_logit, _ = self._reachable(reachable_o, reachable_ac)
        neg_logit, _ = self._reachable(unreachable_o, unreachable_ac)

        pos_output = torch.sigmoid(pos_logit)
        neg_output = torch.sigmoid(neg_logit)

        rew_neg = self._preprocess_data(unreachable_data, key="rew")
        # resize rew_neg to match the shape of neg_logit (1, batch_size, 1)
        rew_neg = rew_neg.unsqueeze(dim=0)
        rew_neg = rew_neg.unsqueeze(dim=-1)
        neg_loss = self._reach_discriminator_loss(neg_output, rew_neg)
        pos_loss = self._reach_discriminator_loss(pos_output, torch.ones_like(pos_output).to(self._config.device))

        logits = torch.cat([neg_logit, pos_logit], dim=0)
        entropy = torch.distributions.Bernoulli(logits=logits).entropy().mean()
        entropy_loss = -self._config.gail_entropy_loss_coeff * entropy

        reachable_loss = neg_loss + pos_loss + entropy_loss

        self._reachable.zero_grad()
        reachable_loss.backward()
        sync_grads(self._reachable)
        self._reachable_optim.step()

        info["reachable_pos_output"] = pos_output.mean().detach().cpu().item()
        info["reachable_neg_output"] = neg_output.mean().detach().cpu().item()
        info["reachable_loss"] = reachable_loss.detach().cpu().item()

        pos_outputs = pos_output.mean(dim=0)
        neg_outputs = neg_output.mean(dim=0)
        for i in range(pos_outputs.shape[0]):
            info[f"reachable{i}_pos_output"] = pos_outputs[i]
            info[f"reachable{i}_neg_output"] = neg_outputs[i]

        return mpi_average(info.get_dict(only_scalar=True))

    def store_episode(self, rollouts, step=0):

        super().store_episode(rollouts)
        new_data = copy.deepcopy(rollouts)

        if self._config.is_DVD_relabel and step>1000+self._config.warm_up_steps:
            # move data to reach buffer if self._discriminator(s,a)>0.8
            from utils.pytorch import obs2tensor
            bs = len(new_data['ob']) if isinstance(new_data['ob'], list) else new_data['ob'].shape[0]

            expert_data_pos = self._sample_expert_traj(self._config.target_task_index_in_demo_path, num_samples=bs)
            e_o_pos = self._preprocess_data(expert_data_pos)
            ob = self.normalize(new_data['ob'])
            ob = obs2tensor(ob, self._config.device)
            e_ac_pos = self._preprocess_data(expert_data_pos, key="ac") if not self._config.gail_no_action else None
            ac = obs2tensor(new_data["ac"], self._config.device) if not self._config.gail_no_action else None

            with torch.no_grad():
                ret = torch.sigmoid(self._discriminator(ob, ac, e_o_pos, e_ac_pos))
                reachable = ret.squeeze(dim=-1) > self._config.dvd_relabel_threshold
                reachable_idx = np.where(reachable.cpu().numpy())[0]
                percent_reachable = len(reachable_idx) / len(reachable)
                if len(reachable_idx) != 0:
                    # store the data to reach buffer and remove them from new_data
                    reachable_data = process_transition(reachable_idx, new_data)

                    reachable_idx = list(reachable_idx)
                    reachable_idx.sort(reverse=True)
                    for idx in reachable_idx:
                        for k, v in new_data.items():
                                new_data[k].pop(idx)
                    self._reachable_buffer.store_transitions(reachable_data)
                    if len(reachable)==len(reachable_idx):
                        return {'DVD_relabel': percent_reachable}

        if self._config.presort_policy_samples and (self._config.pretrain_prox or (self._config.pretrain_prox==False and step>0)):
            # get obs from rollouts
            new_transitions = {}
            for k, v in new_data.items():
                if isinstance(v[0], dict):
                    sub_keys = v[0].keys()
                    new_transitions[k] = {
                        sub_key: np.stack([v_[sub_key] for v_ in v]) for sub_key in sub_keys
                    }
                else:
                    new_transitions[k] = np.stack(v)

            obs = self._preprocess_data(new_transitions, key="ob")
            if self._config.reach_discriminator_input:
                ac = self._preprocess_data(new_transitions, key="ac")
                logit, _ = self._reachable(obs, ac)
            else:
                logit, _ = self._reachable(obs)
            output = torch.sigmoid(logit)

            # Collate reachable transitions
            # TODO: add backwards trajectory looking
            # update rew
            output = output.reshape(-1)
            output = output.detach().cpu().numpy()
            new_transitions['rew'] = output

            _data = process_transition([i for i in range(output.shape[0])], new_transitions)
            self._unreachable_buffer.store_transitions(_data)

        else:
            new_data['rew'] = np.zeros_like(new_data['rew'])
            
            self._unreachable_buffer.store_transitions(new_data)
        
        if self._config.backwards_relabelling and self._unreachable_buffer._current_size > 0:
            self._matches = []
            for i, item in enumerate(self._unreachable_buffer._buffer["traj_start_indicator"]):
                if item[0] == 1:
                    self._matches.append(i)
            
        if self._config.is_DVD_relabel and step>1000+self._config.warm_up_steps:
            return {'DVD_relabel': percent_reachable}

    def train(self, step=0):
        train_info = Info()

        # train GAIL discriminator
        self._discriminator_lr_scheduler.step()

        ### TO DO: re-factor so it fits with new SAC code
        if self._config.gail_rl_algo!='ppo':
            self._num_batches = (self._update_iter % self._config.discriminator_update_freq == 0)
            self._num_batches_reach = (self._update_iter % self._config.reachability_discriminator_update_freq == 0)
            self._update_iter += 1

        #dis_train_time = time()
        for _ in range(self._num_batches):
            policy_data = self._buffer.sample(self._config.batch_size)
            try:
                expert_data = next(self._data_iter)
            except StopIteration:
                self._data_iter = iter(self._data_loader)
                expert_data = next(self._data_iter)
            
            _train_info = self._update_discriminator(policy_data, expert_data, step=step)
            train_info.add(_train_info)
        #train_info.add({"dis_train_time": time() - dis_train_time})

        #prox_time = time()
        for _ in range(self._num_batches_reach):
            reachable_data, _, _ = self._reachable_buffer.sample(self._config.batch_size)
            unreachable_data, _, _ = self._unreachable_buffer.sample(self._config.batch_size)  # all other tasks
            _info = self.update_reachability_discriminator(reachable_data, unreachable_data)
            train_info.add(_info)
        #train_info.add({"prox_train_time": time() - prox_time})

        ### Re-label unreachable data?
        relabel_time = time()
        if self._num_batches_reach and step % self._config.reachability_online_update_frequency == 0:
            num_update_steps = self._config.reachability_online_update_steps
            if num_update_steps == -1:
                num_update_steps = self._unreachable_buffer._current_size // self._config.batch_size
            for _ in range(num_update_steps):
                data, episode_indx, t_samples = self._unreachable_buffer.sample(self._config.batch_size)
                self.relabel_reachability_dense(data, episode_indx, t_samples)
              
                #     train_info["num_reachable"] = self._reachable_buffer._current_size
                #     train_info["percent_reachable"] = self._reachable_buffer._current_size / (self._reachable_buffer._current_size + self._unreachable_buffer._current_size) * 100

            # if self._config.is_visual_reachability and reachable_data is not None:
            #     im = visualize_reachability(self, reachable_data, step, 'online')
            #     train_info["visualized_states"] = im
                # im = visualize_reachability_colormap(self, step, 'online')
                # train_info["visualized_colormap"] = im
            
            # eval ood states
            # if self._config.eval_ood_states_reachability and step % self._config.eval_ood_states_reachability_interval == 0:
            #    ood_im = visual_ood_states_reachability(self, step, mode='video', suffix='online')
            #    if ood_im:
            #        train_info["visualized_ood"] = ood_im
        train_info.add({"relabel_time": time() - relabel_time})

        # train RL agent
        #rl_train_time = time()
        _train_info = self._rl_agent.train()
        train_info.add(_train_info)
        #train_info.add({"rl_train_time": time() - rl_train_time})
        
        for _ in range(self._num_batches):
            try:
                expert_data = next(self._data_iter)
            except StopIteration:
                self._data_iter = iter(self._data_loader)
                expert_data = next(self._data_iter)
            self.update_normalizer(expert_data["ob"])

        return train_info.get_dict(only_scalar=False)

    def _predict_reward(self, ob, ac, ID=None, *argv):
        if self._config.gail_no_action:
            ac = None

        with torch.no_grad():
            ### GAIL Discriminator reward
            ret, _ = self._discriminator(ob, ac) # pre-sigmoid
            eps = 1e-10
            s = torch.sigmoid(ret)
            if self._config.gail_reward == "vanilla":
                gail_reward = -(1 - s + eps).log()
            elif self._config.gail_reward == "gan":
                gail_reward = (s + eps).log() - (1 - s + eps).log()
            elif self._config.gail_reward == "d":
                gail_reward = ret

            reward = gail_reward
            ### Reachability Discriminator
            if self._config.reach_discriminator_input:
                reach, _ = self._reachable(ob, ac)
            else:
                reach, _ = self._reachable(ob)
            mean_reach = torch.mean(reach, dim=0)
            median_reach = torch.median(reach, dim=0).values
            reach = median_reach
            ### TODO: how to combine reachability reward?
            assert self._config.gail_reward == "d" # reach rewards would also need to be scaled appropriately for other types of gail rewards
            if self._config.reach_reward == "sum":
                reach_reward = reach * self._config.reach_reward_scale # 1 by default
                reward = reward + reach_reward
            elif self._config.reach_reward == "sum_negative":
                reach -= self._config.reach_reward_constant
                reach_reward = torch.minimum(reach, torch.zeros_like(reach)) * self._config.reach_reward_scale
                reward = reward + reach_reward # only penalize unreachable states
            elif self._config.reach_reward == "tier_sum":
                reach_reward = torch.multiply(torch.sigmoid(reach) - 1, s < self._config.gail_discriminator_threshold) * self._config.reach_reward_scale
                # sets reach_reward to 0 if s < self._config.gail_discriminator_threshold
                gail_reward = s
                reward = s + reach_reward
            elif self._config.reach_reward == "reach_only":
                reach -= self._config.reach_reward_constant
                reach_reward = reach * self._config.reach_reward_scale
                reward = reach

            # check the correlation bewteen reachability and dvd rewards
            # if gail_reward.shape[0] == 1:
            #     self._correlation_dvd.extend(gail_reward.cpu().numpy())
            #     self._correlation_reach.extend(reach_reward.cpu().numpy())
            # corr = None
            # if len(self._correlation_dvd) > 10000:
            #     _to_tensor = lambda x: to_tensor(x, self._config.device)
            #     dvd = np.array(self._correlation_dvd)
            #     reach = np.array(self._correlation_reach)
            #     corr = np.corrcoef(dvd, reach)[0, 1]
            #     corr = _to_tensor(corr)
            #     self._correlation_dvd = []
            #     self._correlation_reach = []

        reward_dict = {"rew": reward, "gail_rew": gail_reward, "sig_gail_rew": torch.sigmoid(gail_reward), "reach_rew": reach_reward, "mean_reach": mean_reach, "median_reach": median_reach, "reach": reach, "sig_reach": torch.sigmoid(reach)}
        # if corr is not None:
        #     reward_dict["corr"] = corr
        return reward_dict

