import os
from re import L
import numpy as np
from PIL import Image
from gym.spaces.discrete import Discrete
from gym.spaces.box import Box as Continuous
import gym
import random

from pytest import param
from .torch_utils import ZFilter, Identity, StateWithTime, RewardFilter, get_zoo_path, load_from_file, state_dict_from_numpy, get_rms_vars
from .models import CtsPolicy
import torch as ch
from .scheduling import ConditionalAnnealer, ConstantAnnealer, LinearAnnealer, Scheduler
import gym_compete
from stable_baselines3.common.running_mean_std import RunningMeanStd
from gym import Wrapper
import pickle

import pickle
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

from stable_baselines3.common import utils
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper, VecEnvIndices



class Env:
    '''
    A wrapper around the OpenAI gym environment that adds support for the following:
    - Rewards normalization
    - State normalization
    - Adding timestep as a feature with a particular horizon T
    Also provides utility functions/properties for:
    - Whether the env is discrete or continuous
    - Size of feature space
    - Size of action space
    Provides the same API (init, step, reset) as the OpenAI gym
    '''
    def __init__(self, game, norm_states, norm_rewards, params, add_t_with_horizon=None, clip_obs=None, clip_rew=None, 
            show_env=False, save_frames=False, save_frames_path=""):
        self.env = gym.make(game)
        clip_obs = None if clip_obs < 0 else clip_obs
        clip_rew = None if clip_rew < 0 else clip_rew

        # Environment type
        self.is_discrete = type(self.env.action_space) == Discrete
        assert self.is_discrete or type(self.env.action_space) == Continuous

        # Number of actions
        action_shape = self.env.action_space.shape
        assert len(action_shape) <= 1 # scalar or vector actions
        self.num_actions = self.env.action_space.n if self.is_discrete else 0 \
                            if len(action_shape) == 0 else action_shape[0]
        
        # Number of features
        assert len(self.env.observation_space.shape) == 1
        self.num_features = self.env.reset().shape[0]

        # Support for state normalization or using time as a feature
        self.state_filter = Identity()
        if norm_states:
            self.state_filter = ZFilter(self.state_filter, shape=[self.num_features], \
                                            clip=clip_obs)
        if add_t_with_horizon is not None:
            self.state_filter = StateWithTime(self.state_filter, horizon=add_t_with_horizon)
        
        # Support for rewards normalization
        self.reward_filter = Identity()
        if norm_rewards == "rewards":
            self.reward_filter = ZFilter(self.reward_filter, shape=(), center=False, clip=clip_rew)
        elif norm_rewards == "returns":
            self.reward_filter = RewardFilter(self.reward_filter, shape=(), gamma=params.GAMMA, clip=clip_rew)

        # Running total reward (set to 0.0 at resets)
        self.total_true_reward = 0.0

        # Set normalizers to read-write mode by default.
        self._read_only = False

        self.setup_visualization(show_env, save_frames, save_frames_path)

    # For environments that are created from a picked object.
    def setup_visualization(self, show_env, save_frames, save_frames_path):
        self.save_frames = save_frames
        self.show_env = show_env
        self.save_frames_path = save_frames_path
        self.episode_counter = 0
        self.frame_counter = 0
        if self.save_frames:
            print(f'We will save frames to {self.save_frames_path}!')
            os.makedirs(os.path.join(self.save_frames_path, "000"), exist_ok=True)
    
    @property
    def normalizer_read_only(self):
        return self._read_only

    @normalizer_read_only.setter
    def normalizer_read_only(self, value):
        self._read_only = bool(value)
        if isinstance(self.state_filter, ZFilter):
            if not hasattr(self.state_filter, 'read_only') and value:
                print('Warning: requested to set state_filter.read_only=True but the underlying ZFilter does not support it.')
            elif hasattr(self.state_filter, 'read_only'):
                self.state_filter.read_only = self._read_only
        if isinstance(self.reward_filter, ZFilter) or isinstance(self.reward_filter, RewardFilter):
            if not hasattr(self.reward_filter, 'read_only') and value:
                print('Warning: requested to set reward_filter.read_only=True but the underlying ZFilter does not support it.')
            elif hasattr(self.reward_filter, 'read_only'):
                self.reward_filter.read_only = self._read_only
    

    def reset(self):
        # Set a deterministic random seed for reproduicability
        self.env.seed(random.getrandbits(31))
        # Reset the state, and the running total reward
        start_state = self.env.reset()
        self.total_true_reward = 0.0
        self.counter = 0.0
        self.episode_counter += 1
        if self.save_frames:
            os.makedirs(os.path.join(self.save_frames_path, f"{self.episode_counter:03d}"), exist_ok=True)
            self.frame_counter = 0
        self.state_filter.reset()
        self.reward_filter.reset()
        return self.state_filter(start_state, reset=True)

    def step(self, action):
        state, reward, is_done, info = self.env.step(action)
        if self.show_env:
            self.env.render()
        # Frameskip (every 6 frames, will be rendered at 25 fps)
        if self.save_frames and int(self.counter) % 1 == 0:
            image = self.env.render(mode='rgb_array')
            path = os.path.join(self.save_frames_path, f"{self.episode_counter:03d}", f"{self.frame_counter+1:04d}.bmp")
            image = Image.fromarray(image)
            image.save(path)
            self.frame_counter += 1
        state = self.state_filter(state)
        self.total_true_reward += reward
        self.counter += 1
        _reward = self.reward_filter(reward)
        if is_done:
            info['done'] = (self.counter, self.total_true_reward)
        return state, _reward, is_done, info


