from numpy.core.numeric import zeros_like
import torch
import torch as ch
import copy
from torch.nn.functional import mse_loss
import tqdm
import sys
import time
import dill
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from scipy.special import lambertw
import random
from copy import deepcopy
import gym
from auto_LiRPA import BoundedModule, PerturbationSynonym
from auto_LiRPA.eps_scheduler import LinearScheduler
from auto_LiRPA.eps_scheduler import SmoothedScheduler
from auto_LiRPA.eps_scheduler import BaseScheduler
from auto_LiRPA.bounded_tensor import BoundedTensor
from auto_LiRPA.perturbations import PerturbationLpNorm
from .models import *
from .torch_utils import *
from .steps import value_step, step_with_mode, pack_history
from .logging import *

from .convex_relaxation import get_kl_bound as get_state_kl_bound
from collections import Counter
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.running_mean_std import RunningMeanStd
from policy_gradients.custom_env import make_env, VecNormalize_MultiAgent

from PIL import Image
import omegaconf
import mbrl.models as models
import mbrl.util.common as common_util
class ExpoScheduler(BaseScheduler):

    def __init__(self, max_eps, opt_str):
        super(ExpoScheduler, self).__init__(max_eps, opt_str)
        # Epoch number to start schedule
        self.schedule_start = int(self.params['start'])
        # Epoch length for completing the schedule
        self.schedule_length = int(self.params['length'])
        # Mid point to change exponential to linear schedule
        self.mid_point = float(self.params.get('mid', 0.25))
        # Exponential
        self.beta = float(self.params.get('beta', 4.0))
        assert self.beta >= 2.
        assert self.mid_point >= 0. and self.mid_point <= 1.
        self.batch = 0

    
    # Set how many batches in an epoch
    def set_epoch_length(self, epoch_length):
        if self.epoch_length != self.epoch_length:
            self.epoch_length = epoch_length
        else:
            if self.epoch_length != epoch_length:
                raise ValueError("epoch_length must stay the same for ExpoScheduler")

    def step_epoch(self, verbose = True):
        super(ExpoScheduler, self).step_epoch()
        # FIXME 
        if verbose == False:
            for i in range(self.epoch_length):
                self.step_batch()

    def step_batch(self, verbose=False):
        if self.is_training:
            self.batch += 1
            init_value =1e-10
            final_value = self.max_eps
            beta = self.beta
            step = self.batch - 1
            # Batch number for schedule start
            init_step = (self.schedule_start - 1) * self.epoch_length
            # Batch number for schedule end
            final_step = (self.schedule_start + self.schedule_length - 1) * self.epoch_length
            # Batch number for switching from exponential to linear schedule
            mid_step = int((final_step - init_step) * self.mid_point) + init_step
            c = self.mid_point/(1-self.mid_point)*(final_value/init_value)
            mid_ratio = float(math.e**(lambertw(c)))
            mid_value = mid_ratio*init_value

            is_ramp = float(step > init_step)
            # linear schedule after mid step
            is_linear = float(step >= mid_step)
            exp_value = init_value*mid_ratio**((step-init_step)/(mid_step-init_step))
            linear_value = min(mid_value + (final_value - mid_value) * (step - mid_step) / (final_step - mid_step), final_value)
 
            if not is_ramp:
                self.eps = init_value
            elif is_linear:
                self.eps =  self.linear_schedule(step, mid_step, final_step, mid_value, final_value)
            else:
                self.eps =  init_value*mid_ratio**((step-init_step)/(mid_step-init_step))
    def linear_schedule(self, step, init_step, final_step, init_value, final_value):
        """Linear schedule."""
        assert final_step >= init_step
        if init_step == final_step:
            return final_value
        rate = float(step - init_step) / float(final_step - init_step)
        linear_value = rate * (final_value - init_value) + init_value
        return np.clip(linear_value, min(init_value, final_value), max(init_value, final_value))

# def exp_schedule(self, step, init_step, final_step, init_value, final_value, mid_point=.25):
#         """Exponential schedule that slowly morphs into a linear schedule."""
#         assert init_value > 0
#         assert final_value >= init_value
#         assert final_step >= init_step
#         assert mid_point >= 0. and mid_point <= 1.
#         mid_step = int((final_step - init_step) * mid_point) + init_step
        
#         #find point where derivatives are approximately equal
#         c = mid_point/(1-mid_point)*(final_value/init_value)
#         mid_ratio = float(math.e**(lambertw(c)))
#         mid_value = mid_ratio*init_value
        
#         is_ramp = float(step > init_step)
#         is_linear = float(step >= mid_step)
        
#         if not is_ramp:
#             return init_value
#         elif is_linear:
#             return self.linear_schedule(step, mid_step, final_step, mid_value, final_value)
#         else:
#             return init_value*mid_ratio**((step-init_step)/(mid_step-init_step))

class Trainer():
    '''
    This is a class representing a Policy Gradient trainer, which 
    trains both a deep Policy network and a deep Value network.
    Exposes functions:
    - advantage_and_return
    - multi_actor_step
    - reset_envs
    - run_trajectories
    - train_step
    Trainer also handles all logging, which is done via the "cox"
    library
    '''
    def __init__(self, policy_net_class, value_net_class, params,
                 store, advanced_logging=True, log_every=5):
        '''
        Initializes a new Trainer class.
        Inputs;
        - policy, the class of policy network to use (inheriting from nn.Module)
        - val, the class of value network to use (inheriting from nn.Module)
        - step, a reference to a function to use for the policy step (see steps.py)
        - params, an dictionary with all of the required hyperparameters
        '''
        # Parameter Loading
        self.params = Parameters(params)

        # Whether or not the value network uses the current timestep
        time_in_state = self.VALUE_CALC == "time"

        # Whether to use GPU (as opposed to CPU)
        if not self.CPU:
            torch.set_default_tensor_type("torch.cuda.FloatTensor")

        # Environment Loading
        def env_constructor():
            # Whether or not we should add the time to the state
            horizon_to_feed = self.T if time_in_state else None

            # temp = make_env(self.MUJOCO_GAME, 
            #            params=self.params,
            #            total_step=self.TRAIN_STEPS,
            #            add_t_with_horizon=horizon_to_feed,
            #            clip_obs=self.CLIP_OBSERVATIONS,
            #            clip_rew=self.CLIP_REWARDS,
            #            tag=1+(1-self.VICTIM_ID), version=self.VERSION, agent_idx=(1-self.VICTIM_ID))
            # init_ob = temp.reset()
            # dummy_action = (np.random.random_sample((4,17)), np.random.random_sample((4,17)))
            # dummy_action = np.clip(dummy_action, -0.4, 0.4)
            # _ = temp.step(dummy_action)

            multi_env = SubprocVecEnv([lambda: make_env(self.MUJOCO_GAME, 
                       params=self.params,
                       total_step=self.TRAIN_STEPS * self.T,
                       clip_obs=self.CLIP_OBSERVATIONS,
                       clip_rew=self.CLIP_REWARDS) for _ in range(self.NUM_ACTORS)])
            venv = VecNormalize_MultiAgent(multi_env, clip_obs=self.CLIP_OBSERVATIONS, clip_reward=self.CLIP_REWARDS)
            for i in range(1,3):
                agent_path = get_zoo_path(self.MUJOCO_GAME, tag=i, version=self.VERSION)
                param = load_from_file(param_pkl_path=agent_path)
                obs_mean, obs_var, obs_count, ret_mean, ret_var, ret_count = get_rms_vars(param, ob_shapes=multi_env.get_attr('num_features')[0])
                self.set_rms(getattr(venv, 'obs_rms_%s'%(str(i-1))), obs_mean, obs_var, obs_count)
                self.set_rms(getattr(venv, 'ret_rms_%s'%(str(i-1))), ret_mean, ret_var, ret_count)
            print()
            print('use running average from zoo agent...')
            print()
            return venv
      

        self.envs = env_constructor()
        self.params.AGENT_TYPE = "discrete" if self.envs.get_attr('is_discrete')[0] else "continuous"
        self.params.NUM_ACTIONS = self.envs.get_attr('num_actions')[0]
        self.params.NUM_FEATURES = self.envs.get_attr('num_features')[0]
        self.policy_step = step_with_mode(self.MODE, adversary=False)
        self.adversary_policy_step = step_with_mode(self.MODE, adversary=True)
        self.params.MAX_KL_INCREMENT = (self.params.MAX_KL_FINAL - self.params.MAX_KL) / self.params.TRAIN_STEPS
        self.advanced_logging = advanced_logging
        self.n_steps = 0
        self.log_every = log_every
        self.policy_net_class = policy_net_class
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.horizon = self.HORIZON
        self.dyn_horizon = self.DYN_HORIZON

        # Instantiation
        self.policy_model = policy_net_class(self.NUM_FEATURES, self.NUM_ACTIONS,
                                             self.INITIALIZATION,
                                             time_in_state=time_in_state,
                                             activation=self.policy_activation,
                                             num_actors=self.NUM_ACTORS)
        
        # Instantiate convex relaxation model when mode is 'robust_ppo'
        if self.MODE in ['robust_ppo', 'radial_ppo', 'robust_ppo_stdv'] or self.MODE == 'adv_sa_ppo' or self.MODE == 'adv_radial':
            self.create_relaxed_model(time_in_state)

        # Minimax training
        if self.MODE == 'adv_ppo' or self.MODE == 'adv_trpo' or self.MODE == 'adv_sa_ppo' or self.MODE == 'adv_radial':
            # Copy parameters if they are set to "same".
            if self.params.ADV_PPO_LR_ADAM == "same":
                self.params.ADV_PPO_LR_ADAM = self.params.PPO_LR_ADAM
            if self.params.ADV_VAL_LR == "same":
                self.params.ADV_VAL_LR = self.params.VAL_LR
            if self.params.ADV_CLIP_EPS == "same":
                self.params.ADV_CLIP_EPS = self.params.CLIP_EPS
            if self.params.ADV_EPS == "same":
                self.params.ADV_EPS = self.params.ROBUST_PPO_EPS
            if self.params.ADV_ENTROPY_COEFF == "same":
                self.params.ADV_ENTROPY_COEFF = self.params.ENTROPY_COEFF
            # The adversary policy has features as input, features as output.
            if self.LOAD_ADV_ATTACKER: # attacker is mlp policy 
                self.attacker_policy_model = CtsPolicy(self.NUM_FEATURES, self.NUM_ACTIONS,
                                                    self.INITIALIZATION,
                                                    time_in_state=time_in_state,
                                                    activation=self.policy_activation,
                                                    num_actors=self.NUM_ACTORS)   
            else:
                self.attacker_policy_model = policy_net_class(self.NUM_FEATURES, self.NUM_ACTIONS,
                                                 self.INITIALIZATION,
                                                 time_in_state=time_in_state,
                                                 activation=self.policy_activation,
                                                 num_actors=self.NUM_ACTORS)
            # Optimizer for adversary
            self.params.ADV_POLICY_ADAM = optim.Adam(self.attacker_policy_model.parameters(), lr=self.ADV_PPO_LR_ADAM, eps=1e-5)

            # Adversary value function.
            if self.params.LOAD_ADV_ATTACKER:
                self.attacker_val_model = ValueDenseNet(self.NUM_FEATURES, self.INITIALIZATION, num_actors=self.NUM_ACTORS)
            else:
                self.attacker_val_model = value_net_class(self.NUM_FEATURES, self.INITIALIZATION, num_actors=self.NUM_ACTORS)
            self.attacker_val_opt = optim.Adam(self.attacker_val_model.parameters(), lr=self.ADV_VAL_LR, eps=1e-5)
            assert self.attacker_policy_model.discrete == (self.AGENT_TYPE == "discrete")

            # Learning rate annealling for adversary.
            if self.ANNEAL_LR:
                adv_lam = lambda f: 1-f/self.TRAIN_STEPS
                adv_ps = optim.lr_scheduler.LambdaLR(self.ADV_POLICY_ADAM, 
                                                        lr_lambda=adv_lam)
                adv_vs = optim.lr_scheduler.LambdaLR(self.attacker_val_opt, lr_lambda=adv_lam)
                self.params.ADV_POLICY_SCHEDULER = adv_ps
                self.params.ADV_VALUE_SCHEDULER = adv_vs

        opts_ok = (self.PPO_LR == -1 or self.PPO_LR_ADAM == -1)
        assert opts_ok, "One of ppo_lr and ppo_lr_adam must be -1 (off)."
        # Whether we should use Adam or simple GD to optimize the policy parameters
        if self.PPO_LR_ADAM != -1:
            kwargs = {
                'lr':self.PPO_LR_ADAM,
            }

            if self.params.ADAM_EPS > 0:
                kwargs['eps'] = self.ADAM_EPS

            self.params.POLICY_ADAM = optim.Adam(self.policy_model.parameters(),
                                                 **kwargs)
        else:
            self.params.POLICY_ADAM = optim.SGD(self.policy_model.parameters(), lr=self.PPO_LR)

        # If using a time dependent value function, add one extra feature
        # for the time ratio t/T
        if time_in_state:
            self.params.NUM_FEATURES = self.NUM_FEATURES + 1

        # Value function optimization
        self.val_model = value_net_class(self.NUM_FEATURES, self.INITIALIZATION, num_actors=self.NUM_ACTORS)
        self.val_opt = optim.Adam(self.val_model.parameters(), lr=self.VAL_LR, eps=1e-5) 
        assert self.policy_model.discrete == (self.AGENT_TYPE == "discrete")

        # Learning rate annealing
        # From OpenAI hyperparametrs:
        # Set adam learning rate to 3e-4 * alpha, where alpha decays from 1 to 0 over training
        if self.ANNEAL_LR:
            lam = lambda f: 1-f/self.TRAIN_STEPS
            ps = optim.lr_scheduler.LambdaLR(self.POLICY_ADAM, 
                                                    lr_lambda=lam)
            vs = optim.lr_scheduler.LambdaLR(self.val_opt, lr_lambda=lam)
            self.params.POLICY_SCHEDULER = ps
            self.params.VALUE_SCHEDULER = vs

        if store is not None:
            self.setup_stores(store)
        else:
            print("Not saving results to cox store.")

        assert not(torch.cuda.is_available()) == self.CPU
        dyn_in_size = (self.NUM_FEATURES + self.NUM_ACTIONS)*2
        dyn_out_size = self.NUM_FEATURES * 2
        cfg_dict = {
            # dynamics model configuration
            "dynamics_model": {
                "_target_": "mbrl.models.GaussianMLP",
                "device": self.device,
                "num_layers": 5,
                "ensemble_size": self.ENSEMBLE_SIZE,
                "hid_size": 1000,
                "in_size": dyn_in_size,
                "out_size": dyn_out_size,
                "deterministic": True,
                "propagation_method": "fixed_model",
                # can also configure activation function for GaussianMLP
                "activation_fn_cfg": {
                    "_target_": "torch.nn.LeakyReLU",
                    "negative_slope": 0.01
                }
            },
            # options for training the dynamics model
            "algorithm": {
                "learned_rewards": False,
                "target_is_delta": True,
                "normalize": True,
                "dataset_size": self.T * self.NUM_ACTORS
            },
            # these are experiment specific options
            "overrides": {
                "trial_length": None,
                "num_steps": None,
                "model_batch_size": self.DYN_BATCH_SIZE,
                "validation_ratio": 0.1
            }
        
        }
        self.cfg = omegaconf.OmegaConf.create(cfg_dict)
        rng = np.random.default_rng(seed=0)
        self.dynamics_model = common_util.create_one_dim_tr_model(self.cfg, None, None)
        self.replay_buffer = common_util.create_replay_buffer(self.cfg, (self.NUM_FEATURES*2,), (self.NUM_ACTIONS*2,), rng=rng)
        self.dyn_model_trainer = models.ModelTrainer(self.dynamics_model, optim_lr=self.DYN_LR, weight_decay=5e-5)
        self.train_losses = []
        self.val_scores = []
        self.victim_id = self.VICTIM_ID

    
    def set_rms(self, ob_rms, ob_mean, ob_var, obs_count):
        ob_rms.mean = ob_mean
        ob_rms.var = ob_var
        if isinstance(obs_count, np.ndarray):
            ob_rms.count = np.float64(obs_count[0])-2
        else:
            ob_rms.count = obs_count

    def train_callback(self, _model, _total_calls, _epoch, tr_loss, val_score, _best_val):
        self.train_losses.append(tr_loss)
        self.val_scores.append(val_score.mean().item())

    def create_relaxed_model(self, time_in_state=False):
        # Create state perturbation model for robust PPO training.
        if isinstance(self.policy_model, CtsPolicy) or isinstance(self.policy_model, CtsPolicy_stdv):
            if self.ROBUST_PPO_METHOD in ["convex-relax", "ibp", "ibp2"]:
                from .convex_relaxation import RelaxedCtsPolicyForState
                relaxed_policy_model = RelaxedCtsPolicyForState(
                        self.NUM_FEATURES, self.NUM_ACTIONS, time_in_state=time_in_state,
                        activation=self.policy_activation, policy_model=self.policy_model)
                dummy_input1 = torch.randn(1, self.NUM_FEATURES)
                inputs = (dummy_input1, )
                self.relaxed_policy_model = BoundedModule(relaxed_policy_model, inputs, device=self.device)
            elif self.ROBUST_PPO_METHOD in ["convex-relax_stdv"]:
                from .convex_relaxation import RelaxedCtsPolicyForState_stdv
                relaxed_policy_model = RelaxedCtsPolicyForState_stdv(
                        self.NUM_FEATURES, self.NUM_ACTIONS, time_in_state=time_in_state,
                        activation=self.policy_activation, policy_model=self.policy_model)
                dummy_input1 = torch.randn(1, self.NUM_FEATURES)
                inputs = (dummy_input1, )
                self.relaxed_policy_model = BoundedModule(relaxed_policy_model, inputs, device=self.device)
            else:
                # For SGLD no need to create the relaxed model
                self.relaxed_policy_model = None
            if "robust_scheduler" in self.params and self.params["robust_scheduler"] == "smooth":
                self.robust_eps_scheduler = SmoothedScheduler(self.params.ROBUST_PPO_EPS, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)
            elif "robust_scheduler" in self.params and self.params["robust_scheduler"] == "expo":
                self.robust_eps_scheduler = ExpoScheduler(self.params.ROBUST_PPO_EPS, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)

            else:
                self.robust_eps_scheduler = LinearScheduler(self.params.ROBUST_PPO_EPS, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)
            if self.params.ROBUST_PPO_BETA_SCHEDULER_OPTS == "same":
                self.robust_beta_scheduler = LinearScheduler(self.params.ROBUST_PPO_BETA, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)
            else:
                self.robust_beta_scheduler = LinearScheduler(self.params.ROBUST_PPO_BETA, self.params.ROBUST_PPO_BETA_SCHEDULER_OPTS)
        elif isinstance(self.policy_model, CtsLSTMPolicy):
            if self.ROBUST_PPO_METHOD in ["convex-relax", "ibp", "ibp2"]:
                from .convex_relaxation import RelaxedCtsLSTMPolicyForState
                relaxed_policy_model = RelaxedCtsLSTMPolicyForState(
                        self.NUM_FEATURES, self.NUM_ACTIONS, time_in_state=time_in_state,
                        activation=self.policy_activation, policy_model=self.policy_model, num_actors=self.params.NUM_ACTORS)
                dummy_input = torch.randn(self.params.NUM_ACTORS, self.NUM_FEATURES)
                inputs = (dummy_input,)
                self.relaxed_policy_model = BoundedModule(relaxed_policy_model, inputs, device=self.device)
            else:
                # For SGLD no need to create the relaxed model
                self.relaxed_policy_model = None
            if "robust_scheduler" in self.params and self.params["robust_scheduler"] == "smooth":
                self.robust_eps_scheduler = SmoothedScheduler(self.params.ROBUST_PPO_EPS, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)
            elif "robust_scheduler" in self.params and self.params["robust_scheduler"] == "expo":
                self.robust_eps_scheduler = ExpoScheduler(self.params.ROBUST_PPO_EPS, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)

            else:
                self.robust_eps_scheduler = LinearScheduler(self.params.ROBUST_PPO_EPS, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)
            if self.params.ROBUST_PPO_BETA_SCHEDULER_OPTS == "same":
                self.robust_beta_scheduler = LinearScheduler(self.params.ROBUST_PPO_BETA, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)
            else:
                self.robust_beta_scheduler = LinearScheduler(self.params.ROBUST_PPO_BETA, self.params.ROBUST_PPO_BETA_SCHEDULER_OPTS)

        else:
            raise NotImplementedError

    """Initialize sarsa training."""
    def setup_sarsa(self, lr_schedule, eps_scheduler, beta_scheduler):
        # Create the Sarsa model, with S and A as the input.
        self.sarsa_model = ValueDenseNet(self.NUM_FEATURES + self.NUM_ACTIONS, self.INITIALIZATION)
        self.sarsa_opt = optim.Adam(self.sarsa_model.parameters(), lr=self.VAL_LR, eps=1e-5)
        self.sarsa_scheduler = optim.lr_scheduler.LambdaLR(self.sarsa_opt, lr_schedule)
        self.sarsa_eps_scheduler = eps_scheduler
        self.sarsa_beta_scheduler = beta_scheduler
        # Convert model with relaxation wrapper.
        dummy_input = torch.randn(1, self.NUM_FEATURES + self.NUM_ACTIONS)
        self.relaxed_sarsa_model = BoundedModule(self.sarsa_model, dummy_input)
    
    """Initialize imitation (snooping) training."""
    def setup_imit(self, train=True, lr=1e-3):
        # Create a same policy network. 
        self.imit_network = self.policy_net_class(self.NUM_FEATURES, self.NUM_ACTIONS,
                                           self.INITIALIZATION, time_in_state=self.VALUE_CALC == "time",
                                           activation=self.policy_activation)
        if train:
            if self.PPO_LR_ADAM != -1:
                kwargs = {
                    'lr':lr,
                }

                if self.params.ADAM_EPS > 0:
                    kwargs['eps'] = self.ADAM_EPS
                    self.imit_opt = optim.Adam(self.imit_network.parameters(), **kwargs)
            else:
                self.imit_opt = optim.SGD(self.imit_network.parameters(), lr=lr)

    """Training imitation agent"""
    def imit_steps(self, all_actions, all_states, all_not_dones, num_epochs):
        assert len(all_actions) == len(all_states)
        for e in range(num_epochs):
            total_loss_val = 0.0
            if self.HISTORY_LENGTH > 0:
                loss = 0.0
                batches, alive_masks, time_masks, lengths = pack_history([all_states, all_actions], all_not_dones, max_length=self.HISTORY_LENGTH)
                self.imit_opt.zero_grad()
                hidden = None
                for i, batch in enumerate(batches):
                    batch_states, batch_actions = batch
                    mask = time_masks[i].unsqueeze(2)
                    
                    if hidden is not None:
                        hidden = [h[:, alive_masks[i], :].detach() for h in hidden]
                    mean, std, hidden = self.imit_network.multi_forward(batch_states, hidden=hidden)
                    batch_loss = torch.nn.MSELoss()(mean*mask, batch_actions*mask)
                    loss += batch_loss
                loss.backward()
                self.imit_opt.step()
                total_loss_val = loss.item()

            else:
                state_indices = np.arange(len(all_actions))
                # np.random.shuffle(state_indices)
                splits = np.array_split(state_indices, self.params.NUM_MINIBATCHES)
                np.random.shuffle(splits)
                for selected in splits:
                    def sel(*args):
                        return [v[selected] for v in args]
                    
                    self.imit_opt.zero_grad()
                    sel_states, sel_actions, sel_not_dones = sel(all_states, all_actions, all_not_dones)  
                    act, _ = self.imit_network(sel_states)
                    loss = torch.nn.MSELoss()(sel_actions, act)
            
                    loss.backward()
                    self.imit_opt.step()
                    total_loss_val += loss.item()
            
            print('Epoch [%d/%d] avg loss: %.8f' % (e+1, num_epochs, total_loss_val / len(all_actions)))


    def setup_stores(self, store):
        # Logging setup
        self.store = store
        if self.MODE == 'adv_ppo' or self.MODE == 'adv_trpo' or self.MODE == 'adv_sa_ppo' or self.MODE == 'adv_radial':
            adv_optimization_table = {
                'mean_reward_0':float,
                'mean_reward_1':float,
                'final_value_loss':float,
                'final_policy_loss':float,
                'final_surrogate_loss':float,
                'entropy_bonus':float,
                'mean_std':float
            }
            self.store.add_table('optimization_adv', adv_optimization_table)
            rollout_table = {
                'game_win0': float,
                'game_win1': float,
                'game_tie': float
            }
            self.store.add_table('rollout', rollout_table)
        optimization_table = {
            'mean_reward_0':float,
            'mean_reward_1':float,
            'final_value_loss':float,
            'final_policy_loss':float,
            'final_surrogate_loss':float,
            'entropy_bonus':float,
            'mean_std':float,
            'horizon':float,
        }
        self.store.add_table('optimization', optimization_table)

        if self.advanced_logging:
            paper_constraint_cols = {
                'avg_kl':float,
                'max_kl':float,
                'max_ratio':float,
                'opt_step':int
            }

            value_cols = {
                'heldout_gae_loss':float,
                'heldout_returns_loss':float,
                'train_gae_loss':float,
                'train_returns_loss':float
            }

            weight_cols = {}
            for name, _ in self.policy_model.named_parameters():
                name += "."
                for k in ["l1", "l2", "linf", "delta_l1", "delta_l2", "delta_linf"]:
                    weight_cols[name + k] = float

            self.store.add_table('paper_constraints_train',
                                        paper_constraint_cols)
            self.store.add_table('paper_constraints_heldout',
                                        paper_constraint_cols)
            self.store.add_table('value_data', value_cols)
            self.store.add_table('weight_updates', weight_cols)

        if self.params.MODE == 'robust_ppo' or self.params.MODE == 'adv_sa_ppo':
            robust_cols ={
                'eps': float,
                'beta': float,
                'kl': float,
                'surrogate': float,
                'entropy': float,
                'loss': float,
            }
            self.store.add_table('robust_ppo_data', robust_cols)


    def __getattr__(self, x):
        '''
        Allows accessing self.A instead of self.params.A
        '''
        if x == 'params':
            return {}
        try:
            return getattr(self.params, x)
        except KeyError:
            raise AttributeError(x)

    def advantage_and_return(self, rewards, values, not_dones):
        """
        Calculate GAE advantage, discounted returns, and 
        true reward (average reward per trajectory)

        GAE: delta_t^V = r_t + discount * V(s_{t+1}) - V(s_t)
        using formula from John Schulman's code:
        V(s_t+1) = {0 if s_t is terminal
                   {v_s_{t+1} if s_t not terminal and t != T (last step)
                   {v_s if s_t not terminal and t == T
        """
        # assert shape_equal_cmp(rewards, values, not_dones)
        
        # V_s_tp1 = ch.cat([values[:,1:], values[:, -1:]], 1) * not_dones
        V_s_tp1 = values[:,1:] * not_dones
        values = values[:,:-1]
        deltas = rewards + self.GAMMA * V_s_tp1 - values

        # now we need to discount each path by gamma * lam
        advantages = ch.zeros_like(rewards)
        returns = ch.zeros_like(rewards)
        indices = get_path_indices(not_dones)
        for agent, start, end in indices:
            advantages[agent, start:end] = discount_path( \
                    deltas[agent, start:end], self.LAMBDA*self.GAMMA)
            returns[agent, start:end] = discount_path( \
                    rewards[agent, start:end], self.GAMMA)

        return advantages.clone().detach(), returns.clone().detach()

    def reset_envs(self, envs):
        '''
        Resets environments and returns initial state with shape:
        (# actors, 1, ... state_shape)
	    '''
        if self.CPU:
            return cpu_tensorize([env.reset() for env in envs]).unsqueeze(1)
        else:
            return cu_tensorize([env.reset() for env in envs]).unsqueeze(1)

    def multi_actor_step(self, actions, envs):
        '''
        Simulate a "step" by several actors on their respective environments
        Inputs:
        - actions, list of actions to take
        - envs, list of the environments in which to take the actions
        Returns:
        - completed_episode_info, a variable-length list of final rewards and episode lengths
            for the actors which have completed
        - rewards, a actors-length tensor with the rewards collected
        - states, a (actors, ... state_shape) tensor with resulting states
        - not_dones, an actors-length tensor with 0 if terminal, 1 otw
        '''
        normed_rewards, states, not_dones = [], [], []
        completed_episode_info = []
        win = None
        for action, env in zip(actions, envs):
            gym_action = action[0].cpu().numpy()
            new_state, normed_reward, is_done, info = env.step(gym_action)
            if is_done:
                completed_episode_info.append(info['done'])
                new_state = env.reset()
                if 'winner' in info:
                    win = True
                if 'loser' in info:
                    win = False

            # Aggregate
            normed_rewards.append([normed_reward])
            not_dones.append([int(not is_done)])
            states.append([new_state])

        tensor_maker = cpu_tensorize if self.CPU else cu_tensorize
        data = list(map(tensor_maker, [normed_rewards, states, not_dones]))
        return [completed_episode_info, *data], win

    def run_trajectories(self, num_saps, agent_id, return_rewards=False, should_tqdm=False,
            collect_adversary_trajectory=False):
        """
        Resets environments, and runs self.T steps in each environment in 
        self.envs. If an environment hits a terminal state, the env is
        restarted and the terminal timestep marked. Each item in the tuple is
        a tensor in which the first coordinate represents the actor, and the
        second coordinate represents the time step. The third+ coordinates, if
        they exist, represent additional information for each time step.
        Inputs: None
        Returns:
        - rewards: (# actors, self.T)
        - not_dones: (# actors, self.T) 1 in timestep if terminal state else 0
        - actions: (# actors, self.T, ) indices of actions
        - action_logprobs: (# actors, self.T, ) log probabilities of each action
        - states: (# actors, self.T, ... state_shape) states
        """

        # Arrays to be updated with historic info
        # self.envs.seed(2)
        table_name_suffix = "_adv" if collect_adversary_trajectory else ""
        initial_states_all = self.envs.reset()
        last_obs_all_dyninput = initial_states_all
        # victim_value_list = []

        initial_states = initial_states_all[:, agent_id, :]
        initial_states_oppo = initial_states_all[:, 1-agent_id, :]

        if self.CPU:
            initial_states = cpu_tensorize(initial_states)
            initial_states_oppo = cpu_tensorize(initial_states_oppo)
        else:
            initial_states = cu_tensorize(initial_states)
            initial_states_oppo = cu_tensorize(initial_states_oppo)

        self.policy_model.reset()
        self.val_model.reset()
        self.attacker_policy_model.reset()
        self.attacker_val_model.reset()

        # Holds information (length and true reward) about completed episodes
        completed_episode_info = []
        traj_length = int(num_saps // self.NUM_ACTORS)

        shape = (self.NUM_ACTORS, traj_length)
        all_zeros = [ch.zeros(shape) for i in range(3)]
        rewards, not_dones, action_log_probs = all_zeros
        values = ch.zeros((self.NUM_ACTORS, traj_length+1))
     
        actions_shape = shape + (self.NUM_ACTIONS,)
        actions = ch.zeros(actions_shape)
        # Mean of the action distribution. 
        action_means = ch.zeros(actions_shape)

        states_shape = (self.NUM_ACTORS, traj_length+1, initial_states.shape[1]) 
        states = ch.zeros(states_shape)      
        iterator = range(traj_length) if not should_tqdm else tqdm.trange(traj_length)
        tensor_maker = cpu_tensorize if self.CPU else cu_tensorize

        # States are collected before the perturbation.
        states[:, 0, :] = initial_states
        last_states = states[:, 0, :]
        last_states_oppo = initial_states_oppo

        outcomes = []     
        next_not_dones = torch.ones(self.NUM_ACTORS)
        for t in iterator:
            
            # The adversary may use the policy or value function, so pause history update.
            # self.policy_model.pause_history()
            # self.val_model.pause_history()
            # self.attacker_policy_model.pause_history()
            # self.attacker_val_model.pause_history()

            if collect_adversary_trajectory:
                # currently train attacker
                last_states_attacker = last_states
                last_states_victim = last_states_oppo
            else:
                last_states_attacker = last_states_oppo
                last_states_victim = last_states
                         
            oppo_action_pds = self.attacker_policy_model(last_states_attacker, next_not_dones)
            oppo_action_means, oppo_action_stds = oppo_action_pds
            next_adv_action = self.attacker_policy_model.sample(oppo_action_pds)
            next_oppo_log_probs = self.attacker_policy_model.get_loglikelihood(oppo_action_pds, next_adv_action)
            next_adv_action = next_adv_action.unsqueeze(1) #(num_envs, 1, action_dims)
            # next_adv_action_clip = np.clip(next_adv_action.detach().cpu().numpy(), self.envs.action_space.low, self.envs.action_space.high)
            next_adv_action_clip = next_adv_action.detach().cpu().numpy()

            # self.policy_model.continue_history()
            # self.val_model.continue_history()
            # self.attacker_policy_model.continue_history()
            # self.attacker_val_model.continue_history()

            action_pds = self.policy_model(last_states_victim, next_not_dones)
            
            next_action_means, next_action_stds = action_pds
            next_actions = self.policy_model.sample(action_pds)
            next_action_log_probs = self.policy_model.get_loglikelihood(action_pds, next_actions)
            next_actions = next_actions.unsqueeze(1)
            # next_actions_clip = np.clip(next_actions.detach().cpu().numpy(), self.envs.action_space.low, self.envs.action_space.high)
            next_actions_clip = next_actions.detach().cpu().numpy()

            if collect_adversary_trajectory and self.HISTORY_LENGTH > 0:
                train_agt_value = self.attacker_val_model(last_states_attacker, next_not_dones).squeeze(1)
            elif self.HISTORY_LENGTH > 0:
                train_agt_value = self.val_model(last_states_victim, next_not_dones).squeeze(1)
            else:
                train_agt_value = 0

            if collect_adversary_trajectory != bool(agent_id):
                agents_actions = np.concatenate((next_adv_action_clip, next_actions_clip), axis = 1)
            else:
                agents_actions = np.concatenate((next_actions_clip, next_adv_action_clip), axis = 1)

            # sapr = ch.cat((last_states_attacker, last_states_victim, next_adv_action.squeeze(1), next_actions.squeeze(1)), dim=1)
            # next_pred_obs = self.dynamics_model.forward(sapr, use_propagation=False)[0][0] + ch.cat((last_states_attacker, last_states_victim), dim=1)  
              
            # last_obs_all = self.envs.get_original_obs()
            if t > 0:
                last_obs_all_dyninput = next_states_two
            next_states_two, next_rewards, next_dones, infos = self.envs.step(agents_actions)
            next_obs_all = self.envs.get_original_obs()
            next_obs_all_dyninput = next_states_two
            next_not_dones = np.logical_not(next_dones)

            # numpy_pred_obs = next_pred_obs.detach().cpu().numpy()
            # print(numpy_pred_obs[0, 0:24] - next_states_two[0, 0, 0:24])

            # add state action pairs to replay buffer   
            if self.MODE == 'adv_ppo' and not self.params['sarsa_enable']: 
                for j in range(self.NUM_ACTORS):
                    last_obs = last_obs_all_dyninput[j].reshape(2*self.NUM_FEATURES)
                    next_obs = next_obs_all_dyninput[j].reshape(2*self.NUM_FEATURES)
                    last_acts = agents_actions[j].reshape(2*self.NUM_ACTIONS)
                    self.replay_buffer.add(last_obs, last_acts, next_obs, 0, next_dones[j])

            add_perturb = self.params.ADV_ADVERSARY_RATIO >= random.random() and not self.params['sarsa_enable']
            if add_perturb:
                atk_state = tensor_maker(next_states_two[:, 1 - agent_id])
                vic_state = tensor_maker(next_states_two[:, agent_id])
                update, horizon = self.apply_mvd_attack(atk_state, vic_state, not_dones=tensor_maker(next_not_dones), horizon=self.horizon, dyn_horizon=self.dyn_horizon)
                self.store.log_table_and_tb('optimization'+table_name_suffix, {
                        'horizon': horizon[0],
                    })
                perturbations = (update.sign() * self.MVD_EPS).detach().cpu().numpy()
               
                for j in range(self.NUM_ACTORS):
                    if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
                        # pos_data = perturbations[j, 356:380] + next_obs_all[j, agent_id, 356:380]
                        next_obs_all[j, agent_id, 356:380] +=  perturbations[j, 356:380]
                        # next_obs_all[j, 1 - agent_id, 0:24] = next_obs_all[j, agent_id, 356:380]
                    elif self.MUJOCO_GAME == "multicomp/KickAndDefend-v0":
                        # pos_data = perturbations[j, 354:378] #+ next_obs_all[j, agent_id, 354:378]  
                        next_obs_all[j, agent_id, 354:378] += perturbations[j, 354:378]
                    elif self.MUJOCO_GAME in {"multicomp/SumoAnts-v0", "multicomp/RunToGoalAnts-v0"}: 
                        # pos_data = perturbations[j, 107:122] + next_obs_all[j, agent_id, 107:122] 
                        next_obs_all[j, agent_id, 107:122] += perturbations[j, 107:122]
                    normalized_obs = self.envs.normalize_obs(next_obs_all)
                    # _, normalized_obs = self.envs.env_method("set_pos_vel", *(pos_data, 1 - agent_id), indices=j)
                    next_states_two[j, self.victim_id, :] = normalized_obs[j, self.victim_id, :]
                    next_states_two[j, 1 - self.victim_id, :] = normalized_obs[j, 1 - self.victim_id, :]
            elif not self.params['sarsa_enable']: # do not add horizon to tensorboard when collecting sarsa saps
                self.store.log_table_and_tb('optimization'+table_name_suffix, {
                        'horizon': -1,
                    })
            # res = self.envs.render('rgb_array')
            # im = Image.fromarray(res)
            # im.save("imgs/%s_step.jpg" %(str(t))) 
            
            last_states_oppo = tensor_maker(next_states_two[:, 1-agent_id, :])
            data = list(map(tensor_maker, [next_states_two[:, agent_id, :], next_rewards[:, agent_id], next_not_dones]))
            next_states, next_rewards, next_not_dones = data
            done_info = []
            # if t < 30:

            #     victim_value = self.val_model(next_states.unsqueeze(1))
            #     victim_value_list.append(victim_value.item())
            #     # print(victim_value.item())
          
            for not_done, info in zip(next_not_dones, infos):
                if not not_done:
                    done_info.append(info[agent_id]['done'])
                    if 'winner' in info[agent_id]:
                        outcomes.append(agent_id)
                    elif 'loser' in info[agent_id]:
                        outcomes.append(1-agent_id)
                    else:
                        outcomes.append(None)


            
            # done_info = List of (length, reward) pairs for each completed trajectory
            # (next_rewards, next_states, next_dones) act like multi-actor env.step()
  
            # done_info, next_rewards, next_states, next_not_dones = ret
       
            # Reset the policy (if the policy has memory if we are done)
            # if next_not_dones.item() == 0:
            #     self.policy_model.reset()
            #     self.val_model.reset()
            # assert shape_equal([self.NUM_ACTORS, 1], next_rewards, next_not_dones)
            # assert shape_equal([self.NUM_ACTORS, 1, self.NUM_FEATURES], next_states)

            # If some of the actors finished AND this is not the last step
            # OR some of the actors finished AND we have no episode information
            if len(done_info) > 0 and (t != self.T - 1 or len(completed_episode_info) == 0):
                completed_episode_info.extend(done_info)

            # Update histories
            # each shape: (nact, t, ...) -> (nact, t + 1, ...)

            if collect_adversary_trajectory:
                # negate the reward for minimax training. Collect states before perturbation.
                pairs = [
                    (rewards, next_rewards),
                    (not_dones, next_not_dones),
                    (actions, next_adv_action.squeeze(1)), # The sampled actions, which is perturbations.
                    (action_means, oppo_action_means), # The Gaussian mean of actions.
                    # (action_stds, next_action_stds), # The Gaussian std of actions, is a constant, no need to save.
                    (action_log_probs, next_oppo_log_probs),
                    (states, next_states), # we save the true environment state without perturbation.
                    (values, train_agt_value),
                ]
            else:
                
                # Previous naive adversarial training. We save the true environment state.
                pairs = [
                    (rewards, next_rewards),
                    (not_dones, next_not_dones),
                    (actions, next_actions.squeeze(1)), # The sampled actions.
                    (action_means, next_action_means), # The Gaussian mean of actions.
                    # (action_stds, next_action_stds), # The Gaussian std of actions, is a constant, no need to save.
                    (action_log_probs, next_action_log_probs),
                    (states, next_states), # true environment state.
                    (values, train_agt_value),
                ]

            for total, v in pairs:
                if total is states:
                    # Next states, stores in the next position.
                    total[:, t+1] = v
                else:
                    # The current action taken, and reward received.
                    # When perturbed state is collected, we also do not neeed the +1 shift
                    # print(total[:, t].shape)
                    # print(v.shape)
                    total[:, t] = v
            last_states = next_states

        if collect_adversary_trajectory and self.HISTORY_LENGTH > 0:
            train_agt_value = self.attacker_val_model(last_states_attacker, next_not_dones).squeeze()
        elif self.HISTORY_LENGTH > 0:
            train_agt_value = self.val_model(last_states_victim, next_not_dones).squeeze()
        else:
            train_agt_value = 0
        values[:,-1] = train_agt_value

        c = Counter()
        c.update(outcomes)
        num_games = len(completed_episode_info)+1e-8
        print('*********************************************')
        print('game_win0', c.get(0,0) / num_games)
        print('game_win1', c.get(1,0) / num_games)
        print('game_tie', c.get(None,0) / num_games)
        print('total_games', num_games)
        print('*********************************************')
        print()
        if not self.params['sarsa_enable']:
            self.store.log_table_and_tb('rollout', {
                    'game_win0': c.get(0,0) / num_games,
                    'game_win1': c.get(1,0) / num_games,
                    'game_tie': c.get(None,0) / num_games,
                })


        # if collect_adversary_trajectory:
        #     # Finished adversary step. Take new samples for normalizing environment.
        #     for e, flag in zip(self.envs, old_env_read_only_flags):
        #         e.normalizer_read_only = flag


        # Calculate the average episode length and true rewards over all the trajectories
        infos = np.array(list(zip(*completed_episode_info)))
        
        # print(infos)
        if infos.size > 0:
            avg_episode_length, ave_ep_reward_0, avg_ep_reward_1 = np.mean(infos, axis=1)
            avg_episode_reward = [ave_ep_reward_0, avg_ep_reward_1]
        else:
            avg_episode_length = -1
            avg_episode_reward = [-1, -1]

        # Last state is never acted on, discard
        # states = states[:,:-1,:]
        # states, rewards, action_log_probs, not_dones, actions, action_means = \
        #     map(swap_and_flatten, (states, rewards, action_log_probs, not_dones, actions, action_means))
        '''
        rewards: (num_envs, T)
        action_log_probs: (num_envs, T)
        not_dones: (num_envs, T)
        actions: (num_envs, T, action_dim)
        states: (num_envs, T, state_dim)
        action_means: (num_envs, T, action_dim)
        '''
        trajs = Trajectories(rewards=rewards, 
            action_log_probs=action_log_probs, not_dones=not_dones, 
            actions=actions, states=states, action_means=action_means, action_std=next_action_stds, values=values)

        to_ret = (avg_episode_length, avg_episode_reward, trajs)
        # if return_rewards:
        #     to_ret += (ep_rewards,)

        return to_ret
    
    def run_whole_trajectories(self, num_saps, victim_id, return_rewards=False, should_tqdm=False):
        initial_states = self.envs.reset()[:, victim_id, :]
        initial_states_oppo = self.envs.reset()[:, 1-victim_id, :]
        if self.CPU:
            initial_states = cpu_tensorize(initial_states)
            initial_states_oppo = cpu_tensorize(initial_states_oppo)
        else:
            initial_states = cu_tensorize(initial_states)
            initial_states_oppo = cu_tensorize(initial_states_oppo)
        self.policy_model.reset()
        self.val_model.reset()

        completed_episode_info = []
        traj_length = int(num_saps // self.NUM_ACTORS)

        shape = (self.NUM_ACTORS, traj_length)
        all_zeros = [ch.zeros(shape) for _ in range(3)]
        rewards, not_dones, action_log_probs = all_zeros
        rewards_oppo, not_dones_oppo, action_log_probs_oppo = all_zeros

        actions_shape = shape + (self.NUM_ACTIONS,)
        actions = ch.zeros(actions_shape)
        action_means = ch.zeros(actions_shape)
        actions_oppo = ch.zeros(actions_shape)
        action_means_oppo = ch.zeros(actions_shape)
    
        states_shape = (self.NUM_ACTORS, traj_length+1, initial_states.shape[1]) 
        states =  ch.zeros(states_shape)
        states_oppo = ch.zeros(states_shape)

        iterator = range(traj_length) if not should_tqdm else tqdm.trange(traj_length)

        states[:, 0, :] = initial_states
        states_oppo[:, 0, :] = initial_states_oppo
        last_states = states[:, 0, :]
        last_states_oppo = states_oppo[:, 0, :]

        outcomes = []     
        for t in iterator:
            self.policy_model.pause_history()
            self.val_model.pause_history()
            self.attacker_policy_model.pause_history()
            self.attacker_val_model.pause_history()
       
            # maybe_attacked_last_states = self.apply_attack(last_states)
            # max_eps = (maybe_attacked_last_states - last_states).abs().max().item()
            # attack_eps = float(self.params.ROBUST_PPO_EPS) if self.params.ATTACK_EPS == "same" else float(self.params.ATTACK_EPS)
            # if max_eps > attack_eps + 1e-5:
            #     raise RuntimeError(f"{max_eps} > {attack_eps}. Attack implementation has bug and eps is not correctly handled.")
            # last_states = maybe_attacked_last_states
            
            oppo_action_pds = self.attacker_policy_model(last_states_oppo)
            oppo_action_means, oppo_action_stds = oppo_action_pds
            next_adv_action = self.attacker_policy_model.sample(oppo_action_pds)
            next_oppo_log_probs = self.attacker_policy_model.get_loglikelihood(oppo_action_pds, next_adv_action)
            next_adv_action = next_adv_action.unsqueeze(1) #(num_envs, 1, action_dims)
            next_adv_action_clip = np.clip(next_adv_action.detach().cpu().numpy(), self.envs.action_space.low, self.envs.action_space.high)

            self.policy_model.continue_history()
            self.val_model.continue_history()
            self.attacker_policy_model.continue_history()
            self.attacker_val_model.continue_history()
            action_pds = self.policy_model(last_states)
            next_action_means, next_action_stds = action_pds
            next_actions = self.policy_model.sample(action_pds)
            next_action_log_probs = self.policy_model.get_loglikelihood(action_pds, next_actions)
            next_actions = next_actions.unsqueeze(1)
            next_actions_clip = np.clip(next_actions.detach().cpu().numpy(), self.envs.action_space.low, self.envs.action_space.high)

            if victim_id == 1:
                agents_actions = np.concatenate((next_adv_action_clip, next_actions_clip), axis = 1)
            else:
                agents_actions = np.concatenate((next_actions_clip, next_adv_action_clip), axis = 1)
           
            last_obs_all = np.asarray(self.envs.get_attr("before_norm_obs"))
            next_states_two, next_rewards_two, next_dones, infos = self.envs.step(agents_actions)
            next_obs_all = np.asarray(self.envs.get_attr("before_norm_obs"))

            add_perturb = self.params.ADV_ADVERSARY_RATIO >= random.random()
            update = self.apply_mvd_attack(last_states_oppo, last_states, agent_id=victim_id, collect_adversary_trajectory=False, horizon=1, dyn_horizon=self.dyn_horizon)
            perturbations = (update.sign() * self.MVD_EPS).detach().cpu().numpy()
            if add_perturb:
                if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
                    for j in range(self.NUM_ACTORS):
                        data = perturbations[j, 356:380] + next_obs_all[j, self.victim_id, 356:380]            
                        obs, normalized_obs = self.envs.env_method("set_pos_vel", *(data, 1 - self.victim_id), indices=j)[0]
                elif self.MUJOCO_GAME == "multicomp/KickAndDefend-v0":
                    for j in range(self.NUM_ACTORS):
                        data = perturbations[j, 354:378] + next_obs_all[j, self.victim_id, 354:378]  
                elif self.MUJOCO_GAME in {"multicomp/SumoAnts-v0", "multicomp/RunToGoalAnts-v0"}: 
                    for j in range(self.NUM_ACTORS):
                        data = perturbations[j, 107:122] + next_obs_all[j, self.victim_id, 107:122] 

                next_obs_all[j][self.victim_id] = obs[self.victim_id]
                next_obs_all[j][1 - self.victim_id] = obs[1 - self.victim_id]
                next_states_two[j][self.victim_id] = normalized_obs[self.victim_id]
                next_states_two[j][1 - self.victim_id] = normalized_obs[1 - self.victim_id]


            if self.MODE == 'adv_ppo': 
                for j in range(self.NUM_ACTORS):
                    last_obs = last_obs_all[j].reshape(2*self.NUM_FEATURES)
                    next_obs = next_obs_all[j].reshape(2*self.NUM_FEATURES)
                    last_acts = agents_actions[j].reshape(2*self.NUM_ACTIONS)
                    self.replay_buffer.add(last_obs, last_acts, next_obs, 0, next_dones[j])
            
            next_not_dones = np.logical_not(next_dones)
            tensor_maker = cpu_tensorize if self.CPU else cu_tensorize
            
            data = list(map(tensor_maker, [next_states_two[:, victim_id, :], next_rewards_two[:, victim_id], next_not_dones]))
            next_states, next_rewards, next_not_dones = data

            data = list(map(tensor_maker, [next_states_two[:, 1-victim_id, :], next_rewards_two[:, 1-victim_id]]))
            next_states_oppo, next_rewards_oppo = data

            done_info = []
            for not_done, info in zip(next_not_dones, infos):
                if not not_done:
                    done_info.append(info[victim_id]['done'])
                    if 'winner' in info[victim_id]:
                        outcomes.append(victim_id)
                    elif 'loser' in info[victim_id]:
                        outcomes.append(1-victim_id)
                    else:
                        outcomes.append(None)
            
            if len(done_info) > 0 and (t != self.T - 1 or len(completed_episode_info) == 0):
                completed_episode_info.extend(done_info)
            
                
            oppo_pairs = [
                (rewards_oppo, next_rewards_oppo),
                (not_dones_oppo, next_not_dones),
                (actions_oppo, next_adv_action.squeeze(1)), # The sampled actions, which is perturbations.
                (action_means_oppo, oppo_action_means), # The Gaussian mean of actions.
                # (action_stds, next_action_stds), # The Gaussian std of actions, is a constant, no need to save.
                (action_log_probs_oppo, next_oppo_log_probs),
                (states_oppo, next_states_oppo), # we save the true environment state without perturbation.
            ]

            pairs = [
                (rewards, next_rewards),
                (not_dones, next_not_dones),
                (actions, next_actions.squeeze(1)), # The sampled actions.
                (action_means, next_action_means), # The Gaussian mean of actions.
                # (action_stds, next_action_stds), # The Gaussian std of actions, is a constant, no need to save.
                (action_log_probs, next_action_log_probs),
                (states, next_states), # true environment state.
            ]
            for total, v in pairs:
                if total is states:
                    # Next states, stores in the next position.
                    total[:, t+1] = v
                else:
           
                    total[:, t] = v

            for total, v in oppo_pairs:
                if total is states_oppo:
                    # Next states, stores in the next position.
                    total[:, t+1] = v
                else:
           
                    total[:, t] = v

            last_states = next_states
            last_states_oppo = next_states_oppo

        c = Counter()
        c.update(outcomes)
        num_games = len(completed_episode_info)+1e-8
        print('*********************************************')
        print('game_win0', c.get(0,0) / num_games)
        print('game_win1', c.get(1,0) / num_games)
        print('game_tie', c.get(None,0) / num_games)
        print('*********************************************')
        print()

        infos = np.array(list(zip(*completed_episode_info)))

        if infos.size > 0:
            _, ep_rewards = infos
            avg_episode_length, avg_episode_reward = np.mean(infos, axis=1)
        else:
            ep_rewards = [-1]
            avg_episode_length = -1
            avg_episode_reward = -1

        trajs = Trajectories(rewards=rewards, 
            action_log_probs=action_log_probs, not_dones=not_dones, 
            actions=actions, states=states, action_means=action_means, action_std=next_action_stds)

        oppo_trajs = Trajectories(rewards=rewards_oppo, 
            action_log_probs=action_log_probs_oppo, not_dones=not_dones_oppo, 
            actions=actions_oppo, states=states_oppo, action_means=action_means_oppo, action_std=oppo_action_stds)

        to_ret = (avg_episode_length, avg_episode_reward, trajs)
        to_ret_oppo = (avg_episode_length, avg_episode_reward, oppo_trajs)

        if return_rewards:
            to_ret += (ep_rewards,)

        return to_ret, to_ret_oppo


    
    def apply_mvd_attack(self, last_states_atk, last_states, not_dones=None, horizon=1, dyn_horizon=False):
        self.policy_model.pause_history()
        self.val_model.pause_history()
        self.attacker_policy_model.pause_history()
        self.attacker_val_model.pause_history()

        if dyn_horizon:
            vdiff_list = torch.zeros((horizon+1, last_states_atk.shape[0], last_states_atk.shape[1]))
            max_h = torch.tensor([0]*last_states_atk.shape[0])
            axis1 = [i for i in range(last_states_atk.shape[0])]

        with torch.set_grad_enabled(True):
            last_states_atk = last_states_atk.clone().detach().requires_grad_()
            last_states = last_states.clone().detach().requires_grad_()
            # TODO: add not dones to  lstm input
            vdiff = self.attacker_val_model(last_states_atk, not_dones) - self.val_model(last_states, not_dones)
            if dyn_horizon:    
                update_0 = torch.autograd.grad(outputs=vdiff, inputs=last_states, grad_outputs=torch.ones_like(vdiff))[0] # t = 0, without using dynamics model
                vdiff_list[0] = update_0
                dummy_last_states = last_states.clone().detach()
                dummy_last_states.data[:, 356:380] = dummy_last_states.data[:, 356:380] + update_0[:, 356:380]
                cur_max_vdiff = self.attacker_val_model(last_states_atk, not_dones) - self.val_model(dummy_last_states, not_dones)
     
            for h in range(horizon):
                atk_action_pds = self.attacker_policy_model(last_states_atk, not_dones)
                next_actions_atk = self.attacker_policy_model.sample(atk_action_pds)
                actions_pd = self.policy_model(last_states, not_dones)
                next_actions = self.policy_model.sample(actions_pd)
                if self.victim_id == 1:
                    sapr = ch.cat((last_states_atk, last_states, next_actions_atk, next_actions), dim=1)
                    next_pred_obs = self.dynamics_model.forward(sapr, use_propagation=False)[0][0] + ch.cat((last_states_atk, last_states), dim=1)
                    last_states_atk = next_pred_obs[:, 0:next_pred_obs.shape[1]//2]
                    last_states = next_pred_obs[:, next_pred_obs.shape[1]//2:]
                    
                else:
                    sapr = ch.cat((last_states, last_states_atk, next_actions, next_actions_atk), dim=1)
                    next_pred_obs = self.dynamics_model.forward(sapr, use_propagation=False)[0][0] + ch.cat((last_states, last_states_atk), dim=1)
                    last_states = next_pred_obs[:, 0:next_pred_obs.shape[1]//2]
                    last_states_atk = next_pred_obs[:, next_pred_obs.shape[1]//2:]               
                vdiff += self.attacker_val_model(last_states_atk, not_dones) - self.val_model(last_states, not_dones)
                if dyn_horizon:
                    update = torch.autograd.grad(outputs=vdiff, inputs=last_states, grad_outputs=torch.ones_like(vdiff))[0]
                    vdiff_list[h+1] = update
                    dummy_last_states = last_states.clone().detach()
                    dummy_last_states.data[:, 356:380] = dummy_last_states.data[:, 356:380] + update_0[:, 356:380]
                    vdiff = self.attacker_val_model(last_states_atk, not_dones) - self.val_model(dummy_last_states, not_dones)
                    idx = torch.nonzero((cur_max_vdiff > vdiff).squeeze(), as_tuple=True)[0]
                    cur_max_vdiff = torch.max(cur_max_vdiff, vdiff)
                    for i in idx:
                        max_h[i] += 1
            if not dyn_horizon:
                static_update = torch.autograd.grad(outputs=vdiff, inputs=last_states, grad_outputs=torch.ones_like(vdiff))[0]

        self.policy_model.continue_history()
        self.val_model.continue_history()
        self.attacker_policy_model.continue_history()
        self.attacker_val_model.continue_history()

        if dyn_horizon:
            return vdiff_list[max_h, axis1, :], max_h.detach().cpu().numpy()
        else:
            return static_update, [horizon]
    '''
    def normalize_tensor_obs(self, env, tensor_obs):
        obs_mean0_tensor = torch.from_numpy(env.obs_rms_0.mean)
        obs_var0_tensor = torch.from_numpy(env.obs_rms_0.var)
        obs_mean1_tensor = torch.from_numpy(env.obs_rms_1.mean)
        obs_var1_tensor = torch.from_numpy(env.obs_rms_1.var)
        if tensor_obs.dim() == 3:
            np.clip((obs - obs_rms.mean) / np.sqrt(np.maximum(obs_rms.var, 1e-2)), -self.clip_obs, self.clip_obs)
            tensor_obs[:,0,:] = (tensor_obs[:,0,:] - obs_mean0_tensor)/torch.sqrt(torch.maximum(obs_var0_tensor, 1e-2))
            tensor_obs[:,1,:] = self._normalize_obs(tensor_obs[:,1,:], self.obs_rms_1)
            obs_ = np.stack((obs_0, obs_1), axis=1)
        elif tensor_obs.dim() == 2:
            obs_0 = self._normalize_obs(tensor_obs[0,:], self.obs_rms_0)
            obs_1 = self._normalize_obs(tensor_obs[1,:], self.obs_rms_1)
            obs_ = np.stack((obs_0, obs_1), axis=0)
        else:
            raise NotImplementedError
        return obs_
    '''
    """Conduct adversarial attack using value network."""
    def apply_attack(self, last_states):
        if self.params.ATTACK_RATIO < random.random():
            # Only attack a portion of steps.
            return last_states
        eps = self.params.ATTACK_EPS
        if eps == "same":
            eps = self.params.ROBUST_PPO_EPS
        else:
            eps = float(eps)
        steps = self.params.ATTACK_STEPS
        if self.params.ATTACK_METHOD == "critic":
            # Find a state that is close the last_states and decreases value most.
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                # Random start.
                noise = torch.empty_like(last_states).uniform_(-step_eps, step_eps)
                states = last_states + noise
                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        value = self.val_model(states).mean(dim=1)
                        value.backward()
                        update = states.grad.sign() * step_eps
                        # Clamp to +/- eps.
                        states.data = torch.min(torch.max(states.data - update, clamp_min), clamp_max)
                    self.val_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "random":
            # Apply an uniform random noise.
            noise = torch.empty_like(last_states).uniform_(-eps, eps)
            return (last_states + noise).detach()
        elif self.params.ATTACK_METHOD == "action" or self.params.ATTACK_METHOD == "action+imit":
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                # SGLD noise factor. We simply set beta=1.
                noise_factor = np.sqrt(2 * step_eps)
                noise = torch.randn_like(last_states) * noise_factor
                # The first step has gradient zero, so add the noise and projection directly.
                states = last_states + noise.sign() * step_eps
                # Current action at this state.
                if self.params.ATTACK_METHOD == "action+imit":
                    if not hasattr(self, "imit_network") or self.imit_network == None:
                        assert self.params.imit_model_path != None
                        print('\nLoading imitation network for attack: ', self.params.imit_model_path)
                        # Setup imitation network
                        self.setup_imit(train=False)
                        imit_ckpt = torch.load(self.params.imit_model_path)
                        self.imit_network.load_state_dict(imit_ckpt['state_dict'])
                        self.imit_network.reset()
                        self.imit_network.pause_history()
                    old_action, old_stdev = self.imit_network(last_states)
                else:
                    old_action, old_stdev = self.policy_model(last_states)
                # Normalize stdev, avoid numerical issue
                old_stdev /= (old_stdev.mean())
                old_action = old_action.detach()
                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        if self.params.ATTACK_METHOD == "action+imit":
                            action_change = (self.imit_network(states)[0] - old_action) / old_stdev
                        else:
                            action_change = (self.policy_model(states)[0] - old_action) / old_stdev
                        action_change = (action_change * action_change).sum(dim=1)
                        action_change.backward()
                        # Reduce noise at every step.
                        noise_factor = np.sqrt(2 * step_eps) / (i+2)
                        # Project noisy gradient to step boundary.
                        update = (states.grad + noise_factor * torch.randn_like(last_states)).sign() * step_eps
                        # Clamp to +/- eps.
                        states.data = torch.min(torch.max(states.data + update, clamp_min), clamp_max)
                    if self.params.ATTACK_METHOD == "action+imit": 
                        self.imit_network.zero_grad() 
                    self.policy_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "sarsa" or self.params.ATTACK_METHOD == "sarsa+action":
            # Attack using a learned value network.
            assert self.params.ATTACK_SARSA_NETWORK is not None
            use_action = self.params.ATTACK_SARSA_ACTION_RATIO > 0 and self.params.ATTACK_METHOD == "sarsa+action"
            action_ratio = self.params.ATTACK_SARSA_ACTION_RATIO
            assert action_ratio >= 0 and action_ratio <= 1
            if not hasattr(self, "sarsa_network"):
                self.sarsa_network = ValueDenseNet(state_dim=self.NUM_FEATURES+self.NUM_ACTIONS, init="normal")
                print("Loading sarsa network", self.params.ATTACK_SARSA_NETWORK)
                sarsa_ckpt = torch.load(self.params.ATTACK_SARSA_NETWORK)
                sarsa_meta = sarsa_ckpt['metadata']
                sarsa_eps = sarsa_meta['sarsa_eps'] if 'sarsa_eps' in sarsa_meta else "unknown"
                sarsa_reg = sarsa_meta['sarsa_reg'] if 'sarsa_reg' in sarsa_meta else "unknown"
                sarsa_steps = sarsa_meta['sarsa_steps'] if 'sarsa_steps' in sarsa_meta else "unknown"
                print(f"Sarsa network was trained with eps={sarsa_eps}, reg={sarsa_reg}, steps={sarsa_steps}")
                if use_action:
                    print(f"objective: {1.0 - action_ratio} * sarsa + {action_ratio} * action_change")
                else:
                    print("Not adding action change objective.")
                self.sarsa_network.load_state_dict(sarsa_ckpt['state_dict'])
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                # Random start.
                noise = torch.empty_like(last_states).uniform_(-step_eps, step_eps)
                states = last_states + noise
                if use_action:
                    # Current action at this state.
                    old_action, old_stdev = self.policy_model(last_states)
                    old_stdev /= (old_stdev.mean())
                    old_action = old_action.detach()
                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        # This is the mean action...
                        actions = self.policy_model(states)[0]
                        value = self.sarsa_network(torch.cat((last_states, actions), dim=1)).mean(dim=1)
                        if use_action:
                            action_change = (actions - old_action) / old_stdev
                            # We want to maximize the action change, thus the minus sign.
                            action_change = -(action_change * action_change).mean(dim=1)
                            loss = action_ratio * action_change + (1.0 - action_ratio) * value
                        else:
                            action_change = 0.0
                            loss = value
                        loss.backward()
                        update = states.grad.sign() * step_eps
                        # Clamp to +/- eps.
                        states.data = torch.min(torch.max(states.data - update, clamp_min), clamp_max)
                    self.val_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "advpolicy":
            # Attack using a learned policy network.
            assert self.params.ATTACK_ADVPOLICY_NETWORK is not None
            if not hasattr(self, "attack_policy_network"):
                self.attack_policy_network = self.policy_net_class(self.NUM_FEATURES, self.NUM_FEATURES,
                                                 self.INITIALIZATION,
                                                 time_in_state=self.VALUE_CALC == "time",
                                                 activation=self.policy_activation)
                print("Loading adversary policy network", self.params.ATTACK_ADVPOLICY_NETWORK)
                advpolicy_ckpt = torch.load(self.params.ATTACK_ADVPOLICY_NETWORK)
                self.attack_policy_network.load_state_dict(advpolicy_ckpt['attacker_policy_model'])
            # Unlike other attacks we don't need step or eps here.
            # We don't sample and use deterministic adversary policy here.
            perturbations_mean, _ = self.attack_policy_network(last_states)
            # Clamp using tanh.
            perturbed_states = last_states + ch.nn.functional.hardtanh(perturbations_mean) * eps
            """
            oppo_action_pds = self.attack_policy_network(last_states)
            next_adv_action = self.attack_policy_network.sample(oppo_action_pds)
            perturbed_states = last_states + ch.tanh(next_adv_action) * eps
            """
            return perturbed_states.detach()
        elif self.params.ATTACK_METHOD == "none":
            return last_states
        else:
            raise ValueError(f'Unknown attack method {self.params.ATTACK_METHOD}')
    
    def apply_multi_attack(self, last_states, last_states_oppo, not_dones):
        '''
        last_states: victim ob
        last_states_oppo: attacker ob
        '''
        if self.params.ATTACK_RATIO < random.random():
            # Only attack a portion of steps.
            return last_states
        eps = self.params.ATTACK_EPS
        if eps == "same":
            eps = self.params.ROBUST_PPO_EPS
        else:
            eps = float(eps)
        steps = self.params.ATTACK_STEPS
        if self.params.ATTACK_METHOD == "critic":
            # Find a state that is close the last_states and decreases value most.
            states = last_states.clone().detach()
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                # Random start.
                # noise = torch.empty_like(last_states).uniform_(-step_eps, step_eps)
                # if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
                #     states.data[:, 356:380] = last_states[:, 356:380] + noise[:, 356:380]
                # elif self.MUJOCO_GAME == "multicomp/KickAndDefend-v0":
                #     states.data[:, 354:378] = last_states[:, 354:378] + noise[:, 354:378]
                # elif self.MUJOCO_GAME in {"multicomp/SumoAnts-v0", "multicomp/RunToGoalAnts-v0"}:
                #     states.data[:, 107:122] = last_states[:, 107:122] + noise[:, 107:122]
                # else:
                #     raise NotImplementedError
                # states = last_states + noise
                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        value = self.val_model(states, not_dones).mean(dim=1)
                        value.backward()
                        update = states.grad.sign() * step_eps

                        # Clamp to +/- eps.
                        if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
                            states.data[:, 356:380] = states[:, 356:380] - update[:, 356:380]
                        elif self.MUJOCO_GAME == "multicomp/KickAndDefend-v0":
                            states.data[:, 354:378] = states[:, 354:378] - update[:, 354:378]
                        elif self.MUJOCO_GAME in {"multicomp/SumoAnts-v0", "multicomp/RunToGoalAnts-v0"}:
                            states.data[:, 107:122] = states[:, 107:122] - update[:, 107:122]
                        else:
                            raise NotImplementedError

                        states.data = torch.min(torch.max(states.data, clamp_min), clamp_max)
                    self.val_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "random":
            # Apply an uniform random noise.
            states = last_states.clone().detach()
            noise = torch.empty_like(states).uniform_(-eps/ steps, eps / steps)
            if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
                states.data[:, 356:380] = states[:, 356:380] + noise[:, 356:380]
            elif self.MUJOCO_GAME == "multicomp/KickAndDefend-v0":
                states.data[:, 354:378] = states[:, 354:378] + noise[:, 354:378]
            elif self.MUJOCO_GAME in {"multicomp/SumoAnts-v0", "multicomp/RunToGoalAnts-v0"}:
                states.data[:, 107:122] = states[:, 107:122] + noise[:, 107:122]
            else:
                raise NotImplementedError
            return states.detach()
        elif self.params.ATTACK_METHOD == "action" or self.params.ATTACK_METHOD == "action+imit":
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                # SGLD noise factor. We simply set beta=1.
                # noise_factor = np.sqrt(2 * step_eps)
                # noise = torch.randn_like(last_states) * noise_factor
                # The first step has gradient zero, so add the noise and projection directly.
                states = last_states #+ noise.sign() * step_eps
                # Current action at this state.
                if self.params.ATTACK_METHOD == "action+imit":
                    if not hasattr(self, "imit_network") or self.imit_network == None:
                        assert self.params.imit_model_path != None
                        print('\nLoading imitation network for attack: ', self.params.imit_model_path)
                        # Setup imitation network
                        self.setup_imit(train=False)
                        imit_ckpt = torch.load(self.params.imit_model_path)
                        self.imit_network.load_state_dict(imit_ckpt['state_dict'])
                        self.imit_network.reset()
                        self.imit_network.pause_history()
                    old_action, old_stdev = self.imit_network(last_states)
                else:
                    old_action, old_stdev = self.policy_model(last_states, not_dones)
                # Normalize stdev, avoid numerical issue
                old_stdev /= (old_stdev.mean())
                old_action = old_action.detach()
                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        if self.params.ATTACK_METHOD == "action+imit":
                            action_change = (self.imit_network(states, not_dones)[0] - old_action) / old_stdev
                        else:
                            action_change = (self.policy_model(states, not_dones)[0] - old_action) / old_stdev
                        action_change = (action_change * action_change).sum(dim=1)
                        action_change.backward()
                        # Reduce noise at every step.
                        noise_factor = np.sqrt(2 * step_eps) / (i+2)
                        # Project noisy gradient to step boundary.
                        update = (states.grad + noise_factor * torch.randn_like(last_states)).sign() * step_eps
                        # Clamp to +/- eps.
                        if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
                            states.data[:, 356:380] = states[:, 356:380] + update[:, 356:380]
                        elif self.MUJOCO_GAME == "multicomp/KickAndDefend-v0":
                            states.data[:, 354:378] = states[:, 354:378] + update[:, 354:378]
                        elif self.MUJOCO_GAME in {"multicomp/SumoAnts-v0", "multicomp/RunToGoalAnts-v0"}:
                            states.data[:, 107:122] = states[:, 107:122] + update[:, 107:122]
                        else:
                            raise NotImplementedError
                        states.data = torch.min(torch.max(states.data, clamp_min), clamp_max)
                    if self.params.ATTACK_METHOD == "action+imit": 
                        self.imit_network.zero_grad() 
                    self.policy_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "sarsa" or self.params.ATTACK_METHOD == "sarsa+action":
            # Attack using a learned value network.
            assert self.params.ATTACK_SARSA_NETWORK is not None
            use_action = self.params.ATTACK_SARSA_ACTION_RATIO > 0 and self.params.ATTACK_METHOD == "sarsa+action"
            action_ratio = self.params.ATTACK_SARSA_ACTION_RATIO
            assert action_ratio >= 0 and action_ratio <= 1
            if not hasattr(self, "sarsa_network"):
                self.sarsa_network = ValueDenseNet(state_dim=self.NUM_FEATURES+self.NUM_ACTIONS, init="normal")
                print("Loading sarsa network", self.params.ATTACK_SARSA_NETWORK)
                sarsa_ckpt = torch.load(self.params.ATTACK_SARSA_NETWORK)
                sarsa_meta = sarsa_ckpt['metadata']
                sarsa_eps = sarsa_meta['sarsa_eps'] if 'sarsa_eps' in sarsa_meta else "unknown"
                sarsa_reg = sarsa_meta['sarsa_reg'] if 'sarsa_reg' in sarsa_meta else "unknown"
                sarsa_steps = sarsa_meta['sarsa_steps'] if 'sarsa_steps' in sarsa_meta else "unknown"
                print(f"Sarsa network was trained with eps={sarsa_eps}, reg={sarsa_reg}, steps={sarsa_steps}")
                if use_action:
                    print(f"objective: {1.0 - action_ratio} * sarsa + {action_ratio} * action_change")
                else:
                    print("Not adding action change objective.")
                self.sarsa_network.load_state_dict(sarsa_ckpt['state_dict'])
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                # Random start.
                # noise = torch.empty_like(last_states).uniform_(-step_eps, step_eps)
                states = last_states
                if use_action:
                    # Current action at this state.
                    old_action, old_stdev = self.policy_model(last_states, not_dones)
                    old_stdev /= (old_stdev.mean())
                    old_action = old_action.detach()
                with torch.enable_grad():
                    for i in range(steps):
                        # print(states.data[:, 0:24])
                        states = states.clone().detach().requires_grad_()
                        actions = self.policy_model(states, not_dones)[0]
                        value = self.sarsa_network(torch.cat((last_states, actions), dim=1)).mean(dim=1)
                        if use_action:
                            action_change = (actions - old_action) / old_stdev
                            # We want to maximize the action change, thus the minus sign.
                            action_change = -(action_change * action_change).mean(dim=1)
                            loss = action_ratio * action_change + (1.0 - action_ratio) * value
                        else:
                            action_change = 0.0
                            loss = value
                        loss.backward()
                        update = states.grad.sign() * step_eps
                        # Clamp to +/- eps.
                        if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
                            states.data[:, 356:380] = states[:, 356:380] - update[:, 356:380]
                        elif self.MUJOCO_GAME == "multicomp/KickAndDefend-v0":
                            states.data[:, 354:378] = states[:, 354:378] - update[:, 354:378]
                        elif self.MUJOCO_GAME in {"multicomp/SumoAnts-v0", "multicomp/RunToGoalAnts-v0"}:
                            states.data[:, 107:122] = states[:, 107:122] - update[:, 107:122]
                        else:
                            raise NotImplementedError
                        states.data = torch.min(torch.max(states.data, clamp_min), clamp_max)
                        # print()
                        # print(states.data[:, 0:24])
                        
                    self.val_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "advpolicy":
            # Attack using a learned policy network.
            assert self.params.ATTACK_ADVPOLICY_NETWORK is not None
            if not hasattr(self, "attack_policy_network"):
                self.attack_policy_network = self.policy_net_class(self.NUM_FEATURES, self.NUM_FEATURES,
                                                 self.INITIALIZATION,
                                                 time_in_state=self.VALUE_CALC == "time",
                                                 activation=self.policy_activation)
                print("Loading adversary policy network", self.params.ATTACK_ADVPOLICY_NETWORK)
                advpolicy_ckpt = torch.load(self.params.ATTACK_ADVPOLICY_NETWORK)
                self.attack_policy_network.load_state_dict(advpolicy_ckpt['attacker_policy_model'])
            # Unlike other attacks we don't need step or eps here.
            # We don't sample and use deterministic adversary policy here.
            perturbations_mean, _ = self.attack_policy_network(last_states)
            # Clamp using tanh.
            perturbed_states = last_states + ch.nn.functional.hardtanh(perturbations_mean) * eps
            """
            oppo_action_pds = self.attack_policy_network(last_states)
            next_adv_action = self.attack_policy_network.sample(oppo_action_pds)
            perturbed_states = last_states + ch.tanh(next_adv_action) * eps
            """
            return perturbed_states.detach()
        elif self.params.ATTACK_METHOD == "pgd" or self.params.ATTACK_METHOD == "tactics":
            # Find a state that is close the last_states and decreases value most.
            compute_c = False
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                if self.params.ATTACK_METHOD == "tactics":
                    compute_c = True
                clamp_min = last_states - eps
                clamp_max = last_states + eps

                states = last_states        
         
                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        action_pds  = self.policy_model(states, not_dones)
                        actions, std = action_pds
                        tar_actions = torch.where((actions > 0), -0.4, 0.4)
                        tar_std = torch.ones_like(std) * 0.01           
                        if compute_c: 
                            a = self.policy_model.get_loglikelihood(action_pds, actions)
                            b = self.policy_model.get_loglikelihood(action_pds, tar_actions)
                            c = abs( a - b)
                        ce_loss =  (self.policy_model.calc_kl(action_pds, (tar_actions, tar_std)) + self.policy_model.entropies(action_pds)).mean()
                        ce_loss.backward()
                        update = states.grad.sign() * step_eps
                        if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
                            if compute_c and c < 200:
                                continue 
                            states.data[:, 356:380] = states[:, 356:380] + update[:, 356:380]
                        elif self.MUJOCO_GAME == "multicomp/KickAndDefend-v0":
                            states.data[:, 354:378] = states[:, 354:378] + update[:, 354:378]
                        elif self.MUJOCO_GAME in {"multicomp/SumoAnts-v0", "multicomp/RunToGoalAnts-v0"}:
                            states.data[:, 107:122] = states[:, 107:122] + update[:, 107:122]
                        else:
                            raise NotImplementedError
                        # Clamp to +/- eps.
                        states.data = torch.min(torch.max(states.data, clamp_min), clamp_max)
                    self.val_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "mvd":
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                states = last_states   
                oppo_states = last_states_oppo

                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        oppo_states = oppo_states.clone().detach().requires_grad_()
                        update = self.apply_mvd_attack(oppo_states, states, self.VICTIM_ID, False, horizon=1)
                        perturbations = update.sign() * self.MVD_EPS
                        if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
                            states[356:380] += perturbations
                            oppo_states[0:24] += perturbations
                        else:
                            raise NotImplementedError
                        states = torch.clamp(states, clamp_min, clamp_max)
                        oppo_states = torch.clamp(oppo_states, clamp_min, clamp_max)
                    self.val_model.zero_grad()
                return states.detach()
            else:
                return last_states


        elif self.params.ATTACK_METHOD == "none":
            return last_states
        else:
            raise ValueError(f'Unknown attack method {self.params.ATTACK_METHOD}')

    def get_oppo_new_states(self, states, oppo_states):
        if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
            oppo_states[:, 356:380] = states[:, 0:24]
        return oppo_states


    """Run trajectories and return saps and values for each state."""
    def collect_saps(self, num_saps, agent_id, should_log=True, return_rewards=False,
                     should_tqdm=False, test=False, collect_adversary_trajectory=False, simul_train=False, saps = None):
        table_name_suffix = "_adv" if collect_adversary_trajectory else ""
        with torch.no_grad():
            # Run trajectories, get values, estimate advantage
            if not simul_train:
                output = self.run_trajectories(num_saps, agent_id=agent_id,
                                            return_rewards=return_rewards,
                                            should_tqdm=should_tqdm,
                                            collect_adversary_trajectory=collect_adversary_trajectory)
            else:
                assert saps is not None
                output = saps

            if not return_rewards:
                avg_ep_length, avg_ep_reward, trajs = output
            else:
                avg_ep_length, avg_ep_reward, trajs, ep_rewards = output

            # No need to compute advantage function for testing.
            if not test:
                # If we are sharing weights between the policy network and 
                # value network, we use the get_value function of the 
                # *policy* to # estimate the value, instead of using the value
                # net
                if not self.SHARE_WEIGHTS:
                    if collect_adversary_trajectory:
                        if self.HISTORY_LENGTH > 0 and self.USE_LSTM_VAL:
                            values = trajs.values
                            # values = self.attacker_val_model(trajs.states, trajs.not_dones).squeeze(-1)
                        else:
                            values = self.attacker_val_model(trajs.states).squeeze(-1)
                    else:
                        if self.HISTORY_LENGTH > 0 and self.USE_LSTM_VAL:
                            values = trajs.values
                            # values = self.val_model(trajs.states, trajs.not_dones).squeeze(-1)
                        else:
                            values = self.val_model(trajs.states).squeeze(-1)
                else:
                    assert self.HISTORY_LENGTH < 1
                    if collect_adversary_trajectory:
                        values = self.attacker_policy_model.get_value(trajs.states).squeeze(-1)
                    else:
                        values = self.policy_model.get_value(trajs.states).squeeze(-1)

                # Calculate advantages and returns
                advantages, returns = self.advantage_and_return(trajs.rewards,
                                                values, trajs.not_dones)

                trajs.advantages = advantages
                trajs.returns = returns
                trajs.values = values[:, :-1]
                trajs.states = trajs.states[:, :-1]

                adv_np = advantages.detach().cpu().numpy()
                rew_np = trajs.rewards.detach().cpu().numpy()

                assert shape_equal_cmp(trajs.advantages, 
                                trajs.returns, trajs.values)

            # Logging
            if should_log:
                msg = "Current mean reward: [%.4f, %.4f] | mean episode length: %.2f"
                print(msg % (avg_ep_reward[0], avg_ep_reward[1], avg_ep_length))
                if not test:
                    self.store.log_table_and_tb('optimization'+table_name_suffix, {
                        'mean_reward_0': avg_ep_reward[0],
                        'mean_reward_1': avg_ep_reward[1],
                    })

            # Unroll the trajectories (actors, T, ...) -> (actors*T, ...)
            saps = trajs.unroll()

        to_ret = (saps, avg_ep_reward, avg_ep_length)
        if return_rewards:
            to_ret += (ep_rewards,)

        return to_ret


    def sarsa_steps(self, saps):
        # Begin advanged logging code
        assert saps.unrolled
        loss = torch.nn.SmoothL1Loss()
        action_std = torch.exp(self.policy_model.log_stdev).detach().requires_grad_(False)  # Avoid backprop twice.
        # We treat all value epochs as one epoch.
        self.sarsa_eps_scheduler.set_epoch_length(self.params.VAL_EPOCHS * self.params.NUM_MINIBATCHES)
        self.sarsa_beta_scheduler.set_epoch_length(self.params.VAL_EPOCHS * self.params.NUM_MINIBATCHES)
        # We count from 1.
        self.sarsa_eps_scheduler.step_epoch()
        self.sarsa_beta_scheduler.step_epoch()
        # saps contains state->action->reward and not_done.
        for i in range(self.params.VAL_EPOCHS):
            # Create minibatches with shuffuling
            state_indices = np.arange(saps.rewards.nelement())
            np.random.shuffle(state_indices)
            splits = np.array_split(state_indices, self.params.NUM_MINIBATCHES)

            # Minibatch SGD
            for selected in splits:
                def sel(*args):
                    return [v[selected] for v in args]

                self.sarsa_opt.zero_grad()
                sel_states, sel_actions, sel_rewards, sel_not_dones = sel(saps.states, saps.actions, saps.rewards, saps.not_dones)
                self.sarsa_eps_scheduler.step_batch()
                self.sarsa_beta_scheduler.step_batch()
                
                inputs = torch.cat((sel_states, sel_actions), dim=1)
                # action_diff = self.sarsa_eps_scheduler.get_eps() * action_std
                # inputs_lb = torch.cat((sel_states, sel_actions - action_diff), dim=1).detach().requires_grad_(False)
                # inputs_ub = torch.cat((sel_states, sel_actions + action_diff), dim=1).detach().requires_grad_(False)
                # bounded_inputs = BoundedTensor(inputs, ptb=PerturbationLpNorm(norm=np.inf, eps=None, x_L=inputs_lb, x_U=inputs_ub))
                bounded_inputs = BoundedTensor(inputs, ptb=PerturbationLpNorm(norm=np.inf, eps=self.sarsa_eps_scheduler.get_eps()))

                q = self.relaxed_sarsa_model(bounded_inputs).squeeze(-1)
             
                q_old = q[:-1]
                q_next = q[1:] * self.GAMMA * sel_not_dones[:-1] + sel_rewards[:-1]
                q_next = q_next.detach()
                # q_loss = (q_old - q_next).pow(2).sum(dim=-1).mean()
                q_loss = loss(q_old, q_next)
                # Compute the robustness regularization.
                if self.sarsa_eps_scheduler.get_eps() > 0 and self.params.SARSA_REG > 0:
                    beta = self.sarsa_beta_scheduler.get_eps()
                    ilb, iub = self.relaxed_sarsa_model.compute_bounds(IBP=True, method=None)
                    if beta < 1:
                        clb, cub = self.relaxed_sarsa_model.compute_bounds(IBP=False, method='backward')
                        lb = beta * ilb + (1 - beta) * clb
                        ub = beta * iub + (1 - beta) * cub
                    else:
                        lb = ilb
                        ub = iub
                    # Output dimension is 1. Remove the extra dimension and keep only the batch dimension.
                    lb = lb.squeeze(-1)
                    ub = ub.squeeze(-1)
                    diff = torch.max(ub - q, q - lb)
                    reg_loss = self.params.SARSA_REG * (diff * diff).mean()
                    sarsa_loss = q_loss + reg_loss
                    reg_loss = reg_loss.item()
                else:
                    reg_loss = 0.0
                    sarsa_loss = q_loss
                sarsa_loss.backward()
                self.sarsa_opt.step()
            print(f'q_loss={q_loss.item():.6g}, reg_loss={reg_loss:.6g}, sarsa_loss={sarsa_loss.item():.6g}')

        if self.ANNEAL_LR:
            self.sarsa_scheduler.step()
        # print('value:', self.val_model(saps.states).mean().item())

        return q_loss, q.mean()


    def take_steps(self, saps, logging=True, value_only=False, adversary_step=False, increment_scheduler=True):
        if adversary_step:
            # collect adversary trajectory is only valid in minimax training mode.
            assert self.MODE == "adv_ppo" or self.MODE == "adv_trpo" or self.MODE == "adv_sa_ppo" or self.MODE == "adv_radial"

        # Begin advanged logging code
        assert saps.unrolled
        should_adv_log = self.advanced_logging and \
                     self.n_steps % self.log_every == 0 and logging

        self.params.SHOULD_LOG_KL = self.advanced_logging and \
                        self.KL_APPROXIMATION_ITERS != -1 and \
                        self.n_steps % self.KL_APPROXIMATION_ITERS == 0
        store_to_pass = self.store if should_adv_log else None
        # End logging code

        if adversary_step:
            policy_model = self.attacker_policy_model
            if self.ANNEAL_LR:
                policy_scheduler = self.ADV_POLICY_SCHEDULER
                val_scheduler = self.ADV_VALUE_SCHEDULER
            policy_params = Parameters(self.params.copy())
            # In policy_step(), some hard coded attributes will be accessed. We override them.
            policy_params.PPO_LR = self.ADV_PPO_LR_ADAM
            policy_params.PPO_LR_ADAM = self.ADV_PPO_LR_ADAM
            policy_params.POLICY_ADAM = self.ADV_POLICY_ADAM
            policy_params.CLIP_EPS = policy_params.ADV_CLIP_EPS
            policy_params.ENTROPY_COEFF = policy_params.ADV_ENTROPY_COEFF
            val_model = self.attacker_val_model
            val_opt = self.attacker_val_opt
            table_name_suffix = '_adv'
        else:
            policy_model = self.policy_model
            if self.ANNEAL_LR:
                policy_scheduler = self.POLICY_SCHEDULER
                val_scheduler = self.VALUE_SCHEDULER
            policy_params = self.params
            val_model = self.val_model
            val_opt = self.val_opt
            table_name_suffix = ''

        if should_adv_log:
            # collect some extra trajactory for validation of KL and max KL.
            num_saps = saps.advantages.shape[0]
            val_saps = self.collect_saps(num_saps, should_log=False, collect_adversary_trajectory=adversary_step)[0]

            out_train = policy_model(saps.states)
            out_val = policy_model(val_saps.states)

            old_pds = select_prob_dists(out_train, detach=True)
            val_old_pds = select_prob_dists(out_val, detach=True)

        # Update the value function before unrolling the trajectories
        # Pass the logging data into the function if applicable
        val_loss = ch.tensor(0.0)
        if not self.SHARE_WEIGHTS:
            val_loss = value_step(saps.states, saps.returns, 
                saps.advantages, saps.not_dones, val_model,
                val_opt, self.params, store_to_pass,
                old_vs=saps.values.detach()).mean()

        if self.ANNEAL_LR and increment_scheduler:
            val_scheduler.step()

        if value_only:
            # Run the value iteration only. Return now.
            return val_loss

        if logging:
            self.store.log_table_and_tb('optimization'+table_name_suffix, {
                'final_value_loss': val_loss
            })

        if (self.MODE in ['robust_ppo','radial_ppo'] or self.MODE == 'adv_sa_ppo' or self.MODE == 'adv_radial') and not adversary_step and logging:
            # Logging Robust PPO KL, entropy, etc.
            store_to_pass = self.store

        # Take optimizer steps
        args = [saps.states, saps.actions, saps.action_log_probs,
                saps.rewards, saps.returns, saps.not_dones, 
                saps.advantages, policy_model, policy_params, 
                store_to_pass, self.n_steps]

        if (self.MODE in ['robust_ppo','radial_ppo'] or 'robust_ppo' in self.MODE or self.MODE == 'adv_sa_ppo' or self.MODE == 'adv_radial') and \
            (isinstance(self.policy_model, CtsPolicy) or isinstance(self.policy_model, CtsPolicy_stdv) or isinstance(self.policy_model, CtsLSTMPolicy)) and not adversary_step:
            args += [self.relaxed_policy_model, self.robust_eps_scheduler, self.robust_beta_scheduler, saps.action_means]

        self.MAX_KL += self.MAX_KL_INCREMENT
        if should_adv_log:
            # Save old parameter to investigate weight updates.
            old_parameter = copy.deepcopy(self.policy_model.state_dict())

        # Policy optimization step
        if adversary_step:
            policy_loss, surr_loss, entropy_bonus = self.adversary_policy_step(*args)
        else:
            policy_loss, surr_loss, entropy_bonus = self.policy_step(*args)

        # If the anneal_lr option is set, then we decrease the 
        # learning rate at each training step
        if self.ANNEAL_LR and increment_scheduler:
            policy_scheduler.step()

        if should_adv_log and not adversary_step:
            log_value_losses(self, val_saps, 'heldout')
            log_value_losses(self, saps, 'train')
            old_pds = saps.action_means, saps.action_std
            paper_constraints_logging(self, saps, old_pds,
                            table='paper_constraints_train')
            paper_constraints_logging(self, val_saps, val_old_pds,
                            table='paper_constraints_heldout')
            log_weight_updates(self, old_parameter, self.policy_model.state_dict())

            self.store['paper_constraints_train'].flush_row()
            self.store['paper_constraints_heldout'].flush_row()
            self.store['value_data'].flush_row()
            self.store['weight_updates'].flush_row()
        if (self.params.MODE == 'robust_ppo' or self.params.MODE == 'adv_sa_ppo') and not adversary_step:
            self.store['robust_ppo_data'].flush_row()

        if self.ANNEAL_LR:
            print(f'val lr: {val_scheduler.get_last_lr()}, policy lr: {policy_scheduler.get_last_lr()}')
        val_loss = val_loss.mean().item()
        return policy_loss, surr_loss, entropy_bonus, val_loss

    def train_step(self):
        if self.MODE == "adv_ppo" or self.MODE == "adv_trpo" or self.MODE == "adv_sa_ppo" or self.MODE == "adv_radial":
            avg_ep_reward = 0.0
            if self.PPO_LR_ADAM != 0.0:
                if not self.SIMUL_TRAIN:
                    for i in range(int(self.ADV_POLICY_STEPS)):
                        avg_ep_reward = self.train_step_impl(adversary_step = False, increment_scheduler = (i==self.ADV_POLICY_STEPS-1), agent_id = self.VICTIM_ID)
                    for i in range(int(self.ADV_ADVERSARY_STEPS)):
                        self.train_step_impl(adversary_step = True, increment_scheduler = (i==self.ADV_ADVERSARY_STEPS-1), agent_id = 1 - self.VICTIM_ID)
                else:
                    saps, oppo_saps= self.run_whole_trajectories(num_saps= self.T * self.NUM_ACTORS, victim_id=self.VICTIM_ID)
                    for i in range(int(self.ADV_POLICY_STEPS)):
                        self.train_step_impl(adversary_step = False, increment_scheduler = (i==self.ADV_POLICY_STEPS-1), agent_id = self.VICTIM_ID, simul_train=True, saps=saps)
                    for i in range(int(self.ADV_ADVERSARY_STEPS)):
                        self.train_step_impl(adversary_step = True, increment_scheduler = (i==self.ADV_ADVERSARY_STEPS-1), agent_id = 1 - self.VICTIM_ID, simul_train=True, saps=oppo_saps)
            else:
                print('skipping policy training because learning rate is 0. adv_policy_steps and adv_adversary_steps ignored.')
                avg_ep_reward = self.train_step_impl(adversary_step = True)
        else:
            avg_ep_reward = self.train_step_impl(adversary_step = False)

        self.n_steps += 1
        print()
        return avg_ep_reward

    def train_step_impl(self, adversary_step=False, increment_scheduler=True, agent_id = 0, simul_train=False, saps=None):
        '''
        Take a training step, by first collecting rollouts, then 
        calculating advantages, then taking a policy gradient step, and 
        finally taking a value function step.

        Inputs: None
        Returns: 
        - The current reward from the policy (per actor)
        '''
        start_time = time.time()

        table_name_suffix = "_adv" if adversary_step else ""

        if adversary_step:
            print('++++++++ Attacker training ++++++++++')
            policy_model = self.attacker_policy_model
        else:
            print('++++++++ Victim training ++++++++++')
            policy_model = self.policy_model

        num_saps = self.T * self.NUM_ACTORS
        if not simul_train:
            saps, avg_ep_reward, avg_ep_length = self.collect_saps(num_saps, collect_adversary_trajectory=adversary_step, agent_id=agent_id)
        else:
            assert saps is not None
            saps, avg_ep_reward, avg_ep_length = self.collect_saps(num_saps, collect_adversary_trajectory=adversary_step, agent_id=agent_id,  simul_train=True, saps = saps)
        policy_loss, surr_loss, entropy_bonus, val_loss = self.take_steps(saps, adversary_step=adversary_step, increment_scheduler=increment_scheduler)
        if self.MODE == 'adv_ppo':
            training_losses, val_scores = self.dynamics_step()
        # Logging code
        print()
        print(f"Policy Loss: {policy_loss:.5g}, | Entropy Bonus: {entropy_bonus:.5g}, | Value Loss: {val_loss:.5g}")
        if self.MODE == 'adv_ppo':
            print(f"Dynamics Model Training Loss: {training_losses}")
            print(f"Validation Loss: {val_scores}")
        print("Time elapsed (s):", time.time() - start_time)
        if not policy_model.discrete:
            mean_std = ch.exp(policy_model.log_stdev).mean()
            print("Agent stdevs: %s" % mean_std.detach().cpu().numpy())
            self.store.log_table_and_tb('optimization'+table_name_suffix, {
                'mean_std': mean_std,
                'final_policy_loss' : policy_loss,
                'final_surrogate_loss': surr_loss,
                'entropy_bonus': entropy_bonus,
            })
        else:
            self.store['optimization'+table_name_suffix].update_row({
                'mean_std': np.nan,
                'final_policy_loss' : policy_loss,
                'final_surrogate_loss': surr_loss,
                'entropy_bonus': entropy_bonus,
            })

        self.store['optimization'+table_name_suffix].flush_row()
        self.store['rollout'].flush_row()
        print("-" * 80)
        sys.stdout.flush()
        sys.stderr.flush()
        # End logging code

        return avg_ep_reward

    def dynamics_step(self):
        if self.cfg.algorithm.normalize:
            self.dynamics_model.update_normalizer(self.replay_buffer.get_all())

        dataset_train, dataset_val = common_util.get_basic_buffer_iterators(
                self.replay_buffer,
                batch_size=self.cfg.overrides.model_batch_size,
                val_ratio=self.cfg.overrides.validation_ratio,
                ensemble_size=self.ENSEMBLE_SIZE,
                shuffle_each_epoch=True,
                bootstrap_permutes=False,  # build bootstrap dataset using sampling with replacement
            )
        training_losses, val_scores = self.dyn_model_trainer.train(
                dataset_train, 
                dataset_val=dataset_val, 
                num_epochs=5, 
                patience=50, 
                callback=self.train_callback)
        return training_losses, val_scores

    def sarsa_step(self):
        '''
        Take a training step, by first collecting rollouts, and 
        taking a value function step.

        Inputs: None
        Returns: 
        - The current reward from the policy (per actor)
        '''
        print("-" * 80)
        start_time = time.time()

        num_saps = self.T * self.NUM_ACTORS
        saps, avg_ep_reward, avg_ep_length = self.collect_saps(num_saps, agent_id = self.VICTIM_ID, should_log=True, test=True)
         
        sarsa_loss, q = self.sarsa_steps(saps)
        print("Sarsa Loss:", sarsa_loss.item())
        print("Q:", q.item())
        print("Time elapsed (s):", time.time() - start_time)
        sys.stdout.flush()
        sys.stderr.flush()

        self.n_steps += 1
        return avg_ep_reward

    def run_test(self, max_len=40960, compute_bounds=False, use_full_backward=False, original_stdev=None):
        assert self.NUM_ACTORS == 1
        # print("-" * 80)
        start_time = time.time()
        if compute_bounds and not hasattr(self, "relaxed_policy_model"):
            self.create_relaxed_model()
        #saps, avg_ep_reward, avg_ep_length = self.collect_saps(num_saps=None, should_log=True, test=True, num_episodes=num_episodes)
        with torch.no_grad():
            output = self.run_test_trajectories(max_len=max_len)
            ep_length, ep_reward_0, ep_reward_1, actions, action_means, states, win_0, win_1, tie = output
            msg = "agent 0 win rate: %f | agent 1 win rate: %f"
            print(msg % (win_0, win_1))
            if compute_bounds:
                if original_stdev is None:
                    kl_stdev = torch.exp(self.policy_model.log_stdev)
                else:
                    kl_stdev = torch.exp(original_stdev)
                eps = float(self.params.ROBUST_PPO_EPS) if self.params.ATTACK_EPS == "same" else float(self.params.ATTACK_EPS)
                kl_upper_bound = get_state_kl_bound(self.relaxed_policy_model, states, action_means,
                        eps=eps, beta=0.0,
                        stdev=kl_stdev, use_full_backward=use_full_backward).mean()
                kl_upper_bound = kl_upper_bound.item()
            else:
                kl_upper_bound = float("nan")
            # Unroll the trajectories (actors, T, ...) -> (actors*T, ...)
        return ep_length, ep_reward_0, ep_reward_1, actions.cpu().numpy(), action_means.cpu().numpy(), states.cpu().numpy(), kl_upper_bound,  win_0, win_1, tie

    def run_test_trajectories(self, max_len, should_tqdm=False):

        # self.envs.seed(2)
        initial_states_all = self.envs.reset()
        initial_states = initial_states_all[:, self.victim_id, :]
        initial_states_oppo = initial_states_all[:, 1-self.victim_id, :]

        if self.CPU:
            initial_states = cpu_tensorize(initial_states)
            initial_states_oppo = cpu_tensorize(initial_states_oppo)
        else:
            initial_states = cu_tensorize(initial_states)
            initial_states_oppo = cu_tensorize(initial_states_oppo)

        if hasattr(self, "imit_network"):
            self.imit_network.reset()

        self.policy_model.reset()
        self.val_model.reset()
        self.attacker_policy_model.reset()
        self.attacker_val_model.reset()

        # Holds information (length and true reward) about completed episodes
        completed_episode_info = []

        shape = (1, max_len)
        rewards = ch.zeros(shape)

        actions_shape = shape + (self.NUM_ACTIONS,)
        actions = ch.zeros(actions_shape)
        action_means = ch.zeros(actions_shape)

        states_shape = (self.NUM_ACTORS, max_len+1, initial_states.shape[1]) 
        states =  ch.zeros(states_shape)
        next_not_dones = torch.ones(1)

        iterator = range(max_len) if not should_tqdm else tqdm.trange(max_len)
        tensor_maker = cpu_tensorize if self.CPU else cu_tensorize

        states[:, 0, :] = initial_states
        last_states = states[:, 0, :]
        last_states_oppo = initial_states_oppo
        
        outcomes = [] 
        for t in iterator:
            # if (t+1) % 100 == 0:
            #     print('Step {} '.format(t+1))
            # assert shape_equal([self.NUM_ACTORS, self.NUM_FEATURES], last_states)
            # Retrieve probabilities 
            # action_pds: (# actors, # actions), prob dists over actions
            # next_actions: (# actors, 1), indices of actions
            
            # pause updating hidden state because the attack may inference the model.
            self.policy_model.pause_history()
            self.val_model.pause_history()
            self.attacker_policy_model.pause_history()
            self.attacker_val_model.pause_history()
            if hasattr(self, "imit_network"):
                self.imit_network.pause_history()
            # temp = last_states_oppo.detach().clone()
            # print(temp)
            before_act_victim = self.policy_model(last_states, next_not_dones)          
            before_act_atk = self.attacker_policy_model(last_states_oppo, next_not_dones)
            # print(before_act_atk)
            maybe_attacked_last_states = self.apply_multi_attack(last_states, last_states_oppo, next_not_dones)
            assert torch.equal(maybe_attacked_last_states[:, 0:24], last_states[:, 0:24])
            # print(torch.eq(temp, last_states_oppo))
            # print(last_states_oppo)
            # print(maybe_attacked_last_states[:, 107:122], last_states[:, 107:122])

#             if not torch.equal(maybe_attacked_last_states, last_states):
#                 for j in range(self.NUM_ACTORS):
#                     if self.MUJOCO_GAME in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:
#                         old_attacker_pos = self.envs.get_original_obs()[:, 1-self.victim_id, 0:24].squeeze()
#                         perturbations = (maybe_attacked_last_states - last_states).detach().cpu().numpy().squeeze()[356:380]
#                         attacker_pos = old_attacker_pos + perturbations
#                     elif self.MUJOCO_GAME == "multicomp/KickAndDefend-v0":
#                         old_attacker_pos = self.envs.get_original_obs()[:, self.victim_id, 354:378].squeeze()
#                         perturbations = (maybe_attacked_last_states - last_states).detach().cpu().numpy().squeeze()[354:378]
#                         attacker_pos = perturbations
#                     elif self.MUJOCO_GAME in {"multicomp/SumoAnts-v0", "multicomp/RunToGoalAnts-v0"}: 
#                         old_attacker_pos = self.envs.get_original_obs()[:, self.victim_id, 107:122].squeeze()
#                         perturbations = (maybe_attacked_last_states - last_states).detach().cpu().numpy().squeeze()[107:122]
#                         attacker_pos = old_attacker_pos + perturbations
#                     else:
#                         raise NotImplementedError
                    
#                     _, obs = self.envs.env_method("set_pos_vel", *(attacker_pos, 1 - self.victim_id), indices=j)
                    
#                 maybe_attacked_last_states_oppo = tensor_maker(obs[1 - self.victim_id]).unsqueeze(dim=0)
#             else:

            maybe_attacked_last_states_oppo = last_states_oppo
                
            self.policy_model.continue_history()
            self.val_model.continue_history()
            self.attacker_policy_model.continue_history()
            self.attacker_val_model.continue_history()

            if hasattr(self, "imit_network"):
                self.imit_network.continue_history()

            action_pds = self.policy_model(maybe_attacked_last_states, next_not_dones)
            # temp_action = self.policy_model(last_states, next_not_dones)
            oppo_action_pds = self.attacker_policy_model(maybe_attacked_last_states_oppo, next_not_dones)
            # print(oppo_action_pds[0], before_act_victim[0])

            if hasattr(self, "imit_network"):
                _ = self.imit_network(maybe_attacked_last_states) 
            
            # next_oppo_action = self.attacker_policy_model.sample(oppo_action_pds)
            next_oppo_action = oppo_action_pds[0]
            next_oppo_action = next_oppo_action.unsqueeze(1) #(num_envs, 1, action_dims)
            # next_oppo_action_clip = np.clip(next_oppo_action.detach().cpu().numpy(), self.envs.action_space.low, self.envs.action_space.high)
            next_oppo_action_clip = next_oppo_action.detach().cpu().numpy()

            next_action_means, _ = action_pds
            # next_actions = self.policy_model.sample(action_pds)
            next_actions = next_action_means
            next_actions = next_actions.unsqueeze(1)
            # next_actions_clip = np.clip(next_actions.detach().cpu().numpy(), self.envs.action_space.low, self.envs.action_space.high)
            next_actions_clip = next_actions.detach().cpu().numpy()

            if self.victim_id == 1:
                agents_actions = np.concatenate((next_oppo_action_clip, next_actions_clip), axis = 1)
            else:
                agents_actions = np.concatenate((next_actions_clip, next_oppo_action_clip), axis = 1)
            
            next_states_two, next_rewards, next_dones, infos = self.envs.step(agents_actions)
           
            # res = self.envs.render('rgb_array')
            # im = Image.fromarray(res)
            # im.save("imgs/%s_step.jpg" %(str(t)))
           
            # Double check if the attack is within eps range.
            if self.params.ATTACK_METHOD != "none":
                max_eps = (maybe_attacked_last_states - last_states).abs().max()
                attack_eps = float(self.params.ROBUST_PPO_EPS) if self.params.ATTACK_EPS == "same" else float(self.params.ATTACK_EPS)
                if max_eps > attack_eps + 1e-5:
                    raise RuntimeError(f"{max_eps} > {attack_eps}. Attack implementation has bug and eps is not correctly handled.")
         
            next_not_dones = np.logical_not(next_dones)
            last_states_oppo = tensor_maker(next_states_two[:, 1-self.victim_id, :])
            data = list(map(tensor_maker, [next_states_two[:, self.victim_id, :], next_rewards[:, self.victim_id], next_not_dones]))
            next_states, next_rewards, next_not_dones = data
            done_info = []
    
            for not_done, info in zip(next_not_dones, infos):
                if not not_done:
                    done_info.append(info[self.victim_id]['done'])
                    if 'winner' in info[self.victim_id]:
                        outcomes.append(self.victim_id)
                    elif 'loser' in info[self.victim_id]:
                        outcomes.append(1-self.victim_id)
                    else:
                        outcomes.append(None)
            # done_info = List of (length, reward) pairs for each completed trajectory
            # (next_rewards, next_states, next_dones) act like multi-actor env.step()
            # done_info, next_rewards, next_states, next_not_dones = ret
            # Reset the policy (if the policy has memory if we are done)
            if torch.numel(next_not_dones) - next_not_dones.count_nonzero() > 0:
                self.policy_model.reset()
                self.val_model.reset()
                self.attacker_policy_model.reset()
                self.attacker_val_model.reset()

            # Update histories
            # each shape: (nact, t, ...) -> (nact, t + 1, ...)

            pairs = [
                (rewards, next_rewards),
                (actions, next_actions), # The sampled actions.
                (action_means, next_action_means), # The sampled actions.
                (states, next_states),
            ]

            # last_states = next_states[:, 0, :]
            last_states = next_states
            for total, v in pairs:
                if total is states:
                    # Next states, stores in the next position.
                    total[:, t+1] = v
                else:
                    # The current action taken, and reward received.
                    total[:, t] = v
            
            # If some of the actors finished AND this is not the last step
            # OR some of the actors finished AND we have no episode information
            if len(done_info) > 0:
                completed_episode_info.extend(done_info)
              

        c = Counter()
        c.update(outcomes)
        num_games = len(completed_episode_info)
        win_0 = c.get(0,0) / num_games
        win_1 = c.get(1,0) / num_games
        tie = c.get(None,0) / num_games
        print('*********************************************')
        print('game_win0', win_0)
        print('game_win1', win_1)
        print('game_tie', tie)
        print('total_games', num_games)
        print('*********************************************')
        print()

        episode_infos = np.array(list(zip(*completed_episode_info)))
        if episode_infos.size > 0:
            ep_length, ep_reward_0, ep_reward_1 = np.mean(episode_infos, axis=1)
        else:
            ep_length = np.nan
            ep_reward = np.nan

        actions = actions[0][:t+1]
        action_means = action_means[0][:t+1]
        states = states[0][:t+1]

        to_ret = (ep_length, ep_reward_0, ep_reward_1, actions, action_means, states, win_0, win_1, tie)
        
        
        return to_ret

    def load_venv_rms(self, path0, path1, step0=None, step1=None):
        # use "none" in run_attack.sh
        if step0 != '' and  step0 !='none':
            for name in ['obs_rms_0', 'ret_rms_0']: 
                rms_name = f"{name}_step{step0}"
                with open("{}/{}.pkl".format(path0, rms_name), 'rb') as file_handler:
                    setattr(self.envs, name, pickle.load(file_handler))
            print("Loading running mean statistics of agent 0 at: ", path0)
        if step1 != '' and  step1 !='none':
            for name in ['obs_rms_1', 'ret_rms_1']: 
                rms_name = f"{name}_step{step1}"
                with open("{}/{}.pkl".format(path1, rms_name), 'rb') as file_handler:
                    setattr(self.envs, name, pickle.load(file_handler))
            print("Loading running mean statistics of agent 1 at: ", path1)
        self.envs.training = False

    def load_adv_rms(self, path, attacker_id):
        rms_name = ['obs_rms', 'ret_rms']
        for i, name in enumerate(['obs_rms_%s' %(attacker_id), 'ret_rms_%s' %(attacker_id)]): 
            with open("{}/{}.pkl".format(path, rms_name[i]), 'rb') as file_handler:
                setattr(self.envs, name, pickle.load(file_handler))
        


    @staticmethod
    def agent_from_data(store, row, cpu, extra_params=None, override_params=None, excluded_params=None):
        '''
        Initializes an agent from serialized data (via cox)
        Inputs:
        - store, the name of the store where everything is logged
        - row, the exact row containing the desired data for this agent
        - cpu, True/False whether to use the CPU (otherwise sends to GPU)
        - extra_params, a dictionary of extra agent parameters. Only used
          when a key does not exist from the loaded cox store.
        - override_params, a dictionary of agent parameters that will override
          current agent parameters.
        - excluded_params, a dictionary of parameters that we do not copy or
          override.
        Outputs:
        - agent, a constructed agent with the desired initialization and
              parameters
        - agent_params, the parameters that the agent was constructed with
        '''

        ckpts = store['final_results']

        get_item = lambda x: list(row[x])[0]

        items = ['val_model', 'policy_model', 'val_opt', 'policy_opt']
        names = {i: get_item(i) for i in items}

        param_keys = list(store['metadata'].df.columns)
        param_values = list(store['metadata'].df.iloc[0,:])

        def process_item(v):
            try:
                return v.item()
            except:
                return v

        param_values = [process_item(v) for v in param_values]
        agent_params = {k:v for k, v in zip(param_keys, param_values)}

        if 'adam_eps' not in agent_params: 
            agent_params['adam_eps'] = 1e-5
        if 'cpu' not in agent_params:
            agent_params['cpu'] = cpu

        # Update extra params if they do not exist in current parameters.
        if extra_params is not None:
            for k in extra_params.keys():
                if k not in agent_params and k not in excluded_params:
                    print(f'adding key {k}={extra_params[k]}')
                    agent_params[k] = extra_params[k]
        if override_params is not None:
            for k in override_params.keys():
                if k not in excluded_params and override_params[k] is not None and override_params[k] != agent_params[k]:
                    print(f'overwriting key {k}: old={agent_params[k]}, new={override_params[k]}')
                    agent_params[k] = override_params[k]

        agent = Trainer.agent_from_params(agent_params)

        def load_state_dict(model, ckpt_name):
            mapper = ch.device('cuda:0') if not cpu else ch.device('cpu')
            state_dict = ckpts.get_state_dict(ckpt_name, map_location=mapper)
            model.load_state_dict(state_dict)

        load_state_dict(agent.policy_model, names['policy_model'])
        load_state_dict(agent.val_model, names['val_model'])
        if agent.ANNEAL_LR:
            agent.POLICY_SCHEDULER.last_epoch = get_item('iteration')
            agent.VALUE_SCHEDULER.last_epoch = get_item('iteration')
        load_state_dict(agent.POLICY_ADAM, names['policy_opt'])
        load_state_dict(agent.val_opt, names['val_opt'])
        agent.envs = ckpts.get_pickle(get_item('envs'))

        return agent, agent_params

    @staticmethod
    def agent_from_params(params, store=None):
        '''
        Construct a trainer object given a dictionary of hyperparameters.
        Trainer is in charge of sampling trajectories, updating policy network,
        updating value network, and logging.
        Inputs:
        - params, dictionary of required hyperparameters
        - store, a cox.Store object if logging is enabled
        Outputs:
        - A Trainer object for training a PPO/TRPO agent
        '''
        # if params['history_length'] > 0:
        #     agent_policy = CtsLSTMPolicy
        #     if params['use_lstm_val']:
        #         agent_value = ValueLSTMNet
        #     else:
        #         agent_value = value_net_with_name(params['value_net_type'])
        # else:
        agent_policy = policy_net_with_name(params['policy_net_type'])
        agent_value = value_net_with_name(params['value_net_type'])

        advanced_logging = params['advanced_logging'] and store is not None
        log_every = params['log_every'] if store is not None else 0

        if params['cpu']:
            torch.set_num_threads(1)
        p = Trainer(agent_policy, agent_value, params, store, log_every=log_every,
                    advanced_logging=advanced_logging)

        return p

