import os
from collections import deque, defaultdict

import numpy as np
import torch
from baselines.common.running_mean_std import RunningMeanStd
from algos import RolloutStorage
from util import is_discrete_actions, get_obs_at_index, set_obs_at_index
import re
from scipy import stats as st

import time

import matplotlib as mpl
import matplotlib.pyplot as plt
# mpl.use("macOSX")
import wandb


class AdversarialRunner(object):
    """
    Performs rollouts of an adversarial environment, given
    protagonist (agent), antogonist (adversary_agent), and
    environment adversary (advesary_env)
    """
    def __init__(
        self,
        args,
        agent,
        venv,
        ued_venv=None,
        adversary_agent=None,
        adversary_env=None,
        train=False,
        level_store_max_size=1000,
        device='cpu'):
        """
        venv: Vectorized, adversarial gym env with agent-specific wrappers.
        agent: Protagonist trainer.
        ued_venv: Vectorized, adversarial gym env with adversary-env-specific wrappers.
        adversary_agent: Antogonist trainer.
        adversary_env: Environment adversary trainer.
        """
        # For ZPD
        self.total_zpd = 0
        self.total_zpd_replacements = 0
        self.num_replaced = 0
        self.num_samples = 0

        self.args = args

        self.venv = venv
        if ued_venv is None:
            self.ued_venv = venv
        else:
            self.ued_venv = ued_venv # Since adv env can have different env wrappers

        self.is_discrete_actions = is_discrete_actions(self.venv)
        self.is_discrete_adversary_env_actions = is_discrete_actions(self.venv, adversary=True)

        self.agents = {
            'agent': agent,
            'adversary_agent': adversary_agent,
            'adversary_env': adversary_env,
        }

        self.agent_rollout_steps = args.num_steps
        self.adversary_env_rollout_steps = self.venv.adversary_observation_space['time_step'].high[0]

        self.is_dr = args.ued_algo == 'domain_randomization'
        self.is_training_env = args.ued_algo in ['minimax', 'paired', 'flexible_paired']
        self.is_paired = args.ued_algo in ['paired', 'flexible_paired']

        # Track running mean and std of env returns for return normalization
        if args.adv_normalize_returns:
            self.env_return_rms = RunningMeanStd(shape=())

        self.device = device

        if train:
            self.train()
        else:
            self.eval()

        self.reset()

    def reset(self):
        self.num_updates = 0
        self.total_episodes_collected = 0

        max_return_queue_size = 10
        self.agent_returns = deque(maxlen=max_return_queue_size)
        self.adversary_agent_returns = deque(maxlen=max_return_queue_size)

    def train(self):
        self.is_training = True
        [agent.train() if agent else agent for _,agent in self.agents.items()]

    def eval(self):
        self.is_training = False
        [agent.eval() if agent else agent for _,agent in self.agents.items()]

    def state_dict(self):
        agent_state_dict = {}
        optimizer_state_dict = {}
        for k, agent in self.agents.items():
            if agent:
                agent_state_dict[k] = agent.algo.actor_critic.state_dict()
                optimizer_state_dict[k] = agent.algo.optimizer.state_dict()

        return {
            'agent_state_dict': agent_state_dict,
            'optimizer_state_dict': optimizer_state_dict,
            'agent_returns': self.agent_returns,
            'adversary_agent_returns': self.adversary_agent_returns,
            'num_updates': self.num_updates,
            'total_episodes_collected': self.total_episodes_collected,
        }

    def load_state_dict(self, state_dict):
        agent_state_dict = state_dict.get('agent_state_dict')
        for k,state in agent_state_dict.items():
            self.agents[k].algo.actor_critic.load_state_dict(state)

        optimizer_state_dict = state_dict.get('optimizer_state_dict')
        for k,state in optimizer_state_dict.items():
            self.agents[k].algo.optimizer.load_state_dict(state)

        self.agent_returns = state_dict.get('agent_returns')
        self.adversary_agent_returns = state_dict.get('adversary_agent_returns')
        self.num_updates = state_dict.get('num_updates')
        self.total_episodes_collected = state_dict.get('total_episodes_collected')

    def _get_rollout_return_stats(self, rollout_returns):
        # @todo: need to record agent curricula-specific env metrics:
        # - shortest path length
        # - num blocks
        # - passable ratio
        mean_return = torch.zeros(self.args.num_processes, 1)
        max_return = torch.zeros(self.args.num_processes, 1)
        for b, returns in enumerate(rollout_returns):
            if len(returns) > 0:
                mean_return[b] = float(np.mean(returns))
                max_return[b] = float(np.max(returns))

        stats = {
            'mean_return': mean_return,
            'max_return': max_return,
            'returns': rollout_returns
        }

        return stats

    def _get_env_stats_multigrid(self, agent_info, adversary_agent_info):
        num_blocks = np.mean(self.venv.get_num_blocks())
        passable_ratio = np.mean(self.venv.get_passable())
        shortest_path_lengths = self.venv.get_shortest_path_length()
        shortest_path_length = np.mean(shortest_path_lengths)

        if 'max_returns' in adversary_agent_info:
            solved_idx = \
                (torch.max(agent_info['max_return'], \
                    adversary_agent_info['max_return']) > 0).numpy().squeeze()
        else:
            solved_idx = (agent_info['max_return'] > 0).numpy().squeeze()

        solved_path_lengths = np.array(shortest_path_lengths)[solved_idx]
        solved_path_length = np.mean(solved_path_lengths) if len(solved_path_lengths) > 0 else 0

        stats = {
            'num_blocks': num_blocks,
            'passable_ratio': passable_ratio,
            'shortest_path_length': shortest_path_length,
            'solved_path_length': solved_path_length
        }

        return stats

    def _get_env_stats_minihack(self, agent_info, adversary_agent_info):
        stats = self._get_env_stats_multigrid(agent_info, adversary_agent_info)

        num_monsters = np.mean(self.venv.get_num_monsters())
        num_lava = np.mean(self.venv.get_num_lava())
        num_walls = np.mean(self.venv.get_num_walls())
        num_doors = np.mean(self.venv.get_num_doors())

        stats.update({
            'num_lava': num_lava,
            'num_monsters': num_monsters,
            'num_doors': num_doors,
            'num_walls': num_walls,
        })

        return stats

    def _get_env_stats(self, agent_info, adversary_agent_info):
        env_name = self.args.env_name
        if env_name.startswith('MultiGrid'):
            stats = self._get_env_stats_multigrid(agent_info, adversary_agent_info)
        elif env_name.startswith('MiniHack'):
            stats = self._get_env_stats_minihack(agent_info, adversary_agent_info)
        else:
            raise ValueError(f'Unsupported environment, {self.args.env_name}')

        return stats

    def agent_evaluation_rollout(
        self,
        agent,
        num_steps,
        is_env=False,
        fixed_seeds=None
    ):
        args = self.args
        obs = self.venv.reset_agent()

        # Temporary storage:
        temp_storage = RolloutStorage(
            model=agent.storage.model,
            num_steps=agent.storage.num_steps,
            num_processes=agent.storage.num_processes,
            observation_space=agent.storage.observation_space,
            action_space=agent.storage.action_space,
            recurrent_hidden_state_size=agent.storage.recurrent_hidden_state_size,
            recurrent_arch=agent.storage.recurrent_arch,
            use_proper_time_limits=agent.storage.use_proper_time_limits,
            use_popart=agent.storage.use_popart,
            device=agent.storage.device,
        )

        # Initialize first observation
        temp_storage.copy_obs_to_index(obs,0)
        experiences = [{'obs': obs}]
        actions = []
        rollout_returns = [[] for _ in range(args.num_processes)]
        for step in range(num_steps):
            if args.render:
                self.venv.render_to_screen()
            # Sample actions
            with torch.no_grad():
                obs_id = temp_storage.get_obs(step)
                fwd_start =time.time()
                value, action, action_log_dist, recurrent_hidden_states = agent.act(
                    obs_id, temp_storage.get_recurrent_hidden_state(step), temp_storage.masks[step].to(temp_storage.device))

                if self.is_discrete_actions:
                    action_log_prob = action_log_dist.gather(-1, action)
                else:
                    action_log_prob = action_log_dist

            # Observe reward and next obs
            reset_random = self.is_dr
            _action = agent.process_action(action.cpu())
            actions.append(_action)
            if is_env:
                obs, reward, done, infos = self.ued_venv.step_adversary(_action)
            else:
                obs, reward, done, infos = self.venv.step_env(_action, reset_random=reset_random)
                if args.clip_reward:
                    reward = torch.clamp(reward, -args.clip_reward, args.clip_reward)

            # Handle early termination due to cliffhanger rollout
            if not is_env and step >= num_steps - 1:
                if temp_storage.use_proper_time_limits:
                    for i, done_ in enumerate(done):
                        if not done_:
                            infos[i]['cliffhanger'] = True
                            infos[i]['truncated'] = True
                            infos[i]['truncated_obs'] = get_obs_at_index(obs, i)

                done = np.ones_like(done, dtype=np.float)

            for i, info in enumerate(infos):
                if 'episode' in info.keys():
                    rollout_returns[i].append(info['episode']['r'])

                    if not is_env:
                        self.total_episodes_collected += 1

                        # Handle early termination
                        if temp_storage.use_proper_time_limits:
                            if 'truncated_obs' in info.keys():
                                truncated_obs = info['truncated_obs']
                                temp_storage.insert_truncated_obs(truncated_obs, index=i)

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'truncated' in info.keys() else [1.0]
                 for info in infos])
            cliffhanger_masks = torch.FloatTensor(
                [[0.0] if 'cliffhanger' in info.keys() else [1.0]
                 for info in infos])

            temp_storage.insert(
                obs, recurrent_hidden_states,
                action, action_log_prob, action_log_dist,
                value, reward, masks, bad_masks,
                cliffhanger_masks=cliffhanger_masks)

            experiences.append({
                'obs': obs,
                'recurrent_hidden_states': recurrent_hidden_states,
                'action': action,
                'action_log_prob': action_log_prob,
                'action_log_dist': action_log_dist,
                'value': value,
                'reward': reward,
                'masks': masks,
                'bad_masks': bad_masks,
                'cliffhanger_masks': cliffhanger_masks,
            })

        rollout_info = self._get_rollout_return_stats(rollout_returns)
        experiences[0]['rollout_returns'] = rollout_returns

        return rollout_info, experiences

    def update_agent_with_experiences(self, agent, experiences, num_steps):
        args = self.args
        # Initialize first observation
        agent.storage.copy_obs_to_index(experiences[0]['obs'], 0)
        rollout_returns = experiences[0]['rollout_returns']

        for step in range(num_steps):
            experience_i = experiences[step+1]
            obs = experience_i['obs']
            recurrent_hidden_states = experience_i['recurrent_hidden_states']
            action = experience_i['action']
            action_log_prob = experience_i['action_log_prob']
            action_log_dist = experience_i['action_log_dist']
            value = experience_i['value']
            reward = experience_i['reward']
            masks = experience_i['masks']
            bad_masks = experience_i['bad_masks']
            cliffhanger_masks = experience_i['cliffhanger_masks']

            agent.insert(
                obs, recurrent_hidden_states,
                action, action_log_prob, action_log_dist,
                value, reward, masks, bad_masks,
                cliffhanger_masks=cliffhanger_masks)

        rollout_info = self._get_rollout_return_stats(rollout_returns)

        # Update non-env agent if required
        with torch.no_grad():
            obs_id = agent.storage.get_obs(-1)
            next_value = agent.get_value(
                obs_id, agent.storage.get_recurrent_hidden_state(-1),
                agent.storage.masks[-1]).detach()

        agent.storage.compute_returns(
            next_value, args.use_gae, args.gamma, args.gae_lambda)

        value_loss, action_loss, dist_entropy, info = agent.update()

        rollout_info.update({
            'value_loss': value_loss,
            'action_loss': action_loss,
            'dist_entropy': dist_entropy,
            'update_info': info,
        })

        # Compute LZ complexity of action trajectories
        if args.log_action_complexity:
            rollout_info.update({'action_complexity': agent.storage.get_action_complexity()})

        return rollout_info

    def agent_rollout(self,
                      agent,
                      num_steps,
                      update=False,
                      is_env=False,
                      fixed_seeds=None):
        args = self.args
        if is_env:
            if self.is_dr:
                obs = self.ued_venv.reset_random() # don't need obs here
                return
            else:
                obs = self.ued_venv.reset() # Prepare for constructive rollout
        else:
            obs = self.venv.reset_agent()

        # Initialize first observation
        agent.storage.copy_obs_to_index(obs,0)

        mean_return = 0

        actions = []
        rollout_returns = [[] for _ in range(args.num_processes)]
        for step in range(num_steps):
            if args.render:
                self.venv.render_to_screen()
            # Sample actions
            with torch.no_grad():
                obs_id = agent.storage.get_obs(step)
                fwd_start =time.time()
                value, action, action_log_dist, recurrent_hidden_states = agent.act(
                    obs_id, agent.storage.get_recurrent_hidden_state(step), agent.storage.masks[step])

                if self.is_discrete_actions:
                    action_log_prob = action_log_dist.gather(-1, action)
                else:
                    action_log_prob = action_log_dist

            # Observe reward and next obs
            reset_random = self.is_dr
            _action = agent.process_action(action.cpu())
            actions.append(_action)
            if is_env:
                obs, reward, done, infos = self.ued_venv.step_adversary(_action)
            else:
                obs, reward, done, infos = self.venv.step_env(_action, reset_random=reset_random)
                if args.clip_reward:
                    reward = torch.clamp(reward, -args.clip_reward, args.clip_reward)

            # Handle early termination due to cliffhanger rollout
            if not is_env and step >= num_steps - 1:
                if agent.storage.use_proper_time_limits:
                    for i, done_ in enumerate(done):
                        if not done_:
                            infos[i]['cliffhanger'] = True
                            infos[i]['truncated'] = True
                            infos[i]['truncated_obs'] = get_obs_at_index(obs, i)

                done = np.ones_like(done, dtype=np.float)

            for i, info in enumerate(infos):
                if 'episode' in info.keys():
                    rollout_returns[i].append(info['episode']['r'])

                    if not is_env:
                        self.total_episodes_collected += 1

                        # Handle early termination
                        if agent.storage.use_proper_time_limits:
                            if 'truncated_obs' in info.keys():
                                truncated_obs = info['truncated_obs']
                                agent.storage.insert_truncated_obs(truncated_obs, index=i)

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'truncated' in info.keys() else [1.0]
                 for info in infos])
            cliffhanger_masks = torch.FloatTensor(
                [[0.0] if 'cliffhanger' in info.keys() else [1.0]
                 for info in infos])

            agent.insert(
                obs, recurrent_hidden_states,
                action, action_log_prob, action_log_dist,
                value, reward, masks, bad_masks,
                cliffhanger_masks=cliffhanger_masks)

        rollout_info = self._get_rollout_return_stats(rollout_returns)

        # Update non-env agent if required
        if not is_env and update:
            with torch.no_grad():
                obs_id = agent.storage.get_obs(-1)
                next_value = agent.get_value(
                    obs_id, agent.storage.get_recurrent_hidden_state(-1),
                    agent.storage.masks[-1]).detach()

            agent.storage.compute_returns(
                next_value, args.use_gae, args.gamma, args.gae_lambda)

            value_loss, action_loss, dist_entropy, info = agent.update()
            # value_loss, action_loss, dist_entropy, info = 0,0,0,{}

            rollout_info.update({
                'value_loss': value_loss,
                'action_loss': action_loss,
                'dist_entropy': dist_entropy,
                'update_info': info,
            })

            # Compute LZ complexity of action trajectories
            if args.log_action_complexity:
                rollout_info.update({'action_complexity': agent.storage.get_action_complexity()})

        if is_env:
            rollout_info['actions'] = actions
        return rollout_info

    def _compute_env_return(self, agent_info, adversary_agent_info):
        args = self.args

        if args.ued_algo == 'paired':
            if self.args.reward == 'nominal':
                env_return = torch.max(adversary_agent_info['max_return'] - agent_info['mean_return'], \
                    torch.zeros_like(agent_info['mean_return']))
            elif self.args.reward == 'mean_regret':
                env_return = torch.max(
                    adversary_agent_info['mean_return'] - agent_info['mean_return'],
                    torch.zeros_like(agent_info['mean_return']))
            elif self.args.reward == 'maximize_return':
                env_return = agent_info['mean_return']
            elif self.args.reward == 'minimize_return':
                env_return = -agent_info['mean_return']
            else:
                raise ValueError('Invalid reward type: {}'.format(self.args.reward))

        elif args.ued_algo == 'flexible_paired':
            env_return = torch.zeros_like(agent_info['max_return'], dtype=torch.float, device=self.device)
            adversary_agent_max_idx = adversary_agent_info['max_return'] > agent_info['max_return']
            agent_max_idx = ~adversary_agent_max_idx

            env_return[adversary_agent_max_idx] = \
                adversary_agent_info['max_return'][adversary_agent_max_idx]
            env_return[agent_max_idx] = agent_info['max_return'][agent_max_idx]

            env_mean_return = torch.zeros_like(env_return, dtype=torch.float)
            env_mean_return[adversary_agent_max_idx] = \
                agent_info['mean_return'][adversary_agent_max_idx]
            env_mean_return[agent_max_idx] = \
                adversary_agent_info['mean_return'][agent_max_idx]

            env_return = torch.max(env_return - env_mean_return, torch.zeros_like(env_return))

        elif args.ued_algo == 'minimax':
            env_return = -agent_info['max_return']

        else:
            env_return = torch.zeros_like(agent_info['mean_return'])

        if args.adv_normalize_returns:
            self.env_return_rms.update(env_return.flatten().cpu().numpy())
            env_return /= np.sqrt(self.env_return_rms.var + 1e-8)

        if args.adv_clip_reward is not None:
            clip_max_abs = args.adv_clip_reward
            env_return = env_return.clamp(-clip_max_abs, clip_max_abs)

        return env_return

    def determine_zpd_indexes(self, regret):
        keep_mask = (regret > 0)[:, 0]
        if self.args.zpd_quantile:
            lower = torch.Tensor(np.quantile(regret, 0.25, 0, method='nearest'))
            upper = torch.Tensor(np.quantile(regret, 0.75, 0, method='nearest'))
            keep_mask = torch.logical_and(lower <= regret, regret <= upper)[:, 0]
        else:
            keep_mask = (regret > 0)[:, 0]

        # Get idxs for non-negative regrets
        non_negative_idxs = torch.nonzero(keep_mask).squeeze(1)
        self.total_zpd += keep_mask.numel()
        self.num_replaced = keep_mask.numel() - non_negative_idxs.numel()
        self.num_samples = keep_mask.numel()
        # If there are no non-negative regrets, then keep all indices
        if non_negative_idxs.numel() == 0:
            idxs = torch.arange(keep_mask.numel())
        # If not all indices are non-negative, then fill in the rest with random non-negative indices
        elif non_negative_idxs.numel() < keep_mask.numel():
            # Get difference in shape between keep_mask and non_negative_idxs
            diff = keep_mask.numel() - non_negative_idxs.numel()
            self.total_zpd_replacements += diff
            # Randomly sample #`diff` indices from the non-negative indices
            rand_idxs = torch.randint(0, non_negative_idxs.numel(), (diff,), device=self.device)
            rand_idxs = non_negative_idxs[rand_idxs]
            # Concatenate the non-negative indices with the randomly sampled indices
            idxs = torch.cat((non_negative_idxs, rand_idxs), dim=0)
        # If all indices are non-negative, then keep all indices
        else:
            self.total_zpd_replacements += non_negative_idxs.numel()
            idxs = non_negative_idxs
        return idxs

    def remap_experiences(self, experiences, indexes):
        for experience in experiences:
            for key, value in experience.items():
                if isinstance(value, dict):
                    for v_key, v_value in value.items():
                        value[v_key] = v_value[indexes]
                elif isinstance(value, tuple):
                    assert len(value) == 2, 'not h, c of hidden state'
                    h, c = value
                    h = h[indexes]
                    c = c[indexes]
                    experience[key] = (h, c)
                elif isinstance(value, list):
                    experience[key] = [value[idx] for idx in indexes]
                else:
                    experience[key] = value[indexes]
        return experiences

    def paired_zpd_rollout(self, adversary_env, agent, adversary_agent):
        # Generate a batch of adversarial environments
        env_info = self.agent_rollout(
            agent=adversary_env,
            num_steps=self.adversary_env_rollout_steps, # 52
            update=False,
            is_env=True)

        # Run adversary agent episodes, train the adversary agent as normal.
        assert self.is_paired, 'Paired ZPD only works with having PAIRED on.'
        original_adversary_agent_info, adversary_experiences = self.agent_evaluation_rollout(
            agent=adversary_agent,
            num_steps=self.agent_rollout_steps)
        # Evaluate agent on adversary environment
        original_agent_info, experiences = self.agent_evaluation_rollout(
            agent=agent,
            num_steps=self.agent_rollout_steps)

        # Identify the indices of the non-negative regrets and use them to replace the negative regrets
        regret = original_adversary_agent_info['max_return'] - original_agent_info['mean_return']
        if self.args.filter_range == 'zpd':
            indexes = self.determine_zpd_indexes(regret)
        elif self.args.filter_range == 'antizpd':
            indexes = self.determine_zpd_indexes(-regret)
        elif self.args.filter_range == 'mean_zpd':
            mean = original_adversary_agent_info['mean_return'] - original_agent_info['mean_return']
            indexes = self.determine_zpd_indexes(mean)
        elif self.args.filter_range == 'min_reward_zpd':
            indexes = torch.topk(-original_agent_info['mean_return'][:, 0], self.args.num_zpd_indices).indices
            diff = regret.numel() - indexes.numel()
            # Randomly sample #`diff` indices from the non-negative indices
            rand_idxs = torch.randint(0, indexes.numel(), (diff,), device=self.device)
            rand_idxs = indexes[rand_idxs]
            # Concatenate the non-negative indices with the randomly sampled indices
            indexes = torch.cat((indexes, rand_idxs), dim=0)
        else:
            raise ValueError('Invalid filter range {}'.format(self.args.filter_range))

        # Swap them for observations that do satisfy ZPD.
        adversary_experiences = self.remap_experiences(adversary_experiences, indexes)
        # Update agent's experiences with the new observations
        adversary_agent_info = self.update_agent_with_experiences(
            adversary_agent, adversary_experiences, num_steps=self.agent_rollout_steps)

        # Swap them for observations that do satisfy ZPD.
        experiences = self.remap_experiences(experiences, indexes)
        # Update agent's experiences with the new observations
        agent_info = self.update_agent_with_experiences(agent, experiences, num_steps=self.agent_rollout_steps)
        return env_info, agent_info, adversary_agent_info

    def paired_rollout(self, adversary_env, agent, adversary_agent):
        # Generate a batch of adversarial environments
        env_info = self.agent_rollout(
            agent=adversary_env,
            num_steps=self.adversary_env_rollout_steps, # 52
            update=False,
            is_env=True)

        # Run protagonist agent episodes
        agent_info = self.agent_rollout(
            agent=agent,
            num_steps=self.agent_rollout_steps, # 256
            update=self.is_training)

        # Run adversary agent episodes
        adversary_agent_info = defaultdict(float)
        if self.args.only_norm:
            adversary_agent_info = {
                'mean_return': torch.zeros(self.args.num_processes, 1),
                'max_return': torch.zeros(self.args.num_processes, 1),
                'returns': [[0.0] for _ in range(self.args.num_processes)],
                'value_loss': torch.zeros(1,1),
                'pg_loss': torch.zeros(1,1),
                'dist_entropy': torch.zeros(1,1),
                'action_loss': torch.zeros(1,1),
                'action_complexity': torch.zeros(1,1),
                'update_info': {'grad_norms': np.zeros(self.args.num_processes)},
            }
        elif self.is_paired:
            adversary_agent_info = self.agent_rollout(
                agent=adversary_agent,
                num_steps=self.agent_rollout_steps,
                update=self.is_training)
        return env_info, agent_info, adversary_agent_info

    def run(self):
        args = self.args

        adversary_env = self.agents['adversary_env']
        agent = self.agents['agent']
        adversary_agent = self.agents['adversary_agent']

        if args.filter_range:
            env_info, agent_info, adversary_agent_info = self.paired_zpd_rollout(
                adversary_env=adversary_env,
                agent=agent,
                adversary_agent=adversary_agent)
        else:
            env_info, agent_info, adversary_agent_info = self.paired_rollout(
                adversary_env=adversary_env,
                agent=agent,
                adversary_agent=adversary_agent)

        if not self.args.only_norm:
            env_return = self._compute_env_return(agent_info, adversary_agent_info)
        else:
            env_return = torch.Tensor([0.0])

        adversary_env_info = defaultdict(float)
        if self.is_training and self.is_training_env:
            with torch.no_grad():
                obs_id = adversary_env.storage.get_obs(-1)
                next_value = adversary_env.get_value(
                    obs_id, adversary_env.storage.get_recurrent_hidden_state(-1),
                    adversary_env.storage.masks[-1]).detach()
            grad_norm = np.mean(agent_info['update_info']['grad_norms'])
            if args.only_norm:
                adversary_env.storage.replace_final_return(grad_norm)
            elif args.with_norm:
                if args.multiply_norm:
                    adversary_env.storage.replace_final_return(env_return*grad_norm)
                else:
                    adversary_env.storage.replace_final_return(env_return+grad_norm)
            else:
                adversary_env.storage.replace_final_return(env_return)
            adversary_env.storage.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda)
            env_value_loss, env_action_loss, env_dist_entropy, info = adversary_env.update()
            adversary_env_info.update({
                'action_loss': env_action_loss,
                'value_loss': env_value_loss,
                'dist_entropy': env_dist_entropy,
                'update_info': info
            })

        if self.is_training:
            self.num_updates += 1

        # === LOGGING ===
        stats = self._get_env_stats(agent_info, adversary_agent_info)
        stats.update({
            'mean_env_return': env_return.mean().item(),
            'adversary_env_pg_loss': adversary_env_info['action_loss'],
            'adversary_env_value_loss': adversary_env_info['value_loss'],
            'adversary_env_dist_entropy': adversary_env_info['dist_entropy'],
        })

        [self.agent_returns.append(r) for b in agent_info['returns'] for r in reversed(b)]
        mean_agent_return = 0
        if len(self.agent_returns) > 0:
            mean_agent_return = np.mean(self.agent_returns)

        mean_adversary_agent_return = 0
        if self.is_paired:
            [self.adversary_agent_returns.append(r) for b in adversary_agent_info['returns'] for r in reversed(b)]
            if len(self.adversary_agent_returns) > 0:
                mean_adversary_agent_return = np.mean(self.adversary_agent_returns)

        steps = self.num_updates * args.num_processes * args.num_steps,
        stats.update({
            'steps': steps,
            'total_episodes': self.total_episodes_collected,

            'mean_agent_return': mean_agent_return,
            'agent_value_loss': agent_info['value_loss'],
            'agent_pg_loss': agent_info['action_loss'],
            'agent_dist_entropy': agent_info['dist_entropy'],

            'mean_adversary_agent_return': mean_adversary_agent_return,
            'adversary_value_loss': adversary_agent_info['value_loss'],
            'adversary_pg_loss': adversary_agent_info['action_loss'],
            'adversary_dist_entropy': adversary_agent_info['dist_entropy'],
        })

        if args.log_grad_norm:
            agent_grad_norm = np.mean(agent_info['update_info']['grad_norms'])
            adversary_grad_norm = 0
            adversary_env_grad_norm = 0
            if self.is_paired:
                adversary_grad_norm = np.mean(adversary_agent_info['update_info']['grad_norms'])
            if self.is_training_env:
                adversary_env_grad_norm = np.mean(adversary_env_info['update_info']['grad_norms'])
            stats.update({
                'agent_grad_norm': agent_grad_norm,
                'adversary_grad_norm': adversary_grad_norm,
                'adversary_env_grad_norm': adversary_env_grad_norm
            })

        if args.log_action_complexity:
            stats.update({
                'agent_action_complexity': agent_info['action_complexity'],
                'adversary_action_complexity': adversary_agent_info['action_complexity']
            })
        true_regret = adversary_agent_info["max_return"] - agent_info["mean_return"]

        if not args.with_zpd or not args.with_norm or not args.only_norm:
            keep_mask = (true_regret > 0)[:, 0]
            non_negative_idxs = torch.nonzero(keep_mask).squeeze(1)
            self.num_replaced = keep_mask.numel() - non_negative_idxs.numel()
            self.num_samples = keep_mask.numel()


        wandb.log({
            'steps': steps,
            'true_regret': wandb.Histogram(np.array(true_regret)),
            'adversary/max_return': wandb.Histogram(np.array(adversary_agent_info['max_return'])),
            'adversary/mean_return': wandb.Histogram(np.array(adversary_agent_info['mean_return'])),
            'protagonist/max_return': wandb.Histogram(np.array(agent_info['max_return'])),
            'protagonist/mean_return': wandb.Histogram(np.array(agent_info['mean_return'])),
            'zpd_analysis/total_counts':  self.total_zpd,
            'zpd_analysis/total_replacements': self.total_zpd_replacements,
            'zpd_analysis/num_replaced': self.num_replaced,
            'zpd_analysis/num_samples': self.num_samples,
        })
        return stats