class Multi2Single_Env(Wrapper):
    '''
    A wrapper around the OpenAI gym environment that adds support for the following:
    - Rewards normalization
    - State normalization
    - Adding timestep as a feature with a particular horizon T
    Also provides utility functions/properties for:
    - Whether the env is discrete or continuous
    - Size of feature space
    - Size of action space
    Provides the same API (init, step, reset) as the OpenAI gym
    '''
    def __init__(self, env, game, params, total_step, add_t_with_horizon=None, clip_obs=None, clip_rew=None, 
            tag=1, version=1, agent_idx=0, epsilon=1e-8):
        Wrapper.__init__(self, env)
        self.env = env
        zoo_agent = make_zoo_agent(game, self.env.observation_space.spaces[1], self.env.action_space.spaces[1],
                               tag=tag, version=version)
        self.agent = zoo_agent
        self.clip_obs = None if clip_obs < 0 else clip_obs
        self.clip_rew = None if clip_rew < 0 else clip_rew
        self.epsilon = epsilon

        self.observation_space = self.env.observation_space.spaces[0]
        # action dimensionality
        self.action_space = self.env.action_space.spaces[0]

        self.counter = 0.0
        self.agent_idx = agent_idx # not trained agent index
        self.total_step = total_step
        self.scheduler = Scheduler(annealer_dict={'lr': ConstantAnnealer(params.VAL_LR)})
        self.gamma = params.GAMMA
        self.env_name = game
        self.is_discrete = False
        
        # Number of features
        self.num_features = self.env.observation_space.spaces[1].shape[0]
        self.num_actions = self.env.action_space.spaces[1].shape[0]

        # Support for state normalization and rewards normalization
        self.obs_rms = RunningMeanStd(shape=self.env.observation_space.spaces[1].shape)
        self.policy_obs_rms = RunningMeanStd(shape=self.env.observation_space.spaces[1].shape)
        self.ret_rms = RunningMeanStd(shape=())
        self.policy_ret_rms = RunningMeanStd(shape=())

        # Running total reward (set to 0.0 at resets)
        self.total_true_reward = 0.0

        # Set normalizers to read-write mode by default.
        self._read_only = False

        for i in range(1,3):
            agent_path = get_zoo_path(self.env_name, tag=i, version=version)
            param = load_from_file(param_pkl_path=agent_path)
            obs_mean, obs_var, ret_mean, ret_var = get_rms_vars(param, ob_shapes=self.num_features)
            if i == self.agent_idx + 1:
                self.set_envagent_rms(obs_mean, obs_var, ret_mean, ret_var)
            else:
                self.set_policy_rms(obs_mean, obs_var, ret_mean, ret_var)

    def set_policy_rms(self, ob_mean, ob_var, ret_mean, ret_var):
        self.policy_obs_rms.mean = ob_mean
        self.policy_obs_rms.var = ob_var
        self.policy_ret_rms.mean = ret_mean
        self.policy_ret_rms.var = ret_var

    def set_envagent_rms(self, ob_mean, ob_var, ret_mean, ret_var):
        self.obs_rms.mean = ob_mean
        self.obs_rms.var = ob_var
        self.ret_rms.mean = ret_mean
        self.ret_rms.var = ret_var
   
    @property
    def normalizer_read_only(self):
        return self._read_only

    @normalizer_read_only.setter
    def normalizer_read_only(self, value):
        self._read_only = bool(value)
        if isinstance(self.state_filter, ZFilter):
            if not hasattr(self.state_filter, 'read_only') and value:
                print('Warning: requested to set state_filter.read_only=True but the underlying ZFilter does not support it.')
            elif hasattr(self.state_filter, 'read_only'):
                self.state_filter.read_only = self._read_only
        if isinstance(self.reward_filter, ZFilter) or isinstance(self.reward_filter, RewardFilter):
            if not hasattr(self.reward_filter, 'read_only') and value:
                print('Warning: requested to set reward_filter.read_only=True but the underlying ZFilter does not support it.')
            elif hasattr(self.reward_filter, 'read_only'):
                self.reward_filter.read_only = self._read_only
    

    def reset(self):
        # Set a deterministic random seed for reproduicability
        # self.env.seed(random.getrandbits(31))
        # Reset the state, and the running total reward
        self.total_true_reward = 0.0
        self.counter = 0.0
        self.reward = 0
        self.done = False
      
        if self.agent_idx == 1:
            ob, self.ob = self.env.reset()
        else:
            self.ob, ob = self.env.reset()
            
        # self.policy_obs_rms.update(ob)
        # print(self.policy_obs_rms.mean)
        # print()
        # print(self.policy_obs_rms.var)
        # print()
        ob = np.clip((ob - self.policy_obs_rms.mean) / np.sqrt(np.maximum(self.policy_obs_rms.var, 1e-2)),
                               -5, 5)
        return ob

    def step(self, action):
       
        self.counter += 1
        self.oppo_ob = self.ob.copy()
        # self.obs_rms.update(self.oppo_ob)
        self.oppo_ob = np.clip((self.oppo_ob - self.obs_rms.mean) / np.sqrt(np.maximum(self.obs_rms.var, 1e-2)),
                               -self.clip_obs, self.clip_obs)
        self_action = self.agent.act(observation=self.oppo_ob)

        if ch.is_tensor(self_action):
            self_action = self_action.cpu().detach().numpy()[0]
        
        # combine agents' actions
        if self.agent_idx == 0:
            actions = (self_action, action)
        else:
            actions = (action, self_action)
            
        # obtain needed information from the environment.
        obs, rewards, dones, infos = self.env.step(actions)

        if dones and 'Ant' in self.env_name:
            if infos[0]['reward_remaining']==0:
                infos[0]['reward_remaining'] = -1000
            if infos[1]['reward_remaining']==0:
                infos[1]['reward_remaining'] = -1000
 
        # separate victim and adversarial information.
        if self.agent_idx == 0: # vic is 0; adv is 1
          self.ob, ob = obs
          self.reward, reward = rewards
          self.done = dones; done = dones
          self.info, info = infos.values()
        else: # vic is 1; adv is 0
          ob, self.ob = obs
          reward, self.reward = rewards
          done = dones; self.done = dones
          info, self.info = infos.values()

        self.total_true_reward += reward
        # normalize reward
        frac_remaining = max(1 - self.counter / self.total_step, 0)
        self.oppo_reward = apply_reward_shapping(self.info, self.shaping_params, self.scheduler, frac_remaining)
        reward = apply_reward_shapping(info, self.shaping_params, self.scheduler, frac_remaining)

        # self.policy_obs_rms.update(ob)
        ob = np.clip((ob - self.policy_obs_rms.mean) / np.sqrt(np.maximum(self.policy_obs_rms.var, 1e-2)),
                               -self.clip_obs, self.clip_obs)
    
        if done:
            if 'winner' in self.info: # opponent (the agent that is not being trained) win.
                info['loser'] = True
            info['done'] = (self.counter, self.total_true_reward)
            
        return ob, reward, done, info

class MultiAgent_Env(Wrapper):
  
    def __init__(self, env, game, params, total_step, clip_obs=None, clip_rew=None, epsilon=1e-8):
        Wrapper.__init__(self, env)
        self.env = env
        # zoo_agent = make_zoo_agent(game, self.env.observation_space.spaces[1], self.env.action_space.spaces[1],
        #                        tag=tag, version=version)
        # self.agent = zoo_agent
        self.clip_obs = None if clip_obs < 0 else clip_obs
        self.clip_rew = None if clip_rew < 0 else clip_rew
        self.epsilon = epsilon

        self.observation_space = self.env.observation_space.spaces[0]
        self.action_space = self.env.action_space.spaces[0]

        self.counter = 0.0
        # self.agent_idx = agent_idx # not trained agent index
        self.total_step = total_step
        self.scheduler = Scheduler(annealer_dict={'lr': ConstantAnnealer(params.VAL_LR)})
        self.gamma = params.GAMMA
        self.env_name = game
        self.is_discrete = False
        
        # Number of features
        self.num_features = self.observation_space.shape[0]
        self.num_actions = self.action_space.shape[0]

        # Support for state normalization and rewards normalization

        # Running total reward (set to 0.0 at resets)
        self.total_true_reward_0 = 0.0
        self.total_true_reward_1 = 0.0

        # Set normalizers to read-write mode by default.
        self._read_only = False
        self.rew_types = set(('sparse', 'dense'))

        if self.env_name == "multicomp/YouShallNotPassHumans-v0" or self.env_name == "multicomp/SumoHumans-v0":
            self.shaping_params = {'weights': {'dense': {'reward_move': 0.1}, 'sparse': {'reward_remaining': 0.01}},
                        'anneal_frac': 0}
            
        elif self.env_name == "multicomp/SumoAnts-v0":
            self.shaping_params = {'weights': {'dense': {'reward_move': 1}, 'sparse': {'reward_remaining': 0.01}},
                        'anneal_frac': 0.1}
        
        elif self.env_name == "multicomp/KickAndDefend-v0":
            self.shaping_params = {'weights': {'dense': {'reward_move': 0.5, 'reward_contact': 1, 'reward_survive': 0.5,},
                                'sparse': {'reward_remaining': 0.01}},
                        'anneal_frac': 0.01, 'anneal_type': 0}

        elif self.env_name == "multicomp/RunToGoalAnts-v0" or self.env_name == "multicomp/RunToGoalHumans-v0":
            self.shaping_params = {'weights': {'dense': {'reward_move': 0.1}, 'sparse': {'reward_remaining': 0.01}},
                        'anneal_frac': 0}
        else:
            raise NotImplementedError


    @property
    def normalizer_read_only(self):
        return self._read_only

    @normalizer_read_only.setter
    def normalizer_read_only(self, value):
        self._read_only = bool(value)
        if isinstance(self.state_filter, ZFilter):
            if not hasattr(self.state_filter, 'read_only') and value:
                print('Warning: requested to set state_filter.read_only=True but the underlying ZFilter does not support it.')
            elif hasattr(self.state_filter, 'read_only'):
                self.state_filter.read_only = self._read_only
        if isinstance(self.reward_filter, ZFilter) or isinstance(self.reward_filter, RewardFilter):
            if not hasattr(self.reward_filter, 'read_only') and value:
                print('Warning: requested to set reward_filter.read_only=True but the underlying ZFilter does not support it.')
            elif hasattr(self.reward_filter, 'read_only'):
                self.reward_filter.read_only = self._read_only
    

    def reset(self):

        self.total_true_reward_0 = 0.0
        self.total_true_reward_1 = 0.0
        self.counter = 0.0
        self.reward = 0
        self.done = False
      
        obs = self.env.reset()
        self.before_norm_obs = obs
        reset_obs = []
        for i, ob in enumerate(obs):
            # rms = getattr(self, 'obs_rms_%s'%(str(i)))
            # ob = np.clip((ob - rms.mean) / np.sqrt(np.maximum(rms.var, 1e-2)), -self.clip_obs, self.clip_obs)
            reset_obs.append(ob)
        
        return reset_obs

    def step(self, actions):
        if ch.is_tensor(actions):
            actions = actions.cpu().detach().numpy()[0]
       
        self.counter += 1
        self.before_norm_obs, rewards, dones, infos = self.env.step(actions)
        obs = self.before_norm_obs

        if dones and 'Ant' in self.env_name:
            if infos[0]['reward_remaining']==0:
                infos[0]['reward_remaining'] = -1000
            if infos[1]['reward_remaining']==0:
                infos[1]['reward_remaining'] = -1000
 
        # separate victim and adversarial information.
       
        reward_0, reward_1 = rewards
        self.total_true_reward_0 += reward_0
        self.total_true_reward_1 += reward_1
        info_0, info_1 = infos.values()

        # self.total_true_reward += reward
        # normalize reward
        frac_remaining = max(1 - self.counter / self.total_step, 0)
        reward_0 = apply_reward_shapping(info_0, self.shaping_params, self.rew_types, self.scheduler, frac_remaining)
        reward_1 = apply_reward_shapping(info_1, self.shaping_params, self.rew_types, self.scheduler, frac_remaining)
        rewards = (reward_0, reward_1)
    
        if dones:
            for i in range(2):
                if 'winner' in infos[i]: # opponent win.
                    infos[1-i]['loser'] = True
                infos[i]['done'] = (self.counter, self.total_true_reward_0, self.total_true_reward_1)
      
        return obs, rewards, dones, infos

    def set_pos_vel(self, data, attacker_id):
        """
        set qpos and qvel of the attacker agent, return the new observation
        """
        qpos = self.env.env_scene.model.data.qpos.flatten().copy()
        qvel = self.env.env_scene.model.data.qvel.flatten().copy()
        if self.env_name in {"multicomp/YouShallNotPassHumans-v0", "multicomp/SumoHumans-v0", "multicomp/RunToGoalHumans-v0"}:  
            if attacker_id == 1:        
                qpos[24:] = data                                                                                             
            else:
                qpos[:24] = data
        elif self.env_name == 'multicomp/KickAndDefend-v0': # the observation of the agent is relative, so we only add the perturbation.
            if attacker_id == 1:        
                qpos[31:55] += data                                                                                             
            else:
                qpos[7:31] += data
        elif self.env_name in {'multicomp/SumoAnts-v0', 'multicomp/RunToGoalAnts-v0'}:
            if attacker_id == 1:        
                qpos[15:] = data                                                                                             
            else:
                qpos[:15] = data
        else:
            raise NotImplementedError
        self.env.env_scene.set_state(qpos, qvel)
        observation = self.env._get_obs()

        return observation

    def save_running_average(self, path, step):
        for rms, name in zip([self.obs_rms_0_dummy, self.obs_rms_1_dummy], ['obs_rms_0_step%s' %(step), 'obs_rms_1_step%s' %(step)]):
            with open("{}/{}.pkl".format(path, name), 'wb') as file_handler:
                pickle.dump(rms, file_handler)


class VecNormalize_MultiAgent(VecEnvWrapper):
    """
    A moving average, normalizing wrapper for vectorized environment.
    has support for saving/loading moving average,

    :param venv: the vectorized environment to wrap
    :param training: Whether to update or not the moving average
    :param norm_obs: Whether to normalize observation or not (default: True)
    :param norm_reward: Whether to normalize rewards or not (default: True)
    :param clip_obs: Max absolute value for observation
    :param clip_reward: Max value absolute for discounted reward
    :param gamma: discount factor
    :param epsilon: To avoid division by zero
    :param norm_obs_keys: Which keys from observation dict to normalize.
        If not specified, all keys will be normalized.
    """

    def __init__(
        self,
        venv: VecEnv,
        training: bool = True,
        norm_obs: bool = True,
        norm_reward: bool = True,
        clip_obs: float = 10.0,
        clip_reward: float = 10.0,
        gamma: float = 0.99,
        epsilon: float = 1e-8,
        norm_obs_keys: Optional[List[str]] = None,
    ):
        VecEnvWrapper.__init__(self, venv)

        self.norm_obs = norm_obs
        self.norm_obs_keys = norm_obs_keys
        # Check observation spaces
        if self.norm_obs:
            self._sanity_checks()

        self.obs_rms_0 = RunningMeanStd(shape=self.observation_space.shape)
        self.obs_rms_1 = RunningMeanStd(shape=self.observation_space.shape)
        self.ret_rms_0 = RunningMeanStd(shape=())
        self.ret_rms_1 = RunningMeanStd(shape=())

        self.clip_obs = clip_obs
        self.clip_reward = clip_reward
        # Returns: discounted rewards
        self.returns_0 = np.zeros(self.num_envs)
        self.returns_1 = np.zeros(self.num_envs)
        self.gamma = gamma
        self.epsilon = epsilon
        self.training = training
        self.norm_obs = norm_obs
        self.norm_reward = norm_reward
        self.old_obs = np.array([])
        self.old_reward = np.array([])

    def _sanity_checks(self) -> None:
        """
        Check the observations that are going to be normalized are of the correct type (spaces.Box).
        """
        if isinstance(self.observation_space, gym.spaces.Dict):
            # By default, we normalize all keys
            if self.norm_obs_keys is None:
                self.norm_obs_keys = list(self.observation_space.spaces.keys())
            # Check that all keys are of type Box
            for obs_key in self.norm_obs_keys:
                if not isinstance(self.observation_space.spaces[obs_key], gym.spaces.Box):
                    raise ValueError(
                        f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} "
                        f"is of type {self.observation_space.spaces[obs_key]}. "
                        "You should probably explicitely pass the observation keys "
                        " that should be normalized via the `norm_obs_keys` parameter."
                    )

        elif isinstance(self.observation_space, gym.spaces.Box):
            if self.norm_obs_keys is not None:
                raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces")

        else:
            raise ValueError(
                "VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, "
                f"not {self.observation_space}"
            )

    def __getstate__(self) -> Dict[str, Any]:
        """
        Gets state for pickling.

        Excludes self.venv, as in general VecEnv's may not be pickleable."""
        state = self.__dict__.copy()
        # these attributes are not pickleable
        del state["venv"]
        del state["class_attributes"]
        # these attributes depend on the above and so we would prefer not to pickle
        del state["returns_0"]
        del state["returns_1"]
        return state

    def __setstate__(self, state: Dict[str, Any]) -> None:
        """
        Restores pickled state.

        User must call set_venv() after unpickling before using.

        :param state:"""
        # Backward compatibility
        if "norm_obs_keys" not in state and isinstance(state["observation_space"], gym.spaces.Dict):
            state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
        self.__dict__.update(state)
        assert "venv" not in state
        self.venv = None

    def step_wait(self) -> VecEnvStepReturn:
        """
        Apply sequence of actions to sequence of environments
        actions -> (observations, rewards, dones)

        where ``dones`` is a boolean vector indicating whether each element is new.
        """
        obs, rewards, dones, infos = self.venv.step_wait()
        self.old_obs = obs
        self.old_reward = rewards
        # print('old', self.obs_rms_0.count, self.obs_rms_1.count, 'step')
        if self.training and self.norm_obs:
            
            self.obs_rms_0.update(obs[:,0,:])
            self.obs_rms_1.update(obs[:,1,:])
        # print('new', self.obs_rms_0.count, self.obs_rms_1.count, 'step')
        obs = self.normalize_obs(obs)

        if self.training:
            self._update_reward(rewards)
        rewards = self.normalize_reward(rewards)

        # Normalize the terminal observations
        for idx, done in enumerate(dones):
            if not done:
                continue
            # if "terminal_observation" in infos[idx]:
            #     infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])

        self.returns_0[dones] = 0
        self.returns_1[dones] = 0

        return obs, rewards, dones, infos

    def _update_reward(self, reward: np.ndarray) -> None:
        """Update reward normalization statistics."""
        self.returns_0 = self.returns_0 * self.gamma + reward[:,0]
        self.returns_1 = self.returns_1 * self.gamma + reward[:,1]
        self.ret_rms_0.update(self.returns_0)
        self.ret_rms_1.update(self.returns_1)

    def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
        """
        Helper to normalize observation.
        :param obs:
        :param obs_rms: associated statistics
        :return: normalized observation
        """
        return np.clip((obs - obs_rms.mean) / np.sqrt(np.maximum(obs_rms.var, 1e-2)), -self.clip_obs, self.clip_obs)

    def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
        """
        Helper to unnormalize observation.
        :param obs:
        :param obs_rms: associated statistics
        :return: unnormalized observation
        """
        return (obs * np.sqrt(np.maximum(obs_rms.var, 1e-2))) + obs_rms.mean

    def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
        """
        Normalize observations using this VecNormalize's observations statistics.
        Calling this method does not update statistics.
        """
        # Avoid modifying by reference the original object
        obs_ = deepcopy(obs)
        if self.norm_obs: 
            if obs_.ndim == 3:
                obs_0 = self._normalize_obs(obs[:,0,:], self.obs_rms_0)
                obs_1 = self._normalize_obs(obs[:,1,:], self.obs_rms_1)
                obs_ = np.stack((obs_0, obs_1), axis=1)
            elif obs_.ndim == 2:
                obs_0 = self._normalize_obs(obs[0,:], self.obs_rms_0)
                obs_1 = self._normalize_obs(obs[1,:], self.obs_rms_1)
                obs_ = np.stack((obs_0, obs_1), axis=0)
            else:
                raise NotImplementedError
        return obs_
        

    def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
        """
        Normalize rewards using this VecNormalize's rewards statistics.
        Calling this method does not update statistics.
        """
        if self.norm_reward:
            reward_0 = np.clip(reward[:,0] / np.sqrt(self.ret_rms_0.var + self.epsilon), -self.clip_reward, self.clip_reward)
            reward_1 = np.clip(reward[:,1] / np.sqrt(self.ret_rms_1.var + self.epsilon), -self.clip_reward, self.clip_reward)
        reward_ = np.stack((reward_0, reward_1), axis=1)
        return reward_

    def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
        # Avoid modifying by reference the original object
        obs_ = deepcopy(obs)
        if self.norm_obs:
            if obs_.ndim == 3:
                obs_0 = self._unnormalize_obs(obs[:,0,:], self.obs_rms_0)
                obs_1 = self._unnormalize_obs(obs[:,1,:], self.obs_rms_1)
                obs_ = np.stack((obs_0, obs_1), axis=1)
            elif obs_.ndim == 2:
                obs_0 = self._unnormalize_obs(obs[0,:], self.obs_rms_0)
                obs_1 = self._unnormalize_obs(obs[1,:], self.obs_rms_1)
                obs_ = np.stack((obs_0, obs_1), axis=0)
        return obs_

    def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray:
        if self.norm_reward:
            return reward * np.sqrt(self.ret_rms.var + self.epsilon)
        return reward

    def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
        """
        Returns an unnormalized version of the observations from the most recent
        step or reset.
        """
        return deepcopy(self.old_obs)

    def get_original_reward(self) -> np.ndarray:
        """
        Returns an unnormalized version of the rewards from the most recent step.
        """
        return self.old_reward.copy()

    def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
        """
        Reset all environments
        :return: first observation of the episode
        """
        obs = self.venv.reset()
        self.old_obs = obs
        self.returns_0 = np.zeros(self.num_envs)
        self.returns_1 = np.zeros(self.num_envs)
        if self.training and self.norm_obs:
            self.obs_rms_0.update(obs[:,0,:])
            self.obs_rms_1.update(obs[:,1,:])
        return self.normalize_obs(obs)

    def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
        unnormalize_obs = super().env_method(method_name, *method_args, indices=indices, **method_kwargs)[0]
        unnormalize_obs = np.asarray(unnormalize_obs)
        if self.norm_obs: 
            # self.obs_rms_0.update(unnormalize_obs[:,0,:])
            # self.obs_rms_1.update(unnormalize_obs[:,1,:])
            normalized_obs = self.normalize_obs(unnormalize_obs)
        return unnormalize_obs, normalized_obs  

    def save_running_average(self, path, step):
        for rms, name in zip([self.obs_rms_0, self.ret_rms_0], ['obs_rms_0_step%s' %(step), 'ret_rms_0_step%s' %(step)]):
            with open("{}/{}.pkl".format(path, name), 'wb') as file_handler:
                pickle.dump(rms, file_handler)
        for rms, name in zip([self.obs_rms_1, self.ret_rms_1], ['obs_rms_1_step%s' %(step), 'ret_rms_1_step%s' %(step)]):
            with open("{}/{}.pkl".format(path, name), 'wb') as file_handler:
                pickle.dump(rms, file_handler)



