import torch
import numpy as np
from collections import defaultdict

from bta.utils.util import check, get_shape_from_obs_space, get_shape_from_act_space, flatten

def _flatten(T, N, x):
    return x.reshape(T * N, *x.shape[2:])

def _cast(x):
    if type(x) == np.ndarray:
        return x.transpose(1,0,2).reshape(-1, *x.shape[2:])
    else:
        return x.permute(1,0,2).reshape(-1, *x.shape[2:])

class SeparatedReplayBuffer(object):
    def __init__(self, args, obs_space, share_obs_space, act_space, agent_id):
        self.args = args
        self.episode_length = args.episode_length
        self.n_rollout_threads = args.n_rollout_threads
        self.rnn_hidden_size = args.hidden_size
        self.recurrent_N = args.recurrent_N
        self.gamma = args.gamma
        self.gae_lambda = args.gae_lambda
        self._use_gae = args.use_gae
        self._use_popart = args.use_popart
        self._use_valuenorm = args.use_valuenorm
        self._use_proper_time_limits = args.use_proper_time_limits
        self.use_graph = args.use_graph
        self.use_action_attention = args.use_action_attention
        self.args = args
        self.agent_idx = agent_id

        obs_shape = get_shape_from_obs_space(obs_space)
        share_obs_shape = get_shape_from_obs_space(share_obs_space)

        print(obs_shape)
        print(share_obs_shape)

        if args.env_name == "GoBigger":
            self.obs = [[] for _ in range(self.episode_length + 1)]
        else:
            if type(obs_shape[-1]) == list:
                obs_shape = obs_shape[:1]
            self.obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, *obs_shape), dtype=np.float32)

        if args.env_name == "GoBigger":
            self.share_obs = [[] for _ in range(self.episode_length + 1)]
        else:
            if type(share_obs_shape[-1]) == list:
                share_obs_shape = share_obs_shape[:1]
            self.share_obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, *share_obs_shape), dtype=np.float32)

        self.rnn_states = np.zeros((self.episode_length + 1, self.n_rollout_threads, self.recurrent_N, self.rnn_hidden_size), dtype=np.float32)
        self.rnn_states_critic = np.zeros_like(self.rnn_states)
        self.rnn_states_joint = np.zeros_like(self.rnn_states)

        self.value_preds = np.zeros((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)
        self.joint_value_preds = np.zeros((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)
        self.returns = np.zeros((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)
        self.ce_gaes = np.zeros_like(self.returns)
        
        self.mix_action = False
        self.act_shape = get_shape_from_act_space(act_space)
        if act_space.__class__.__name__ == "Discrete":
            self.action_dim = act_space.n
        elif act_space.__class__.__name__ == "Box":
            self.action_dim = act_space.shape[0]
        else:
            self.mix_action = True
            continous_dim = act_space[0].shape[0]
            discrete_dim = act_space[1].n
            self.action_dim = continous_dim + discrete_dim

        if act_space.__class__.__name__ == 'Discrete':
            self.available_actions = np.ones((self.episode_length + 1, self.n_rollout_threads, act_space.n), dtype=np.float32)
        elif self.mix_action:
            self.available_actions = np.ones((self.episode_length + 1, self.n_rollout_threads, act_space[1].n), dtype=np.float32)
        else:
            self.available_actions = None

        self.actions = np.zeros((self.episode_length, self.n_rollout_threads, self.act_shape))
        self.one_hot_actions = np.zeros((self.episode_length, self.n_rollout_threads, args.num_agents, self.action_dim), dtype=np.float32)
        self.action_log_probs = np.zeros((self.episode_length, self.n_rollout_threads, self.act_shape), dtype=np.float32)
        self.rewards = np.zeros((self.episode_length, self.n_rollout_threads, 1), dtype=np.float32)

        self.joint_actions = np.zeros((self.episode_length, self.n_rollout_threads, self.act_shape), dtype=np.float32)
        self.joint_action_log_probs = np.zeros((self.episode_length, self.n_rollout_threads, self.act_shape), dtype=np.float32)
        self.thresholds = np.zeros((self.episode_length, self.n_rollout_threads, args.num_agents, 1), dtype=np.float32)
        self.bias = np.zeros((self.episode_length, self.n_rollout_threads, self.action_dim), dtype=np.float32)
        self.logits = np.zeros((self.episode_length, self.n_rollout_threads, self.action_dim), dtype=np.float32)
    
        self.masks = np.ones((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)
        self.bad_masks = np.ones_like(self.masks)
        self.active_masks = np.ones_like(self.masks)

        self.factor = None
        self.action_grad = None

        self.step = 0

    def insert(self, share_obs, obs, rnn_states, rnn_states_critic, actions, hard_actions, action_log_probs,
               value_preds, rewards, masks, bad_masks=None, active_masks=None, available_actions=None, 
               joint_actions=None, joint_action_log_probs=None, joint_value_preds=None, rnn_states_joint=None,
               thresholds=None, bias=None, logits=None):
        self.share_obs[self.step + 1] = share_obs.copy()
        self.obs[self.step + 1] = obs.copy()
        self.rnn_states[self.step + 1] = rnn_states.copy()
        self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
        self.one_hot_actions[self.step] = actions.copy()
        self.actions[self.step] = hard_actions.copy()
        self.action_log_probs[self.step] = action_log_probs.copy()
        self.value_preds[self.step] = value_preds.copy()
        self.rewards[self.step] = rewards.copy()
        self.masks[self.step + 1] = masks.copy()
        if joint_value_preds is not None:
            self.joint_value_preds[self.step] = joint_value_preds.copy()
        if rnn_states_joint is not None:
            self.rnn_states_joint[self.step + 1] = rnn_states_joint.copy()
        if bad_masks is not None:
            self.bad_masks[self.step + 1] = bad_masks.copy()
        if active_masks is not None:
            self.active_masks[self.step + 1] = active_masks.copy()
        if available_actions is not None:
            self.available_actions[self.step + 1] = available_actions.copy()
        if joint_actions is not None:
            self.joint_actions[self.step] = joint_actions.copy()
        if joint_action_log_probs is not None:
            self.joint_action_log_probs[self.step] = joint_action_log_probs.copy()
        if thresholds is not None:
            self.thresholds[self.step] = thresholds if type(thresholds)==float else thresholds.copy()
        if bias is not None:
            self.bias[self.step] = bias.copy()
        if logits is not None:
            self.logits[self.step] = logits.copy()
        
        self.step = (self.step + 1) % self.episode_length
        
    def chooseinsert(self, share_obs, obs, rnn_states, rnn_states_critic, actions, one_hot_actions, action_log_probs,
                     value_preds, rewards, masks, bad_masks=None, active_masks=None, available_actions=None, joint_actions=None, joint_action_log_probs=None):
        self.share_obs[self.step] = share_obs.copy()
        self.obs[self.step] = obs.copy()
        self.rnn_states[self.step + 1] = rnn_states.copy()
        self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()
        self.actions[self.step] = actions.copy()
        self.one_hot_actions[self.step] = one_hot_actions.copy()
        self.action_log_probs[self.step] = action_log_probs.copy()
        self.value_preds[self.step] = value_preds.copy()
        self.rewards[self.step] = rewards.copy()
        self.masks[self.step + 1] = masks.copy()
        if bad_masks is not None:
            self.bad_masks[self.step + 1] = bad_masks.copy()
        if active_masks is not None:
            self.active_masks[self.step] = active_masks.copy()
        if available_actions is not None:
            self.available_actions[self.step] = available_actions.copy()
        if joint_actions is not None:
            self.joint_actions[self.step + 1] = joint_actions.copy()
        if joint_action_log_probs is not None:
            self.joint_action_log_probs[self.step + 1] = joint_action_log_probs.copy()

        self.step = (self.step + 1) % self.episode_length
    
    def update_factor(self, factor):
        self.factor = factor.copy()

    def update_action_grad(self, action_grad):
        self.action_grad = action_grad.copy()

    def after_update(self):
        self.share_obs[0] = self.share_obs[-1].copy()
        self.obs[0] = self.obs[-1].copy()
        self.rnn_states[0] = self.rnn_states[-1].copy()
        self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy()
        self.masks[0] = self.masks[-1].copy()
        self.bad_masks[0] = self.bad_masks[-1].copy()
        self.active_masks[0] = self.active_masks[-1].copy()
        if self.available_actions is not None:
            self.available_actions[0] = self.available_actions[-1].copy()
        
    def chooseafter_update(self):
        self.rnn_states[0] = self.rnn_states[-1].copy()
        self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy()
        self.masks[0] = self.masks[-1].copy()
        self.bad_masks[0] = self.bad_masks[-1].copy()

    def compute_returns(self, next_value, value_normalizer=None):
        if self._use_proper_time_limits:
            if self._use_gae:
                self.value_preds[-1] = next_value
                gae = 0
                for step in reversed(range(self.rewards.shape[0])):
                    if self._use_popart or self._use_valuenorm:
                        delta = self.rewards[step] + self.gamma * value_normalizer.denormalize(self.value_preds[
                            step + 1]) * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step])
                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae
                        gae = gae * self.bad_masks[step + 1]
                        self.returns[step] = gae + value_normalizer.denormalize(self.value_preds[step])
                    else:
                        delta = self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae
                        gae = gae * self.bad_masks[step + 1]
                        self.returns[step] = gae + self.value_preds[step]
            else:
                self.returns[-1] = next_value
                for step in reversed(range(self.rewards.shape[0])):
                    if self._use_popart:
                        self.returns[step] = (self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \
                            + (1 - self.bad_masks[step + 1]) * value_normalizer.denormalize(self.value_preds[step])
                    else:
                        self.returns[step] = (self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \
                            + (1 - self.bad_masks[step + 1]) * self.value_preds[step]
        else:
            if self._use_gae:
                self.value_preds[-1] = next_value
                # if self.use_action_attention:
                #     ce_gae = 0
                #     imp_weights = np.prod(np.exp(self.joint_action_log_probs - self.action_log_probs), -1, keepdims=True)
                #     clipped_weights = np.clip(imp_weights, a_max=1.0, a_min=None)
                #     truncated_weights = np.minimum(imp_weights, clipped_weights)
                #     for step in reversed(range(self.rewards.shape[0])):
                #         rho = truncated_weights[step + 1] if step < self.rewards.shape[0] - 1 else 1
                #         if self._use_popart or self._use_valuenorm:
                #             ce_delta = self.rewards[step] + self.gamma * value_normalizer.denormalize(self.value_preds[step + 1]) * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step])
                #             ce_gae = ce_delta + rho * self.gamma * self.gae_lambda * self.masks[step + 1] * ce_gae
                #             self.ce_gaes[step] = ce_gae
                #         else:
                #             ce_delta = self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
                #             ce_gae = ce_delta + rho * self.gamma * self.gae_lambda * self.masks[step + 1] * ce_gae
                #             self.ce_gaes[step] = ce_gae
                # else:
                gae = 0
                for step in reversed(range(self.rewards.shape[0])):
                    if self._use_popart or self._use_valuenorm:
                        delta = self.rewards[step] + self.gamma * value_normalizer.denormalize(self.value_preds[step + 1]) * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step])
                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae
                        self.returns[step] = gae + value_normalizer.denormalize(self.value_preds[step])
                    else:
                        delta = self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae
                        self.returns[step] = gae + self.value_preds[step]
            else:
                self.returns[-1] = next_value
                for step in reversed(range(self.rewards.shape[0])):
                    self.returns[step] = self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step]
            
    def feed_forward_generator(self, advantages, num_mini_batch=None, mini_batch_size=None, sampler=None):
        episode_length, n_rollout_threads = self.rewards.shape[0:2]
        batch_size = n_rollout_threads * episode_length

        if mini_batch_size is None:
            assert batch_size >= num_mini_batch, (
                "PPO requires the number of processes ({}) "
                "* number of steps ({}) = {} "
                "to be greater than or equal to the number of PPO mini batches ({})."
                "".format(n_rollout_threads, episode_length, n_rollout_threads * episode_length,
                          num_mini_batch))
            mini_batch_size = batch_size // num_mini_batch

        if sampler == None:
            rand = torch.randperm(batch_size).numpy()
            sampler = [rand[i*mini_batch_size:(i+1)*mini_batch_size] for i in range(num_mini_batch)]
        self.advg = advantages
        one_hot_actions = self.one_hot_actions.reshape(-1, *self.one_hot_actions.shape[2:])
        if self.args.env_name == "GoBigger":
            share_obs = flatten(self.share_obs[:-1])
            obs = flatten(self.obs[:-1])
        else:
            share_obs = self.share_obs[:-1].reshape(-1, *self.share_obs.shape[2:])
            obs = self.obs[:-1].reshape(-1, *self.obs.shape[2:])
        rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:])
        rnn_states_critic = self.rnn_states_critic[:-1].reshape(-1, *self.rnn_states_critic.shape[2:])
        actions = self.actions.reshape(-1, *self.actions.shape[2:])
        if self.available_actions is not None:
            available_actions = self.available_actions[:-1].reshape(-1, self.available_actions.shape[-1])
        if self.factor is not None:
            # factor = self.factor.reshape(-1,1)
            factor = self.factor.reshape(-1, self.factor.shape[-1])
        if self.action_grad is not None:
            # factor = self.factor.reshape(-1,1)
            action_grad = self.action_grad.reshape(-1, self.action_grad.shape[-1])
        value_preds = self.value_preds[:-1].reshape(-1, 1)
        returns = self.returns[:-1].reshape(-1, 1)
        ce_gaes = self.ce_gaes[:-1].reshape(-1, 1)
        masks = self.masks[:-1].reshape(-1, 1)
        active_masks = self.active_masks[:-1].reshape(-1, 1)
        action_log_probs = self.action_log_probs.reshape(-1, *self.action_log_probs.shape[2:])
        advantages = advantages.reshape(-1, 1)
        joint_actions = self.joint_actions.reshape(-1, *self.joint_actions.shape[2:])
        joint_action_log_probs = self.joint_action_log_probs.reshape(-1, *self.joint_action_log_probs.shape[2:])
        thresholds = self.thresholds.reshape(-1, *self.thresholds.shape[2:])
        bias = self.bias.reshape(-1, *self.bias.shape[2:])
        logits = self.logits.reshape(-1, *self.logits.shape[2:])

        for indices in sampler:
            # obs size [T+1 N Dim]-->[T N Dim]-->[T*N,Dim]-->[index,Dim]
            share_obs_batch = share_obs[indices]
            obs_batch = obs[indices]
            rnn_states_batch = rnn_states[indices]
            rnn_states_critic_batch = rnn_states_critic[indices]
            actions_batch = actions[indices]
            one_hot_actions_batch = one_hot_actions[indices]
            if self.available_actions is not None:
                available_actions_batch = available_actions[indices]
            else:
                available_actions_batch = None
            value_preds_batch = value_preds[indices]
            return_batch = returns[indices]
            ce_gaes_batch = ce_gaes[indices]
            masks_batch = masks[indices]
            active_masks_batch = active_masks[indices]
            old_action_log_probs_batch = action_log_probs[indices]
            joint_actions_batch = joint_actions[indices]
            joint_action_log_probs_batch = joint_action_log_probs[indices]
            thresholds_batch = thresholds[indices]
            bias_batch = bias[indices]
            logits_batch = logits[indices]
            if advantages is None:
                adv_targ = None
            else:
                adv_targ = advantages[indices]
            if self.factor is None:
                factor_batch = None
            else:
                factor_batch = factor[indices]
            if self.action_grad is None:
                action_grad_batch = None
            else:
                action_grad_batch = action_grad[indices]

            yield share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, \
                actions_batch, one_hot_actions_batch, value_preds_batch, return_batch, \
                masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, \
                available_actions_batch, factor_batch, action_grad_batch, joint_actions_batch, \
                joint_action_log_probs_batch, thresholds_batch, ce_gaes_batch, bias_batch, logits_batch

    def recurrent_generator(self, advantages, num_mini_batch, data_chunk_length, sampler=None):
        episode_length, n_rollout_threads = self.rewards.shape[0:2]
        batch_size = n_rollout_threads * episode_length
        data_chunks = batch_size // data_chunk_length  # [C=r*T/L]
        mini_batch_size = data_chunks // num_mini_batch

        assert episode_length * n_rollout_threads >= data_chunk_length, (
            "PPO requires the number of processes ({}) * episode length ({}) "
            "to be greater than or equal to the number of "
            "data chunk length ({}).".format(n_rollout_threads, episode_length, data_chunk_length))
        assert data_chunks >= 2, ("need larger batch size")
        self.advg = advantages
        if sampler == None:
            rand = torch.randperm(data_chunks).numpy()
            sampler = [rand[i*mini_batch_size:(i+1)*mini_batch_size] for i in range(num_mini_batch)]
        
        if len(self.share_obs.shape) > 3:
            share_obs = self.share_obs[:-1].transpose(1, 0, 2, 3, 4).reshape(-1, *self.share_obs.shape[2:])
            obs = self.obs[:-1].transpose(1, 0, 2, 3, 4).reshape(-1, *self.obs.shape[2:])
        else:
            share_obs = _cast(self.share_obs[:-1])
            obs = _cast(self.obs[:-1])

        one_hot_actions = self.one_hot_actions.transpose(1,0,2,3).reshape(-1, *self.one_hot_actions.shape[2:])
        actions = _cast(self.actions)
        action_log_probs = _cast(self.action_log_probs)
        advantages = _cast(advantages)
        if self.factor is not None:
            factor = _cast(self.factor)
        if self.action_grad is not None:
            action_grad = _cast(self.action_grad)
        value_preds = _cast(self.value_preds[:-1])
        returns = _cast(self.returns[:-1])
        ce_gaes = _cast(self.ce_gaes[:-1])
        masks = _cast(self.masks[:-1])
        active_masks = _cast(self.active_masks[:-1])
        rnn_states_joint = self.rnn_states_joint[:-1].transpose(1, 0, 2, 3).reshape(-1, *self.rnn_states_joint.shape[2:])
        rnn_states = self.rnn_states[:-1].transpose(1, 0, 2, 3).reshape(-1, *self.rnn_states.shape[2:])
        rnn_states_critic = self.rnn_states_critic[:-1].transpose(1, 0, 2, 3).reshape(-1, *self.rnn_states_critic.shape[2:])
        joint_actions = _cast(self.joint_actions)
        joint_action_log_probs = _cast(self.joint_action_log_probs)
        thresholds = self.thresholds.transpose(1, 0, 2, 3).reshape(-1, *self.thresholds.shape[2:])
        bias = _cast(self.bias)
        logits = _cast(self.logits)

        if self.available_actions is not None:
            available_actions = _cast(self.available_actions[:-1])

        for indices in sampler:
            share_obs_batch = []
            obs_batch = []
            rnn_states_batch = []
            rnn_states_critic_batch = []
            rnn_states_joint_batch = []
            actions_batch = []
            one_hot_actions_batch = []
            available_actions_batch = []
            value_preds_batch = []
            return_batch = []
            ce_gae_batch = []
            masks_batch = []
            active_masks_batch = []
            old_action_log_probs_batch = []
            adv_targ = []
            factor_batch = []
            action_grad_batch = []
            joint_actions_batch = []
            joint_action_log_probs_batch = []
            thresholds_batch = []
            bias_batch = []
            logits_batch = []

            for index in indices:
                ind = index * data_chunk_length
                # size [T+1 N M Dim]-->[T N Dim]-->[N T Dim]-->[T*N,Dim]-->[L,Dim]
                share_obs_batch.append(share_obs[ind:ind+data_chunk_length])
                obs_batch.append(obs[ind:ind+data_chunk_length])
                actions_batch.append(actions[ind:ind+data_chunk_length])
                one_hot_actions_batch.append(one_hot_actions[ind:ind+data_chunk_length])
                if self.available_actions is not None:
                    available_actions_batch.append(available_actions[ind:ind+data_chunk_length])
                if self.factor is None:
                    factor_batch = None
                else:
                    factor_batch.append(factor[ind:ind+data_chunk_length])
                if self.action_grad is None:
                    action_grad_batch = None
                else:
                    action_grad_batch.append(action_grad[ind:ind+data_chunk_length])
                value_preds_batch.append(value_preds[ind:ind+data_chunk_length])
                return_batch.append(returns[ind:ind+data_chunk_length])
                ce_gae_batch.append(ce_gaes[ind:ind+data_chunk_length])
                masks_batch.append(masks[ind:ind+data_chunk_length])
                active_masks_batch.append(active_masks[ind:ind+data_chunk_length])
                old_action_log_probs_batch.append(action_log_probs[ind:ind+data_chunk_length])
                adv_targ.append(advantages[ind:ind+data_chunk_length])
                # size [T+1 N Dim]-->[T N Dim]-->[T*N,Dim]-->[1,Dim]
                rnn_states_batch.append(rnn_states[ind])
                rnn_states_critic_batch.append(rnn_states_critic[ind])
                rnn_states_joint_batch.append(rnn_states_joint[ind])
                joint_actions_batch.append(joint_actions[ind:ind+data_chunk_length])
                joint_action_log_probs_batch.append(joint_action_log_probs[ind:ind+data_chunk_length])
                thresholds_batch.append(thresholds[ind:ind+data_chunk_length])
                bias_batch.append(bias[ind:ind+data_chunk_length])
                logits_batch.append(logits[ind:ind+data_chunk_length])

            L, N = data_chunk_length, mini_batch_size

            # These are all from_numpys of size (N, L, Dim)
            share_obs_batch = np.stack(share_obs_batch)
            obs_batch = np.stack(obs_batch)

            actions_batch = np.stack(actions_batch)
            one_hot_actions_batch = np.stack(one_hot_actions_batch)
            if self.available_actions is not None:
                available_actions_batch = np.stack(available_actions_batch)
            if self.factor is not None:
                factor_batch = np.stack(factor_batch)
            if self.action_grad is not None:
                action_grad_batch = np.stack(action_grad_batch)
            value_preds_batch = np.stack(value_preds_batch)
            return_batch = np.stack(return_batch)
            ce_gae_batch = np.stack(ce_gae_batch)
            masks_batch = np.stack(masks_batch)
            active_masks_batch = np.stack(active_masks_batch)
            old_action_log_probs_batch = np.stack(old_action_log_probs_batch)
            adv_targ = np.stack(adv_targ)
            joint_actions_batch = np.stack(joint_actions_batch)
            joint_action_log_probs_batch = np.stack(joint_action_log_probs_batch)
            thresholds_batch = np.stack(thresholds_batch)
            bias_batch = np.stack(bias_batch)
            logits_batch = np.stack(logits_batch)

            # States is just a (N, -1) from_numpy
            rnn_states_batch = np.stack(rnn_states_batch).reshape(N, *self.rnn_states.shape[2:])
            rnn_states_critic_batch = np.stack(rnn_states_critic_batch).reshape(N, *self.rnn_states_critic.shape[2:])
            rnn_states_joint_batch = np.stack(rnn_states_joint_batch).reshape(N, *self.rnn_states_joint.shape[2:])

            # Flatten the (L, N, ...) from_numpys to (L * N, ...)
            share_obs_batch = _flatten(L, N, share_obs_batch)
            obs_batch = _flatten(L, N, obs_batch)
            actions_batch = _flatten(L, N, actions_batch)
            one_hot_actions_batch = _flatten(L, N, one_hot_actions_batch)
            if self.available_actions is not None:
                available_actions_batch = _flatten(L, N, available_actions_batch)
            else:
                available_actions_batch = None
            if self.factor is not None:
                factor_batch = _flatten(L, N, factor_batch)
            else:
                factor_batch = None
            if self.action_grad is not None:
                action_grad_batch = _flatten(L, N, action_grad_batch)
            else:
                action_grad_batch = None
            value_preds_batch = _flatten(L, N, value_preds_batch)
            return_batch = _flatten(L, N, return_batch)
            ce_gae_batch = _flatten(L, N, ce_gae_batch)
            masks_batch = _flatten(L, N, masks_batch)
            active_masks_batch = _flatten(L, N, active_masks_batch)
            old_action_log_probs_batch = _flatten(L, N, old_action_log_probs_batch)
            adv_targ = _flatten(L, N, adv_targ)
            joint_actions_batch = _flatten(L, N, joint_actions_batch)
            joint_action_log_probs_batch = _flatten(L, N, joint_action_log_probs_batch)
            thresholds_batch = _flatten(L, N, thresholds_batch)
            bias_batch = _flatten(L, N, bias_batch)
            logits_batch = _flatten(L, N, logits_batch)

            yield share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
                one_hot_actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, \
                old_action_log_probs_batch, adv_targ, available_actions_batch, factor_batch, action_grad_batch, \
                joint_actions_batch, joint_action_log_probs_batch, rnn_states_joint_batch, thresholds_batch, ce_gae_batch, bias_batch, logits_batch

    def naive_recurrent_generator(self, advantages, num_mini_batch):
        n_rollout_threads = self.rewards.shape[1]
        assert n_rollout_threads >= num_mini_batch, (
            "PPO requires the number of processes ({}) "
            "to be greater than or equal to the number of "
            "PPO mini batches ({}).".format(n_rollout_threads, num_mini_batch))
        num_envs_per_batch = n_rollout_threads // num_mini_batch
        perm = torch.randperm(n_rollout_threads).numpy()
        for start_ind in range(0, n_rollout_threads, num_envs_per_batch):
            share_obs_batch = []
            obs_batch = []
            rnn_states_batch = []
            rnn_states_critic_batch = []
            actions_batch = []
            available_actions_batch = []
            value_preds_batch = []
            return_batch = []
            masks_batch = []
            active_masks_batch = []
            old_action_log_probs_batch = []
            adv_targ = []

            for offset in range(num_envs_per_batch):
                ind = perm[start_ind + offset]
                share_obs_batch.append(self.share_obs[:-1, ind])
                obs_batch.append(self.obs[:-1, ind])
                rnn_states_batch.append(self.rnn_states[0:1, ind])
                rnn_states_critic_batch.append(self.rnn_states_critic[0:1, ind])
                actions_batch.append(self.actions[:, ind])
                if self.available_actions is not None:
                    available_actions_batch.append(self.available_actions[:-1, ind])
                value_preds_batch.append(self.value_preds[:-1, ind])
                return_batch.append(self.returns[:-1, ind])
                masks_batch.append(self.masks[:-1, ind])
                active_masks_batch.append(self.active_masks[:-1, ind])
                old_action_log_probs_batch.append(self.action_log_probs[:, ind])
                adv_targ.append(advantages[:, ind])

            # [N[T, dim]]
            T, N = self.episode_length, num_envs_per_batch
            # These are all from_numpys of size (T, N, -1)
            share_obs_batch = np.stack(share_obs_batch, 1)
            obs_batch = np.stack(obs_batch, 1)
            actions_batch = np.stack(actions_batch, 1)
            if self.available_actions is not None:
                available_actions_batch = np.stack(available_actions_batch, 1)
            value_preds_batch = np.stack(value_preds_batch, 1)
            return_batch = np.stack(return_batch, 1)
            masks_batch = np.stack(masks_batch, 1)
            active_masks_batch = np.stack(active_masks_batch, 1)
            old_action_log_probs_batch = np.stack(old_action_log_probs_batch, 1)
            adv_targ = np.stack(adv_targ, 1)

            # States is just a (N, -1) from_numpy [N[1,dim]]
            rnn_states_batch = np.stack(rnn_states_batch, 1).reshape(N, *self.rnn_states.shape[2:])
            rnn_states_critic_batch = np.stack(rnn_states_critic_batch, 1).reshape(N, *self.rnn_states_critic.shape[2:])

            # Flatten the (T, N, ...) from_numpys to (T * N, ...)
            share_obs_batch = _flatten(T, N, share_obs_batch)
            obs_batch = _flatten(T, N, obs_batch)
            actions_batch = _flatten(T, N, actions_batch)
            if self.available_actions is not None:
                available_actions_batch = _flatten(T, N, available_actions_batch)
            else:
                available_actions_batch = None
            value_preds_batch = _flatten(T, N, value_preds_batch)
            return_batch = _flatten(T, N, return_batch)
            masks_batch = _flatten(T, N, masks_batch)
            active_masks_batch = _flatten(T, N, active_masks_batch)
            old_action_log_probs_batch = _flatten(T, N, old_action_log_probs_batch)
            adv_targ = _flatten(T, N, adv_targ)

            yield share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch
