import torch
import numpy as np
import torch.nn.functional as F
from utils.util import get_shape_from_obs_space, get_shape_from_act_space, get_dim_from_act_space


def _flatten(T, N, x):
    return x.reshape(T * N, *x.shape[2:])


def _cast(x):
    return x.transpose(1, 2, 0, 3).reshape(-1, *x.shape[3:])


def _shuffle_agent_grid(x, y, shuffle=False):
    rows = np.indices((x, y))[0]
    # cols = np.stack([np.random.permutation(y) for _ in range(x)]) \
    #     if shuffle else np.stack([np.arange(y) for _ in range(x)])
    cols = np.stack([np.arange(y) for _ in range(x)])
    return rows, cols


class SharedReplayBuffer(object):
    """
    Buffer to store training data.
    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.
    :param num_agents: (int) number of agents in the env.
    :param obs_space: (gym.Space) observation space of agents.
    :param cent_obs_space: (gym.Space) centralized observation space of agents.
    :param act_space: (gym.Space) action space for agents.
    """

    def __init__(self, args, num_agents, obs_space, cent_obs_space, act_space, env_name):
        self.episode_length = args.episode_length
        self.n_rollout_threads = args.n_rollout_threads
        self.hidden_size = args.hidden_size
        self.recurrent_N = args.recurrent_N
        self.gamma = args.gamma
        self.gae_lambda = args.gae_lambda
        self._use_gae = args.use_gae
        self._use_popart = args.use_popart
        self._use_valuenorm = args.use_valuenorm
        self._use_proper_time_limits = args.use_proper_time_limits
        self.algo = args.algorithm_name
        self.num_agents = num_agents
        self.env_name = env_name
        self._use_gail = args.use_gail
        # add for gru mat
        self._mat_use_history = args.mat_use_history
        self.history_obs_len = args.history_obs_len
        # add for classifier
        self._use_classifier_reward = args.use_classifier_reward
        self._classifier_use_gru = args.classifier_use_gru
        self.classifier_gru_his_len = args.classifier_gru_his_len
        self.classifier_reward_rate = args.classifier_reward_rate
        # add for shuffle data among agents
        self._shuffle_buffer_agents_data = args.shuffle_buffer_agents_data

        obs_shape = get_shape_from_obs_space(obs_space)
        share_obs_shape = get_shape_from_obs_space(cent_obs_space)
        act_shape = get_shape_from_act_space(act_space)
        act_dim = get_dim_from_act_space(act_space)

        if type(obs_shape[-1]) == list:
            obs_shape = obs_shape[:1]
        if type(share_obs_shape[-1]) == list:
            share_obs_shape = share_obs_shape[:1]

        self.share_obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, num_agents, *share_obs_shape),
                                  dtype=np.float32)
        self.obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, num_agents, *obs_shape), dtype=np.float32)
        self.his_obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, num_agents, self.history_obs_len, *obs_shape), dtype=np.float32)
        # record history obs for mat if necessary
        self.history_obs_record = np.zeros((
            self.n_rollout_threads, num_agents, self.history_obs_len, *obs_shape), dtype=np.float32)
        self.history_act_record = np.zeros((
            self.n_rollout_threads, num_agents, self.classifier_gru_his_len, act_shape), dtype=np.float32)
        # record last done for each env
        self.last_dones = np.zeros((self.n_rollout_threads, self.num_agents))

        self.rnn_states = np.zeros(
            (self.episode_length + 1, self.n_rollout_threads, num_agents, self.recurrent_N, self.hidden_size),
            dtype=np.float32)
        self.rnn_states_critic = np.zeros_like(self.rnn_states)

        self.value_preds = np.zeros(
            (self.episode_length + 1, self.n_rollout_threads, num_agents, 1), dtype=np.float32)
        self.returns = np.zeros_like(self.value_preds)
        self.advantages = np.zeros(
            (self.episode_length, self.n_rollout_threads, num_agents, 1), dtype=np.float32)

        if act_space.__class__.__name__ == 'Discrete':
            self.available_actions = np.ones((self.episode_length + 1, self.n_rollout_threads, num_agents, act_space.n), dtype=np.float32)
        else:
            self.available_actions = None

        self.actions = np.zeros(
            (self.episode_length, self.n_rollout_threads, num_agents, act_shape), dtype=np.float32)
        self.action_log_probs = np.zeros(
            (self.episode_length, self.n_rollout_threads, num_agents, act_shape), dtype=np.float32)
        self.action_probs = np.zeros(
            (self.episode_length, self.n_rollout_threads, num_agents, act_dim), dtype=np.float32)
        self.rewards = np.zeros(
            (self.episode_length, self.n_rollout_threads, num_agents, 1), dtype=np.float32)
        self.disc_values = np.zeros(
            (self.episode_length, self.n_rollout_threads, num_agents, 1), dtype=np.float32)
        self.classifier_values = np.zeros(
            (self.episode_length, self.n_rollout_threads, num_agents, 1), dtype=np.float32)

        self.masks = np.ones((self.episode_length + 1, self.n_rollout_threads, num_agents, 1), dtype=np.float32)
        self.bad_masks = np.ones_like(self.masks)
        self.active_masks = np.ones_like(self.masks)
        self.step = 0

    def insert(self, share_obs, obs, rnn_states_actor, rnn_states_critic, actions, action_log_probs, action_probs,
               value_preds, rewards, dones, masks, bad_masks=None, active_masks=None,
               available_actions=None, disc_values=None, classifier_rewards=None):
        """
        Insert data into the buffer.
        :param share_obs: (argparse.Namespace) arguments containing relevant model, policy, and env information.
        :param obs: (np.ndarray) local agent observations.
        :param rnn_states_actor: (np.ndarray) RNN states for actor network.
        :param rnn_states_critic: (np.ndarray) RNN states for critic network.
        :param actions:(np.ndarray) actions taken by agents.
        :param action_log_probs:(np.ndarray) log probs of actions taken by agents
        :param action_probs:(np.ndarray) probs of actions taken by agents
        :param value_preds: (np.ndarray) value function prediction at each step.
        :param rewards: (np.ndarray) reward collected at each step.
        :param dones: (np.ndarray) if current step is end
        :param masks: (np.ndarray) denotes whether the environment has terminated or not.
        :param bad_masks: (np.ndarray) action space for agents.
        :param active_masks: (np.ndarray) denotes whether an agent is active or dead in the env.
        :param available_actions: (np.ndarray) actions available to each agent. If None, all actions are available.
        """
        # print('---------------------')
        # print('rewards', rewards.shape)
        # print('rewards', rewards[0])
        self.share_obs[self.step + 1] = share_obs.copy()
        self.obs[self.step + 1] = obs.copy()
        # add history obs for mat if necessary
        """
        share_obs (n_rollout_threads, num_agent, share_obs_dim)
        obs (n_rollout_threads, num_agent, obs_dim)
        actions (n_rollout_threads, num_agent, act_dim)
        rewards (n_rollout_threads, num_agent, 1)
        dones (n_rollout_threads, num_agent)
        last_dones (n_rollout_threads, num_agent) --all--> (n_rollout_threads, )
        history_obs_record: (n_rollout_threads, num_agents, history_obs_len, obs_dim)
        """
        if self._mat_use_history:
            # clear his_obs_record if done
            self.history_obs_record = self.history_obs_record * np.tile((1 - np.all(self.last_dones, axis=1)).reshape(-1, 1, 1, 1), (1, *self.history_obs_record.shape[1:]))
            # update history_obs_record by move left
            self.history_obs_record[:, :, 0: self.history_obs_len - 1] = self.history_obs_record[:, :, 1: self.history_obs_len]
            self.history_obs_record[:, :, self.history_obs_len - 1] = obs.copy()
            self.his_obs[self.step + 1] = self.history_obs_record.copy()
            # replace last done
            self.last_dones = dones.astype(self.last_dones.dtype)
        self.rnn_states[self.step + 1] = rnn_states_actor.copy()
        self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
        self.actions[self.step] = actions.copy()
        self.action_log_probs[self.step] = action_log_probs.copy()
        self.action_probs[self.step] = action_probs.copy()
        self.value_preds[self.step] = value_preds.copy()
        self.rewards[self.step] = rewards.copy()
        self.disc_values[self.step] = disc_values.copy()
        if self._use_classifier_reward:
            self.classifier_values[self.step] = classifier_rewards.copy()

        self.masks[self.step + 1] = masks.copy()
        if bad_masks is not None:
            self.bad_masks[self.step + 1] = bad_masks.copy()
        if active_masks is not None:
            self.active_masks[self.step + 1] = active_masks.copy()
        if available_actions is not None:
            self.available_actions[self.step + 1] = available_actions.copy()

        self.step = (self.step + 1) % self.episode_length

    def chooseinsert(self, share_obs, obs, rnn_states, rnn_states_critic, actions, action_log_probs,
                     value_preds, rewards, masks, bad_masks=None, active_masks=None, available_actions=None):
        """
        Insert data into the buffer. This insert function is used specifically for Hanabi, which is turn based.
        :param share_obs: (argparse.Namespace) arguments containing relevant model, policy, and env information.
        :param obs: (np.ndarray) local agent observations.
        :param rnn_states_actor: (np.ndarray) RNN states for actor network.
        :param rnn_states_critic: (np.ndarray) RNN states for critic network.
        :param actions:(np.ndarray) actions taken by agents.
        :param action_log_probs:(np.ndarray) log probs of actions taken by agents
        :param value_preds: (np.ndarray) value function prediction at each step.
        :param rewards: (np.ndarray) reward collected at each step.
        :param masks: (np.ndarray) denotes whether the environment has terminated or not.
        :param bad_masks: (np.ndarray) denotes indicate whether whether true terminal state or due to episode limit
        :param active_masks: (np.ndarray) denotes whether an agent is active or dead in the env.
        :param available_actions: (np.ndarray) actions available to each agent. If None, all actions are available.
        """
        self.share_obs[self.step] = share_obs.copy()
        self.obs[self.step] = obs.copy()
        self.rnn_states[self.step + 1] = rnn_states.copy()
        self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
        self.actions[self.step] = actions.copy()
        self.action_log_probs[self.step] = action_log_probs.copy()
        self.value_preds[self.step] = value_preds.copy()
        self.rewards[self.step] = rewards.copy()
        self.masks[self.step + 1] = masks.copy()
        if bad_masks is not None:
            self.bad_masks[self.step + 1] = bad_masks.copy()
        if active_masks is not None:
            self.active_masks[self.step] = active_masks.copy()
        if available_actions is not None:
            self.available_actions[self.step] = available_actions.copy()

        self.step = (self.step + 1) % self.episode_length

    def get_history_obs_record(self):
        history_obs_record = self.history_obs_record

        return history_obs_record

    def get_history_act_record(self):
        history_act_record = self.history_act_record

        return history_act_record

    def update_history_act_record(self, actions, dones):
        if self._classifier_use_gru:
            # clear his_act_record if done
            self.history_act_record = self.history_act_record * np.tile((1 - np.all(self.last_dones, axis=1)).reshape(-1, 1, 1, 1), (1, *self.history_act_record.shape[1:]))
            # update history_obs_record by move left
            self.history_act_record[:, :, 0: self.classifier_gru_his_len - 1] = self.history_act_record[:, :, 1: self.classifier_gru_his_len]
            self.history_act_record[:, :, self.classifier_gru_his_len - 1] = actions.copy()
            # replace last done
            self.last_dones = dones.astype(self.last_dones.dtype)

    def after_update(self):
        """Copy last timestep data to first index. Called after update to model."""
        self.share_obs[0] = self.share_obs[-1].copy()
        self.obs[0] = self.obs[-1].copy()
        self.his_obs[0] = self.his_obs[-1].copy()
        self.rnn_states[0] = self.rnn_states[-1].copy()
        self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy()
        self.masks[0] = self.masks[-1].copy()
        self.bad_masks[0] = self.bad_masks[-1].copy()
        self.active_masks[0] = self.active_masks[-1].copy()
        if self.available_actions is not None:
            self.available_actions[0] = self.available_actions[-1].copy()

    def chooseafter_update(self):
        """Copy last timestep data to first index. This method is used for Hanabi."""
        self.rnn_states[0] = self.rnn_states[-1].copy()
        self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy()
        self.masks[0] = self.masks[-1].copy()
        self.bad_masks[0] = self.bad_masks[-1].copy()

    def compute_returns(self, next_value, value_normalizer=None):
        """
        Compute returns either as discounted sum of rewards, or using GAE.
        :param next_value: (np.ndarray) value predictions for the step after the last episode step.
        :param value_normalizer: (PopArt) If not None, PopArt value normalizer instance.
        """
        self.value_preds[-1] = next_value
        gae = 0
        for step in reversed(range(self.rewards.shape[0])):
            if self._use_popart or self._use_valuenorm:
                step_reward = self.rewards[step] if not self._use_gail else (
                    self.disc_values[step] if not self._use_classifier_reward else
                    self.disc_values[step] + self.classifier_reward_rate * self.classifier_values[step]
                )
                delta = step_reward + self.gamma * value_normalizer.denormalize(self.value_preds[step + 1])\
                        * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step])
                # delta = self.rewards[step] + self.gamma * value_normalizer.denormalize(self.value_preds[step + 1]) \
                #         * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step]) \
                #     if not self._use_gail else \
                #     self.disc_values[step] + self.gamma * value_normalizer.denormalize(self.value_preds[step + 1]) \
                #     * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step])
                gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae

                # here is a patch for mpe, whose last step is timeout instead of terminate
                if self.env_name == "MPE" and step == self.rewards.shape[0] - 1:
                    gae = 0
                self.advantages[step] = gae
                self.returns[step] = gae + value_normalizer.denormalize(self.value_preds[step])
            else:
                step_reward = self.rewards[step] if not self._use_gail else (
                    self.disc_values[step] if not self._use_classifier_reward else
                    self.disc_values[step] + self.classifier_reward_rate * self.classifier_values[step]
                )
                delta = step_reward + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
                # delta = self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step] \
                #     if not self._use_gail else \
                #     self.disc_values[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
                gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae

                # here is a patch for mpe, whose last step is timeout instead of terminate
                if self.env_name == "MPE" and step == self.rewards.shape[0] - 1:
                    gae = 0
                self.advantages[step] = gae
                self.returns[step] = gae + self.value_preds[step]

    def feed_forward_generator_transformer(self, advantages, num_mini_batch=None, mini_batch_size=None, force_not_shuffle=False):
        """
        Yield training data for MLP policies.
        :param advantages: (np.ndarray) advantage estimates.
        :param num_mini_batch: (int) number of minibatches to split the batch into.
        :param mini_batch_size: (int) number of samples in each minibatch.
        """
        episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3]
        batch_size = n_rollout_threads * episode_length

        if mini_batch_size is None:
            assert batch_size >= num_mini_batch, (
                "PPO requires the number of processes ({}) "
                "* number of steps ({}) = {} "
                "to be greater than or equal to the number of PPO mini batches ({})."
                "".format(n_rollout_threads, episode_length,
                          n_rollout_threads * episode_length,
                          num_mini_batch))
            mini_batch_size = batch_size // num_mini_batch
        rand = torch.randperm(batch_size).numpy()
        sampler = [rand[i * mini_batch_size:(i + 1) * mini_batch_size] for i in range(num_mini_batch)]
        rows, cols = _shuffle_agent_grid(
            batch_size, num_agents, shuffle=(not force_not_shuffle) and self._shuffle_buffer_agents_data)
        # keep (num_agent, dim)
        share_obs = self.share_obs[:-1].reshape(-1, *self.share_obs.shape[2:])
        share_obs = share_obs[rows, cols]
        obs = self.obs[:-1].reshape(-1, *self.obs.shape[2:])
        obs = obs[rows, cols]
        his_obs = self.his_obs[:-1].reshape(-1, *self.his_obs.shape[2:])
        his_obs = his_obs[rows, cols]
        rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:])
        rnn_states = rnn_states[rows, cols]
        rnn_states_critic = self.rnn_states_critic[:-1].reshape(-1, *self.rnn_states_critic.shape[2:])
        rnn_states_critic = rnn_states_critic[rows, cols]
        actions = self.actions.reshape(-1, *self.actions.shape[2:])
        actions = actions[rows, cols]
        if self.available_actions is not None:
            available_actions = self.available_actions[:-1].reshape(-1, *self.available_actions.shape[2:])
            available_actions = available_actions[rows, cols]
        value_preds = self.value_preds[:-1].reshape(-1, *self.value_preds.shape[2:])
        value_preds = value_preds[rows, cols]
        returns = self.returns[:-1].reshape(-1, *self.returns.shape[2:])
        returns = returns[rows, cols]
        masks = self.masks[:-1].reshape(-1, *self.masks.shape[2:])
        masks = masks[rows, cols]
        active_masks = self.active_masks[:-1].reshape(-1, *self.active_masks.shape[2:])
        active_masks = active_masks[rows, cols]
        action_log_probs = self.action_log_probs.reshape(-1, *self.action_log_probs.shape[2:])
        action_log_probs = action_log_probs[rows, cols]
        action_probs = self.action_probs.reshape(-1, *self.action_probs.shape[2:])
        action_probs = action_probs[rows, cols]
        advantages = advantages.reshape(-1, *advantages.shape[2:])
        advantages = advantages[rows, cols]

        for indices in sampler:
            # [L,T,N,Dim]-->[L*T,N,Dim]-->[index,N,Dim]-->[index*N, Dim]
            share_obs_batch = share_obs[indices].reshape(-1, *share_obs.shape[2:])
            obs_batch = obs[indices].reshape(-1, *obs.shape[2:])
            his_obs_batch = his_obs[indices].reshape(-1, *his_obs.shape[2:])
            rnn_states_batch = rnn_states[indices].reshape(-1, *rnn_states.shape[2:])
            rnn_states_critic_batch = rnn_states_critic[indices].reshape(-1, *rnn_states_critic.shape[2:])
            actions_batch = actions[indices].reshape(-1, *actions.shape[2:])
            if self.available_actions is not None:
                available_actions_batch = available_actions[indices].reshape(-1, *available_actions.shape[2:])
            else:
                available_actions_batch = None
            value_preds_batch = value_preds[indices].reshape(-1, *value_preds.shape[2:])
            return_batch = returns[indices].reshape(-1, *returns.shape[2:])
            masks_batch = masks[indices].reshape(-1, *masks.shape[2:])
            active_masks_batch = active_masks[indices].reshape(-1, *active_masks.shape[2:])
            old_action_log_probs_batch = action_log_probs[indices].reshape(-1, *action_log_probs.shape[2:])
            old_action_probs_batch = action_probs[indices].reshape(-1, *action_probs.shape[2:])
            if advantages is None:
                adv_targ = None
            else:
                adv_targ = advantages[indices].reshape(-1, *advantages.shape[2:])

            yield share_obs_batch, obs_batch, his_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
                  value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, old_action_probs_batch, \
                  adv_targ, available_actions_batch


class GailReplayBuffer(object):
    def __init__(self, args, num_agents, env_name, policy):
        # set basic config for data buffer
        self.args = args
        self.num_steps_per_epochs = args.num_steps_per_epochs
        self.n_rollout_threads = args.n_rollout_threads
        self.num_agents = num_agents
        self.env_name = env_name
        self.policy = policy
        # data directly offer to model for training
        self.mean_episode_reward = 0.0
        self.batch_share_obs = None
        self.batch_obs = None
        self.batch_acts = None
        self.batch_log_probs = None
        self.batch_rets = None
        self.batch_advs = None
        self.batch_gms = None
        self.batch_targets = None
        # data buffer to store all generate episode data for training
        self.all_episode_reward = []
        self.now_steps = 0  # record now data step s num
        self.share_obs = []
        self.obs = []
        self.acts = []
        self.log_probs = []
        self.rets = []
        self.advs = []
        self.gms = []
        self.targets = []
        # episode generate buffer for each threads
        self.ep_share_obs = [[] for _ in range(self.n_rollout_threads)]
        self.ep_obs = [[] for _ in range(self.n_rollout_threads)]
        self.ep_acts = [[] for _ in range(self.n_rollout_threads)]
        self.ep_rwds = [[] for _ in range(self.n_rollout_threads)]
        self.ep_dones = [[] for _ in range(self.n_rollout_threads)]
        self.ep_log_probs = [[] for _ in range(self.n_rollout_threads)]
        self.ep_gms = [[] for _ in range(self.n_rollout_threads)]
        self.ep_lmbs = [[] for _ in range(self.n_rollout_threads)]

    def add_multi_threads_step_data(self, share_ob, ob, act, rwd, done, log_prob):
        finish_sample = False
        """
        :param share_ob: (n_rollout_threads, num_agents, share_obs_dim)
        :param ob: (n_rollout_threads, num_agents, obs_dim)
        :param act: (n_rollout_threads, num_agents, act_dim)
        :param rwd: (n_rollout_threads, num_agents, 1)
        :param done: (n_rollout_threads, )
        :param log_prob: (n_rollout_threads, num_agents)
        :return:
        """
        for thread_i in range(self.n_rollout_threads):
            now_thread_t = len(self.ep_share_obs[thread_i])
            self.ep_share_obs[thread_i].append(share_ob[thread_i])
            self.ep_obs[thread_i].append(ob[thread_i])
            self.ep_acts[thread_i].append(act[thread_i])
            self.ep_rwds[thread_i].append(rwd[thread_i])
            self.ep_dones[thread_i].append(done[thread_i])
            self.ep_log_probs[thread_i].append(log_prob[thread_i])
            self.ep_gms[thread_i].append(np.array([
                self.args.gamma ** now_thread_t, self.args.gamma ** now_thread_t
            ]))
            self.ep_lmbs[thread_i].append([
                self.args.gae_lambda ** now_thread_t, self.args.gae_lambda ** now_thread_t
            ])
            # add an episode data to all episode data when finish an episode
            if self.ep_dones[thread_i][-1]:
                self.add_single_thread_episode_data(thread_i)
                self.clear_single_thread_episode_data(thread_i)
                if self.now_steps >= self.num_steps_per_epochs:
                    self.prepare_data_for_train()
                    finish_sample = True
        return finish_sample

    def add_single_thread_episode_data(self, thread_i):
        now_thread_t = len(self.ep_share_obs[thread_i])
        """
        ep_share_obs torch.Size([ep_len, agent_num, share_obs_dim])
        ep_obs torch.Size([ep_len, agent_num, obs_dim])
        ep_acts torch.Size([ep_len, agent_num, act_dim])
        ep_rwds torch.Size([ep_len, agent_num, 1])
        ep_log_probs torch.Size([ep_len, agent_num])
        ep_gms torch.Size([ep_len, agent_num])
        ep_lmbs torch.Size([ep_len, agent_num])
        ep_costs torch.Size([ep_len, agent_num])
        ep_disc_costs torch.Size([ep_len, agent_num])
        ep_disc_rets torch.Size([ep_len, agent_num])
        ep_rets torch.Size([ep_len, agent_num])
        """
        ep_share_obs = torch.FloatTensor(np.array(self.ep_share_obs[thread_i]))
        ep_obs = torch.FloatTensor(np.array(self.ep_obs[thread_i]))
        ep_acts = torch.FloatTensor(np.array(self.ep_acts[thread_i]))
        ep_rwds = torch.FloatTensor(np.array(self.ep_rwds[thread_i]))
        ep_log_probs = torch.FloatTensor(np.array((self.ep_log_probs[thread_i])))
        ep_gms = torch.FloatTensor(np.array(self.ep_gms[thread_i]))
        ep_lmbs = torch.FloatTensor(np.array(self.ep_lmbs[thread_i]))

        # add ep_share_obs, ep_obs, ep_acts, ep_log_probs to all data buffer
        self.share_obs.append(ep_share_obs)
        self.obs.append(ep_obs)
        self.acts.append(ep_acts)
        self.log_probs.append(ep_log_probs)
        # add ep_ret to all data buffer
        # ep_costs = (-1) * torch.log(
        #     self.policy.get_discriminator_reward(ep_share_obs, ep_obs, ep_acts)
        # ).squeeze().detach().cpu()
        ep_costs = self.policy.get_discriminator_reward(ep_share_obs, ep_obs, ep_acts).squeeze().detach().cpu()
        ep_disc_costs = ep_gms * ep_costs
        ep_disc_rets = torch.stack(
            [torch.sum(ep_disc_costs[i:], dim=0) for i in range(now_thread_t)]
        )
        ep_rets = ep_disc_rets / ep_gms
        self.rets.append(ep_rets)
        """
        curr_vals torch.Size([ep_len, agent_num])
        next_vals torch.Size([ep_len, agent_num])
        ep_deltas torch.Size([ep_len, agent_num])
        ep_advs torch.Size([ep_len, agent_num])
        ep_targets torch.Size([ep_len, agent_num])
        """
        # add ep_advs, ep_gms to all data buffer
        self.policy.transformer.eval()
        curr_vals = self.policy.get_critic_values(ep_share_obs, ep_obs).detach().cpu().squeeze(-1)
        next_vals = torch.cat(
            (curr_vals[1:], torch.FloatTensor([0., 0.]).unsqueeze(0))
        )
        ep_deltas = ep_costs + self.args.gamma * next_vals - curr_vals
        ep_advs = torch.stack([
            torch.sum(((ep_gms * ep_lmbs)[:now_thread_t - j] * ep_deltas[j:]), dim=0)
            for j in range(now_thread_t)
        ])
        ep_targets = ep_advs + curr_vals
        self.advs.append(ep_advs)
        self.gms.append(ep_gms)
        self.targets.append(ep_targets)

        # all ep_rwds to all_episode_reward
        self.all_episode_reward.append(torch.sum(torch.mean(ep_rwds, dim=1)).item())
        # add episode len to new step
        self.now_steps += now_thread_t

    def clear_single_thread_episode_data(self, thread_i):
        self.ep_share_obs[thread_i] = []
        self.ep_obs[thread_i] = []
        self.ep_acts[thread_i] = []
        self.ep_rwds[thread_i] = []
        self.ep_dones[thread_i] = []
        self.ep_log_probs[thread_i] = []
        self.ep_gms[thread_i] = []
        self.ep_lmbs[thread_i] = []

    def prepare_data_for_train(self):
        """
        batch_share_obs torch.Size([num_steps_per_epochs, agent_num, share_ob_dim])
        batch_obs torch.Size([num_steps_per_epochs, agent_num, obs_dim])
        batch_acts torch.Size([num_steps_per_epochs, agent_num, act_dim])
        batch_log_probs torch.Size([num_steps_per_epochs, agent_num])
        batch_rets torch.Size([num_steps_per_epochs])
        batch_advs torch.Size([num_steps_per_epochs])
        batch_gms torch.Size([num_steps_per_epochs])
        """
        # cat all episode data together
        self.batch_share_obs = torch.cat(self.share_obs)[:self.num_steps_per_epochs]
        self.batch_obs = torch.cat(self.obs)[:self.num_steps_per_epochs]
        self.batch_acts = torch.cat(self.acts)[:self.num_steps_per_epochs]
        self.batch_log_probs = torch.cat(self.log_probs)[:self.num_steps_per_epochs]
        self.batch_rets = torch.cat(self.rets)[:self.num_steps_per_epochs]
        self.batch_advs = torch.cat(self.advs)[:self.num_steps_per_epochs]
        self.batch_gms = torch.cat(self.gms)[:self.num_steps_per_epochs]
        self.batch_targets = torch.cat(self.targets)[:self.num_steps_per_epochs]
        # normalize advantage if necessary
        if self.args.normalize_advantage:
            self.batch_advs = (self.batch_advs - self.batch_advs.mean()) / (self.batch_advs.std() + 1e-8)
        # clear data buffer to store all generate episode data for training
        self.share_obs = []
        self.obs = []
        self.acts = []
        self.log_probs = []
        self.rets = []
        self.advs = []
        self.gms = []
        self.targets = []
        self.mean_episode_reward = np.mean(self.all_episode_reward)
        self.all_episode_reward = []
        self.now_steps = 0

    def get_all_data_for_train(self):
        # get data for training and mean episode return
        mean_episode_reward = self.mean_episode_reward
        batch_share_obs = self.batch_share_obs
        batch_obs = self.batch_obs
        batch_acts = self.batch_acts
        batch_log_probs = self.batch_log_probs
        batch_rets = self.batch_rets
        batch_advs = self.batch_advs
        batch_gms = self.batch_gms
        # batch_targets = self.policy.get_critic_values(batch_share_obs, batch_obs).detach().cpu().squeeze(-1) + batch_advs
        batch_targets = self.batch_targets
        # clear data directly offer to model for training and
        # reset mean episode return
        # after get all data for train
        self.mean_episode_reward = 0.0
        self.batch_share_obs = None
        self.batch_obs = None
        self.batch_acts = None
        self.batch_log_probs = None
        self.batch_rets = None
        self.batch_advs = None
        self.batch_gms = None
        self.batch_targets = None

        return batch_share_obs, batch_obs, batch_acts, batch_log_probs, \
               batch_rets, batch_advs, batch_gms, batch_targets, mean_episode_reward

    def sample_batch_for_train(self, batch_size):
        indexes = np.random.choice(np.arange(self.num_steps_per_epochs), size=batch_size, replace=False)
        batch_share_obs = self.batch_share_obs[indexes]
        batch_obs = self.batch_obs[indexes]
        batch_acts = self.batch_acts[indexes]

        return batch_share_obs, batch_obs, batch_acts