def make_zoo_agent(env_name, ob_space, action_space, tag=2, version=1, scope=""):

    return ZooAgent(env_name, ob_space, action_space, tag, version, scope)

class ZooAgent(object):
    def __init__(self, env_name, ob_space, action_space, tag, version, scope):
        if env_name in ['multicomp/YouShallNotPassHumans-v0', "multicomp/RunToGoalAnts-v0", "multicomp/RunToGoalHumans-v0"]:
            self.agent = CtsPolicy(state_dim=ob_space.shape[0], action_dim=action_space.shape[0], init='orthogonal',activation='tanh')
            env_path = get_zoo_path(env_name, tag=tag)
            param = load_from_file(param_pkl_path=env_path)
            shapes = sum(map(ch.numel, self.agent.parameters()))
            self_agent_pdict = state_dict_from_numpy(self.agent.state_dict(), param, shapes)
            self.agent.load_state_dict(self_agent_pdict)
            self.vic_id = tag-1
            
    def reset(self):
        return self.agent.reset()

    # return the needed state

    def get_state(self):
        return self.agent.state

    def act(self, observation):
        action_pds = self.agent(observation)
        # next_action_means, next_action_stds = action_pds
        next_actions = self.agent.sample(action_pds)
        return next_actions



def apply_reward_shapping(infos, shaping_params, rew_types, scheduler, frac_remaining):
    """ victim agent reward shaping function.
    :param: info: reward returned from the environment.
    :param: shaping_params: reward shaping parameters.
    :param: annealing factor decay schedule.
    :param: linear annealing fraction.
    :return: shaped reward.
    """
    if 'metric' in shaping_params:
        rew_shape_annealer = ConditionalAnnealer.from_dict(shaping_params, get_logs=None)
        scheduler.set_conditional('rew_shape')
    else:
        anneal_frac = shaping_params.get('anneal_frac')
        if shaping_params.get('anneal_type')==0:
            rew_shape_annealer = ConstantAnnealer(anneal_frac)
        else:
            rew_shape_annealer = LinearAnnealer(1, 0, anneal_frac)

    scheduler.set_annealer('rew_shape', rew_shape_annealer)
    reward_annealer = scheduler.get_annealer('rew_shape')
    shaping_params = shaping_params['weights']

    assert shaping_params.keys() == rew_types
    new_shaping_params = {}

    for rew_type, params in shaping_params.items():
        for rew_term, weight in params.items():
            new_shaping_params[rew_term] = (rew_type, weight)

    shaped_reward = {k: 0 for k in rew_types}
  
    for rew_term, rew_value in infos.items():
        if rew_term not in new_shaping_params:
            continue
        rew_type, weight = new_shaping_params[rew_term]
        shaped_reward[rew_type] += weight * rew_value

    # Compute total shaped reward, optionally annealing
    reward = _anneal(shaped_reward, reward_annealer, frac_remaining)
    return reward


def _anneal(reward_dict, reward_annealer, frac_remaining):
    c = reward_annealer(frac_remaining)
    assert 0 <= c <= 1
    sparse_weight = 1 - c
    dense_weight = c

    return (reward_dict['sparse'] * sparse_weight
            + reward_dict['dense'] * dense_weight)

def make_env(game, params, total_step, clip_obs, clip_rew):
    env = gym.make(game)
    return MultiAgent_Env(env, game, 
                       params, total_step,
                       clip_obs, clip_rew)
    
