import time
from debug import debug_print, get_size
import numpy as np
import torch
import torch.nn as nn
from onpolicy.utils.util import get_gard_norm, huber_loss, mse_loss
from onpolicy.utils.valuenorm import ValueNorm
from onpolicy.algorithms.utils.util import check
import random
import math

def extract(x, t):
    assert (x.shape[:len(t.shape)] == t.shape), (x.shape, t.shape)
    idx = len(t.shape)
    o_t = t
    t = t.unsqueeze(-1)
    num_repetitions = int(np.prod(x.shape[idx+1:]))
    t = t.unsqueeze(-1).repeat(*([1,] * len(t.shape) + [num_repetitions,]))
    t = t.reshape(*t.shape[:-1], *x.shape[idx+1:])
    y = torch.gather(x, dim=idx, index=t)
    y = y.squeeze(idx)
    return y

def _t2n(x):
    if isinstance(x, np.ndarray):
        return x
    return x.detach().cpu().numpy()

def compute_returns(buffer, next_value, value_normalizer=None):
    """
    Compute returns either as discounted sum of rewards, or using GAE.
    :param next_value: (np.ndarray) value predictions for the step after the last episode step.
    :param value_normalizer: (PopArt) If not None, PopArt value normalizer instance.
    """
    rewards = np.concatenate([buffer.rewards], axis=-1)#[..., np.newaxis].transpose(0, 3, 1, 2, 4)
    value_preds = buffer.value_preds.copy()
    value_preds[-1] = next_value
    # value_preds = value_preds.transpose(0, 3, 1, 2)[..., np.newaxis]
    masks = np.concatenate([buffer.masks], axis=-1)#.transpose(0, 3, 1, 2)[..., np.newaxis]
    # print(rewards[:, :, 0, 0], masks[:, :, 0, 0])
    bad_masks = np.concatenate([buffer.bad_masks], axis=-1)#.transpose(0, 3, 1, 2)[..., np.newaxis]
    returns = buffer.returns.copy()#.transpose(0, 3, 1, 2)[..., np.newaxis]

    # rewards = np.concatenate(rewards)
    # value_preds = np.concatenate(value_preds)
    # masks = np.concatenate(masks)
    # bad_masks = np.concatenate(bad_masks)
    # returns = np.concatenate(returns)

    if buffer._use_proper_time_limits:
        if buffer._use_gae:
            gae = 0
            mul = buffer.gamma * buffer.gae_lambda
            # value_preds[-1] = next_value
            if buffer._use_popart or buffer._use_valuenorm:
                denormed_values = value_normalizer.denormalize(value_preds)
                delta = rewards + buffer.gamma * denormed_values[1:] * masks[1:] \
                        - denormed_values[:-1]
            else:
                delta = rewards + buffer.gamma * value_preds[1:] * masks[1:] - value_preds[:-1]
            for step in reversed(range(rewards.shape[0])):
                if buffer._use_popart or buffer._use_valuenorm:
                    gae = delta[step] + mul * gae * masks[step + 1]
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + denormed_values[step]
                else:
                    gae = delta[step] + mul * masks[step + 1] * gae
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + value_preds[step]
            # for step in reversed(range(rewards.shape[0])):
            #     if buffer._use_popart or buffer._use_valuenorm:
            #         # step + 1
            #         delta = rewards[step] + buffer.gamma * value_normalizer.denormalize(
            #             value_preds[step + 1]) * masks[step + 1] \
            #                 - value_normalizer.denormalize(value_preds[step])
            #         gae = delta + buffer.gamma * buffer.gae_lambda * gae * masks[step + 1]
            #         gae = gae * bad_masks[step + 1]
            #         returns[step] = gae + value_normalizer.denormalize(value_preds[step])
            #     else:
            #         delta = rewards[step] + buffer.gamma * value_preds[step + 1] * masks[step + 1] - \
            #                 value_preds[step]
            #         gae = delta + buffer.gamma * buffer.gae_lambda * masks[step + 1] * gae
            #         gae = gae * bad_masks[step + 1]
            #         returns[step] = gae + value_preds[step]
    else:
        # print('fa')
        if buffer._use_gae:
            value_preds[-1] = next_value
            gae = 0
            mul = buffer.gamma * buffer.gae_lambda
            # value_preds[-1] = next_value
            if buffer._use_popart or buffer._use_valuenorm:
                denormed_values = value_normalizer.denormalize(value_preds)[..., 0, None]
                delta = rewards + buffer.gamma * denormed_values[1:] * masks[1:] \
                        - denormed_values[:-1]
            else:
                delta = rewards + buffer.gamma * value_preds[1:] * masks[1:] - value_preds[:-1]
            for step in reversed(range(rewards.shape[0])):
                if buffer._use_popart or buffer._use_valuenorm:
                    gae = delta[step] + mul * gae * masks[step + 1]
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + denormed_values[step]
                else:
                    gae = delta[step] + mul * masks[step + 1] * gae
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + value_preds[step]

    # debug_print(returns.shape, buffer.rnum_agents)
    returns = returns.reshape(buffer.episode_length + 1, 1, buffer.n_rollout_threads, buffer.rnum_agents).transpose(0, 2, 3, 1)
    value_preds = value_preds.reshape(buffer.episode_length + 1, 1, buffer.n_rollout_threads, buffer.rnum_agents).transpose(0, 2, 3, 1)

    assert (returns.shape == buffer.returns.shape), (returns.shape, buffer.returns.shape)
    assert (value_preds.shape == buffer.value_preds.shape), (value_preds.shape, buffer.value_preds.shape)

    buffer.returns = returns
    buffer.value_preds = value_preds

class DiffusionMAPPO():
    def __init__(self, 
                 args,
                 policy,
                 device=torch.device("cpu")):


        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy

        
        self.algorithm_name = args.algorithm_name
        self.bc_buffer_limit = args.bc_buffer_limit
        self.clip_param = args.clip_param
        self.ppo_epoch = args.ppo_epoch
        self.bc_epoch = args.bc_epoch
        self.num_mini_batch = args.num_mini_batch
        self.data_chunk_length = args.data_chunk_length
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm       
        self.huber_delta = args.huber_delta
        self.n_timesteps = args.n_timesteps
        self.t_dim = args.t_dim
        self.value_dim = args.value_dim
        self.bc_loss_coef = args.bc_loss_coef
        self.ppo_loss_coef = args.ppo_loss_coef
        self.aug_latent_actions = args.aug_latent_actions
        self.recompute_adv = args.recompute_adv or args.aug_latent_actions
        self.sep_bc_phase = args.sep_bc_phase
        self.n_rollout_threads = args.n_rollout_threads
        self.rnum_agents = args.rnum_agents
        self.num_agents = args.num_agents
        self.joint_train = args.joint_train
        self.use_symmetry = args.use_symmetry
        self.use_attention = args.use_attention
        self.no_rand_train = args.no_rand_train
        self.norm_reward = args.norm_reward
        self.logit_scaling = args.logit_scaling
        self.use_latent_prob = args.use_latent_prob

        self._use_recurrent_policy = args.use_recurrent_policy
        self._use_naive_recurrent = args.use_naive_recurrent_policy
        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_popart = args.use_popart
        self._use_valuenorm = args.use_valuenorm
        self._use_value_active_masks = False #args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        self.negative_sample_scale = args.negative_sample_scale
        self.critic_epoch = args.critic_epoch
        self.normalize_advantage = args.normalize_advantage
        self.sep_logprob = args.sep_logprob
        self.act_step = args.act_step
        self.normalize_advantage_mean = args.normalize_advantage_mean
        self.args = args

        assert (self._use_popart and self._use_valuenorm) == False, ("self._use_popart and self._use_valuenorm can not be set True simultaneously")
        
        if self._use_popart:
            self.value_normalizer = self.policy.critic.v_out
        elif self._use_valuenorm:
            self.value_normalizer = ValueNorm(args.rnum_agents, device=self.device).to(self.device)
        else:
            self.value_normalizer = None

        if self.sep_bc_phase:
            # use seperate optimizer and model when using seperate behavior cloning phase
            self.bc_actor_state_dict = self.policy.actor.state_dict()
            self.bc_optimizer_state_dict = self.policy.actor_optimizer.state_dict()
            self.rl_actor_state_dict = self.policy.actor.state_dict()
            self.rl_optimizer_state_dict = self.policy.actor_optimizer.state_dict()

    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):
        """
        Calculate value function loss.
        :param values: (torch.Tensor) value function predictions.
        :param value_preds_batch: (torch.Tensor) "old" value  predictions from data batch (used for value clip loss)
        :param return_batch: (torch.Tensor) reward to go returns.
        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.

        :return value_loss: (torch.Tensor) value function loss.
        """
        value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
                                                                                        self.clip_param)
        # print((values-return_batch)[0:10, 0])
        if self._use_popart or self._use_valuenorm:
            # debug_print(return_batch.shape)
            self.value_normalizer.update(return_batch)
            error_clipped = self.value_normalizer.normalize(return_batch) - value_pred_clipped
            error_original = self.value_normalizer.normalize(return_batch) - values
        else:
            error_clipped = return_batch - value_pred_clipped
            error_original = return_batch - values

        if self._use_huber_loss:
            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
            value_loss_original = huber_loss(error_original, self.huber_delta)
        else:
            value_loss_clipped = mse_loss(error_clipped)
            value_loss_original = mse_loss(error_original)

        if self._use_clipped_value_loss:
            value_loss = torch.max(value_loss_original, value_loss_clipped)
        else:
            value_loss = value_loss_original
        # debug_print('val', value_loss.shape)
        
        if self._use_value_active_masks:
            value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            value_loss = value_loss.mean()

        return value_loss

    def compute_entropy_loss(self, dist_entropy):
        return dist_entropy * self.entropy_coef
    
    def compute_extra_actor_loss(self):
        return 0, dict()

    def bc_update(self, sample, progress=0.0):
        if self.use_symmetry:
            share_obs_batch, obs_batch, obs_sym_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, latent_actions_batch, sampled_actions_batch,\
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, action_log_probs_last_batch, \
        adv_targ, available_actions_batch = sample
        else:
            share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, latent_actions_batch, sampled_actions_batch,\
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, action_log_probs_last_batch, \
        adv_targ, available_actions_batch = sample
        
        # latent_actions_0 = latent_actions_batch[..., -1, :].reshape(*latent_actions_batch.shape[:-2], -1)
        
        # length = 256
        # # debug_print(obs_batch.shape)
        # shape = obs_batch.shape
        # if len(self.bc_loss_buffer) == 0:
        #     return torch.tensor(0), torch.tensor(0)
        # idx = (torch.randint(0, len(self.bc_loss_buffer), (length,)), torch.randint(0, shape[0], (length,)))
        # data = [torch.zeros((length, *(obs_batch.shape[1:])), dtype=torch.float32), torch.zeros((length, *(rnn_states_batch.shape[1:])), dtype=torch.float32), torch.zeros((length, *(latent_actions_0.shape[1:])), dtype=torch.float32), torch.zeros((length, *(masks_batch.shape[1:])), dtype=torch.float32), None, torch.zeros((length, *(active_masks_batch.shape[1:])), dtype=torch.float32)]
        # rnd_len = 0
        # for i in range(length):
        #     for j in range(len(data)):
        #         if j!=4 and j!=0:
        #             # print(j, shape, self.bc_loss_buffer[idx[0][i].item()][j].shape)
        #             div = 1 if j!=1 else shape[0]//rnn_states_batch.shape[0]
        #             # if type(self.bc_loss_buffer[idx[0][i].item()][j]) == torch.Tensor:
        #             #     self.bc_loss_buffer[idx[0][i]][j].to('cpu')
        #             data[j][i] = torch.tensor(self.bc_loss_buffer[idx[0][i].item()][j][idx[1][i].item()//div])
        #         if j==2 and i < rnd_len:
        #             data[j][i] = torch.zeros_like(data[j][i])
        #             # data[j][i] = torch.tensor(self.bc_loss_buffer[idx[0][i].item()][j][idx[1][i].item()//div])
        #         id = 1
        #         if not self.use_symmetry:
        #             id = 0
        #         if j==0:
        #             data[j][i] = torch.tensor(self.bc_loss_buffer[idx[0][i].item()][j][idx[1][i].item()][id])
        #         if j==2 and id==1:
        #             data[j][i] = data[j][i].view(-1, self.rnum_agents).flip(1).view(-1)
        #             # data[j][i] = torch.tensor(self.bc_loss_buffer[idx[0][i].item()][j][idx[1][i].item()])
        # bc_loss = self.policy.actor.bc_loss(*data)
        # debug_print('obs', obs_batch.shape)

        latent_actions_batch = check(latent_actions_batch).to(**self.tpdv)
        active_masks_batch = check(active_masks_batch).to(**self.tpdv)
        latent_actions_0 = latent_actions_batch[..., -1, :]
        
        # obs_batch = obs_batch.reshape(obs_batch.shape[0], self.rnum_agents, -1)
        # latent_actions_0 = latent_actions_0.reshape(latent_actions_0.shape[0], self.rnum_agents, -1)
        # perm = np.random.permutation(self.rnum_agents)
        # obs_batch = obs_batch[:, perm]
        # latent_actions_0 = latent_actions_0[:, perm]
        # obs_batch = obs_batch.reshape((obs_batch.shape[0], -1))
        # latent_actions_0 = latent_actions_0.reshape((latent_actions_0.shape[0], -1))
        
        bc_loss = self.policy.actor.bc_loss(obs_batch,
                                            rnn_states_batch,
                                            latent_actions_0,
                                            masks_batch, 
                                            available_actions_batch,
                                            active_masks_batch)#.mean()

        
        self.policy.actor_optimizer.zero_grad()
        
        bc_loss = bc_loss

        bc_loss.backward()

        if self._use_max_grad_norm:
            actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
        else:
            actor_grad_norm = get_gard_norm(self.policy.actor.parameters())

        self.policy.actor_optimizer.step()

        return bc_loss, actor_grad_norm
    
    def bc_loss(self, sample):
        obs_batch, actions_batch,\
        masks_batch, active_masks_batch, \
        available_actions_batch = sample
        # debug_print('obs', obs_batch.shape, 'action', actions_batch.shape)

        actions_batch = check(actions_batch).to(**self.tpdv)
        if active_masks_batch is not None:
            active_masks_batch = check(active_masks_batch).to(**self.tpdv)
        
        bc_loss = self.policy.actor.bc_loss(obs_batch,
                                            None,
                                            actions_batch,
                                            masks_batch, 
                                            available_actions_batch,
                                            active_masks_batch)#.mean()

        
        self.policy.clone_actor_optimizer.zero_grad()
        
        # bc_loss = bc_loss
        # debug_print(bc_loss.shape)
        bc_loss *= self.logit_scaling
        return bc_loss
        
    def bc_clone(self, sample):
        bc_loss = self.bc_loss(sample)

        bc_loss.backward()

        if self._use_max_grad_norm:
            actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
        else:
            actor_grad_norm = get_gard_norm(self.policy.actor.parameters())

        self.policy.clone_actor_optimizer.step()


        return bc_loss, actor_grad_norm
    

    def entropy_coef_scheduler(self, progress):
        return 1
        if progress > 0.8:
            return 1
        progress /= 0.8
        if progress < 0.1:
            return 16
        progress = (progress - 0.1) / 0.9
        # progress = progress % 0.25
        # progress *= 4
        return 16 * (1-progress) + 1 * progress
    

    bc_loss_buffer = []
    bc_reward_info = []
    def insert_bc_data(self, sample, progress=0.0, reward = 0):
        # debug_print('insert_bc_data', reward)
        # if self.use_symmetry:
        #     share_obs_batch, obs_batch, obs_sym_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, latent_actions_batch, sampled_actions_batch,\
        # value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, action_log_probs_last_batch, \
        # adv_targ, available_actions_batch = sample
        # else:
        #     share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, latent_actions_batch, sampled_actions_batch,\
        # value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, action_log_probs_last_batch, \
        # adv_targ, available_actions_batch = sample
        #     obs_sym_batch = obs_batch
        
        # latent_actions_batch = check(latent_actions_batch).to(**self.tpdv)
        
        # active_masks_batch = check(active_masks_batch).to(**self.tpdv)

        
        # latent_actions_0 = latent_actions_batch[..., -1, :].reshape(*latent_actions_batch.shape[:-2], -1)
        
        
        # Calculate BC loss over a longer period
        # self.bc_loss_buffer.append((obs_sym_batch, rnn_states_batch, latent_actions_0.cpu().numpy(), masks_batch, available_actions_batch, active_masks_batch.cpu().numpy()))
        self.bc_loss_buffer.append(sample)
        self.bc_reward_info.append(reward)
        if len(self.bc_loss_buffer) > self.bc_buffer_limit:
            mn = 1e9
            idx = 0
            # for i in range(len(self.bc_loss_buffer)):
            #     if self.bc_reward_info[i] < mn:
            #         mn = self.bc_reward_info[i]
            #         idx = i
            self.bc_loss_buffer.pop(idx)
            self.bc_reward_info.pop(idx)

    def ppo_update(self, sample, update_actor=True, update_critic=True, progress=0.0, joint_ppo=False):
        """
        Update actor and critic networks.
        :param sample: (Tuple) contains data batch with which to update networks.
        :update_actor: (bool) whether to update actor network.

        :return value_loss: (torch.Tensor) value function loss.
        :return critic_grad_norm: (torch.Tensor) gradient norm from critic up9date.
        ;return policy_loss: (torch.Tensor) actor(policy) loss value.
        :return dist_entropy: (torch.Tensor) action entropies.
        :return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
        :return imp_weights: (torch.Tensor) importance sampling weights.
        """
        if self.use_symmetry:
            share_obs_batch, obs_batch, obs_sym_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, latent_actions_batch, sampled_actions_batch,\
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, action_log_probs_last_batch, \
        adv_targ, available_actions_batch = sample
        else:
            share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, latent_actions_batch, sampled_actions_batch,\
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, action_log_probs_last_batch, \
        adv_targ, available_actions_batch, noises_batch = sample
            obs_sym_batch = obs_batch
        # debug_print(old_action_log_probs_batch.shape, action_log_probs_last_batch.shape, old_action_log_probs_batch[0])
        # print(adv_targ)
        # debug_print('A', torch.cuda.memory_allocated()/1024/1024)
        
        # obs_batch = check(obs_batch).to(**self.tpdv)
        # share_obs_batch = check(share_obs_batch).to(**self.tpdv)
        rnum_agents = action_log_probs_last_batch.shape[1]
        rnn_states_batch = check(rnn_states_batch).to(**self.tpdv)
        rnn_states_critic_batch = check(rnn_states_critic_batch).to(**self.tpdv)
        latent_actions_batch = check(latent_actions_batch).to(**self.tpdv)
        noises_batch = check(noises_batch).to(**self.tpdv)
        sampled_actions_batch = check(sampled_actions_batch).to(**self.tpdv)
        if available_actions_batch is not None:
            available_actions_batch = check(available_actions_batch).to(**self.tpdv)
        old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
        old_latent_probs = old_action_log_probs_batch.clone()
        action_log_probs_last_batch = check(action_log_probs_last_batch).to(**self.tpdv)
        adv_targ = check(adv_targ).to(**self.tpdv)
        # adv_targ = (adv_targ - adv_targ.mean()) / (adv_targ.std() + 1e-5)
        
        value_preds_batch = check(value_preds_batch).to(**self.tpdv)
        return_batch = check(return_batch).to(**self.tpdv)
        masks_batch = check(masks_batch).to(**self.tpdv)
        active_masks_batch = check(active_masks_batch).to(**self.tpdv)
        
        action_ts = torch.randint(self.n_timesteps - 1, self.n_timesteps, latent_actions_batch.shape[:-2]).to(**self.tpdv).long()
        
        joint_ppo = self.joint_train
        # debug_print('problast', action_log_probs_last_batch.shape, active_masks_batch.shape)
        if self.sep_logprob:
            action_log_probs_last_batch = (action_log_probs_last_batch * active_masks_batch).sum(dim=-2).unsqueeze(-1)
        else:
            action_log_probs_last_batch = (action_log_probs_last_batch * active_masks_batch[:, :, 0]).sum(dim=-1, keepdim=True)#.repeat(1, self.rnum_agents)
        # if joint_ppo:

        # old_latent_log_probs = old_action_log_probs_batch
            
        
        # old_action_log_probs_batch = torch.cat([old_action_log_probs_batch.repeat(1, 1, self.rnum_agents), action_log_probs_last_batch[:, None]], dim=-2)
        
        
        sampled_actions = sampled_actions_batch
        agent_idx = torch.randint(0, rnum_agents, (latent_actions_batch.shape[0],)).to(**self.tpdv).long()
        # old_action_log_probs_batch = extract(old_action_log_probs_batch, action_ts)
        old_action_log_probs_batch = action_log_probs_last_batch
        adv_targ = adv_targ[:, :, 0] #adv_targ.permute(0, 2, 1)
        value_preds_batch = value_preds_batch[:, 0] #extract(value_preds_batch, action_ts)
        return_batch = return_batch[:, 0] #extract(return_batch, action_ts)
        
        return_batch = return_batch[:, 0, None].repeat(1, self.rnum_agents)
        value_preds_batch = value_preds_batch[:, 0, None].repeat(1, self.rnum_agents)
        adv_targ = adv_targ[:, 0, None].repeat(1, self.rnum_agents)
        
        r_active_masks_batch = active_masks_batch.clone()
        active_masks_batch = extract(active_masks_batch, agent_idx)
        # # debug_print('extract:', active_masks_batch.shape)
        # old_action_log_probs_batch = extract(old_action_log_probs_batch, agent_idx).unsqueeze(-1)
        adv_targ = extract(adv_targ, agent_idx).unsqueeze(-1)
        
        # debug_print(latent_actions_batch.shape)
            
        latent_actions_0 = latent_actions_batch[..., -1, :].reshape(*latent_actions_batch.shape[:-2], -1)
        
        # debug_print(action_ts.shape, share_obs_batch.shape, obs_batch.shape, rnn_states_batch.shape, rnn_states_critic_batch.shape, action_pairs.shape, sampled_actions.shape, agent_idx.shape, masks_batch.shape, available_actions_batch.shape, active_masks_batch.shape, adv_targ.shape)
        
        action_ts = self.n_timesteps - 1 - action_ts
        
        values, action_log_probs, latent_probs, dist_entropy = self.policy.evaluate_actions(share_obs_batch,
                                                                              obs_batch, 
                                                                              rnn_states_batch, 
                                                                              rnn_states_critic_batch, 
                                                                            #   action_pairs,
                                                                              sampled_actions,
                                                                              action_ts,
                                                                              agent_idx,
                                                                              masks_batch, 
                                                                              available_actions_batch,
                                                                              active_masks_batch, joint_ppo=joint_ppo, noises=noises_batch, r_active_masks_batch=r_active_masks_batch, latent_actions=latent_actions_batch)
        # debug_print(latent_probs.shape, action_log_probs.shape, dist_entropy.shape)
        # time.sleep(1000000)
        # debug_print(old_latent_log_probs.shape, latent_probs.shape)
        # debug_print(old_latent_log_probs - latent_probs[:, :-1])
        
        if not (self.sep_bc_phase or self.args.use_mlp):
            latent_actions_batch = check(latent_actions_batch).to(**self.tpdv)
            active_masks_batch = check(active_masks_batch).to(**self.tpdv)
            latent_actions_0 = latent_actions_batch[..., -1, :]
            
            # bc_loss = torch.zeros(1).to(**self.tpdv)
            
            bc_loss = self.policy.actor.bc_loss(obs_batch,
                                                rnn_states_batch,
                                                latent_actions_0,
                                                masks_batch, 
                                                available_actions_batch,
                                                active_masks_batch)
        else:
            bc_loss = torch.zeros(1).to(**self.tpdv)

        # dist_entropy = dist_entropy.mean()
        # actor update
        # debug_print('X', torch.cuda.memory_allocated()/1024/1024)
        # debug_print(action_log_probs.shape, old_action_log_probs_batch.shape)
        # imp_weights = torch.exp(torch.clamp(action_log_probs - old_action_log_probs_batch, min=-1e9, max=3.))
        if self.sep_logprob:
            adv_targ = adv_targ.unsqueeze(-1)
            active_masks_batch = active_masks_batch.unsqueeze(-1)
        # debug_print('act', action_log_probs.shape, old_action_log_probs_batch.shape, adv_targ.shape)
        # torch.set_printoptions(profile="full", precision=5, sci_mode=False)
        # debug_print(old_latent_probs[0:5], latent_probs[0:5])
        # debug_print(old_latent_probs[0] - latent_probs[0, :-1])
        # debug_print(action_log_probs[0:10], old_action_log_probs_batch[0:10])
        # debug_print(action_log_probs.shape, old_action_log_probs_batch.shape, old_latent_probs.shape, latent_probs.shape)
        # debug_print(latent_probs[0], old_latent_probs[0])
        if self.use_latent_prob:
            latent_probs = torch.tile(latent_probs[:, None, :, 0], (1, self.act_step, 1)).reshape(-1, latent_probs.shape[-2], 1)
            # old_latent_probs = torch.cat([old_latent_probs, old_action_log_probs_batch[:, None]], dim=-2)
            # latent_probs = torch.cat([latent_probs[:, :-1], action_log_probs[:, None]], dim=-2)
            # debug_print(latent_probs.shape, old_latent_probs.shape, old_action_log_probs_batch.shape)
            imp_weights = torch.exp((latent_probs - old_latent_probs))
            adv_targ = adv_targ[:, None, :]
            active_masks_batch = active_masks_batch[:, None, :]
        else:
            imp_weights = torch.exp((action_log_probs - old_action_log_probs_batch))
        # debug_print('fa', imp_weights.shape, adv_targ.shape, action_log_probs.shape, old_action_log_probs_batch.shape)
        # debug_print(action_log_probs)

        # debug_print(imp_weights.shape, adv_targ.shape, action_log_probs.shape, old_action_log_probs_batch.shape, active_masks_batch.shape)
        # if self.use_latent_prob:
        #     adv_targ = adv_targ.unsqueeze(-1)
        #     active_masks_batch = active_masks_batch.unsqueeze(-1)

        surr1 = imp_weights * adv_targ
        # debug_print('surr', imp_weights.shape, adv_targ.shape, active_masks_batch.shape)
        surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
        # debug_print('surr', surr1.shape, surr2.shape, active_masks_batch.shape)
        # print("[DEBUG] ppo update", imp_weights.min().item(), imp_weights.max().item(), adv_targ.min().item(), adv_targ.max().item(), surr1.min().item(), surr1.max().item(), surr2.min().item(), surr2.max().item(), flush=True)
        # debug_print(self._use_policy_active_masks, torch.sum(torch.min(surr1, surr2),
        #                                      dim=-1,
        #                                      keepdim=True).shape, active_masks_batch.shape)
        # debug_print(surr1.shape, surr2.shape, active_masks_batch.shape)
        if self._use_policy_active_masks:
            policy_action_loss = (-torch.sum(torch.min(surr1, surr2),
                                             dim=-1,
                                             keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            policy_action_loss = -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True).mean()
        
        # policy_action_loss /= self.rnum_agents\

        policy_loss = policy_action_loss + bc_loss * self.bc_loss_coef #/ self.rnum_agents

        self.policy.actor_optimizer.zero_grad()
        
        extra_loss_dict = dict()
        extra_actor_loss, extra_loss_dict = self.compute_extra_actor_loss()
        
        (policy_loss - self.compute_entropy_loss(dist_entropy) * self.entropy_coef_scheduler(progress)
        + extra_actor_loss).backward()

        actor_grad_norm = torch.tensor(0.0)


        if update_actor:
            if self._use_max_grad_norm:
                actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
            else:
                actor_grad_norm = get_gard_norm(self.policy.actor.parameters())
            self.policy.actor_optimizer.step()

        # values = torch.gather(values, -1, agent_idx.unsqueeze(-1)) # .squeeze(-1)
        value_preds_batch = value_preds_batch # .unsqueeze(-1)
        return_batch = return_batch # .unsqueeze(-1)
        # debug_print('fa', get_gard_norm(self.policy.actor.parameters()), self.compute_entropy_loss(dist_entropy), policy_loss, extra_actor_loss, get_gard_norm(self.policy.critic.parameters()))

        # print(values, value_preds_batch, return_batch, flush=True)

        # critic update
        # value_loss = torch.zeros(1).to(**self.tpdv)
        critic_grad_norm = torch.tensor(0.0)
        if self.no_rand_train:
            active_masks_batch = active_masks_batch.reshape(self.rnum_agents, -1).permute(1, 0)
        value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)
        self.policy.critic_optimizer.zero_grad()
        (value_loss * self.value_loss_coef).backward()
        if update_critic:
            # debug_print('ha', values.shape, value_preds_batch.shape, return_batch.shape, active_masks_batch.shape)
            if self._use_max_grad_norm:
                critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
            else:
                critic_grad_norm = get_gard_norm(self.policy.critic.parameters())

            self.policy.critic_optimizer.step()
        # debug_print('value_loss', value_loss, 'critic', critic_grad_norm, 'policy_loss', policy_action_loss, 'bc_loss', bc_loss, 'dist_entropy', dist_entropy, 'actor', actor_grad_norm)
        # debug_print('E', torch.cuda.memory_allocated()/1024/1024)
        # exit()

        return value_loss, critic_grad_norm, policy_action_loss, bc_loss, dist_entropy, actor_grad_norm, imp_weights, extra_loss_dict

    def compute_returns(self, buffer):
        # debug_print(buffer.value_preds.shape)
        next_values, _ = self.policy.critic(buffer.share_obs, buffer.rnn_states, None, buffer.masks)
        buffer.value_preds = _t2n(next_values[:, 0, :, None, None]).repeat(self.rnum_agents, axis=2)
        # debug_print(buffer.value_preds.shape)
        next_values = buffer.value_preds[-1]
        _t = time.time()
        # debug_print(next_values.shape)
        
        compute_returns(buffer, next_values, self.value_normalizer)
        if self.norm_reward:
            returns = buffer.returns
            # returns = check(buffer.returns).to(**self.tpdv)
            # buffer.returns = (buffer.returns) / np.sqrt(self.policy.ret_rms.var + 1e-6)
            # debug_print(self.policy.ret_rms.var, self.policy.ret_rms.mean, self.policy.ret_rms.var.shape)
            shape = returns.shape
            self.policy.ret_rms.update(returns.reshape(-1, *shape[2:]))
            returns = self.policy.ret_rms.norm(returns.reshape(-1, *shape[2:]), retain_mean=True).reshape(*shape)
            buffer.returns = _t2n(returns)
        # print("[DEBUG] compute_returns", time.time() - _t)

    @torch.no_grad()
    def augment_latent_actions(self, buffer):
        latent_actions_0 = buffer.latent_actions[..., -1, :]
        latent_actions = [latent_actions_0]
        betas = self.policy.actor.diffusion.betas.clone().cpu().numpy()
        for i in range(self.n_timesteps):
            c = np.sqrt(1 - betas[i])
            d = np.sqrt(betas[i])

            noise = np.random.randn(*latent_actions[-1].shape)
            latent_actions.append(c * latent_actions[-1] + d * noise)
        latent_actions = list(reversed(latent_actions))
        latent_actions_lst = np.stack(latent_actions, axis=-2)

        for t in range(buffer.episode_length):
            share_obs = np.concatenate(buffer.share_obs[t])[:, np.newaxis].repeat(self.n_timesteps, axis=1)
            obs = np.concatenate(buffer.obs[t])[:, np.newaxis].repeat(self.n_timesteps, axis=1)
            rnn_states = np.concatenate(buffer.rnn_states[t])[:, np.newaxis].repeat(self.n_timesteps, axis=1)
            rnn_states_critic = np.concatenate(buffer.rnn_states_critic[t])[:, np.newaxis].repeat(self.n_timesteps, axis=1)
            masks = np.concatenate(buffer.masks[t])[:, np.newaxis].repeat(self.n_timesteps, axis=1)
            active_masks = np.concatenate(buffer.active_masks[t])[:, np.newaxis].repeat(self.n_timesteps, axis=1)
            latent_actions = np.concatenate(latent_actions_lst[t])
            # action_pairs = np.stack([latent_actions[..., :-1, :], latent_actions[..., 1:, :]], axis=-2)
            sampled_actions = np.concatenate(buffer.sampled_actions[t])[:, np.newaxis].repeat(self.n_timesteps, axis=1)
            action_ts = torch.arange(self.n_timesteps).flip(0).reshape(1, -1).numpy().repeat(latent_actions.shape[0], axis=0)

            # print("[DEBUG] augment latent actions", share_obs.shape, obs.shape, rnn_states.shape, rnn_states_critic.shape, action_pairs.shape, action_ts.shape, masks.shape, active_masks.shape, flush=True)

            values, action_log_probs, _ = self.policy.evaluate_actions(
                np.concatenate(share_obs),
                np.concatenate(obs),
                np.concatenate(rnn_states),
                np.concatenate(rnn_states_critic),
                # np.concatenate(action_pairs),
                np.concatenate(sampled_actions),
                np.concatenate(action_ts),
                np.concatenate(masks),
                None, # available_actions
                np.concatenate(active_masks),
            )

            values = values.reshape(buffer.n_rollout_threads, buffer.num_agents, self.n_timesteps)
            action_log_probs = action_log_probs.reshape(buffer.n_rollout_threads, buffer.num_agents, self.n_timesteps, 1)

            buffer.value_preds[t] = values.detach().cpu().numpy()
            buffer.action_log_probs[t] = action_log_probs.detach().cpu().numpy()
            buffer.latent_actions[t] =  latent_actions_lst[t]
        
    def compute_advantages(self, buffer):
        if self._use_popart or self._use_valuenorm:
            denorm_value = self.value_normalizer.denormalize(buffer.value_preds[:-1][..., np.newaxis])[..., 0]
            advantages = buffer.returns[:-1] - denorm_value
        else:
            advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
        advantages_copy = advantages.copy()
        active_masks = buffer.active_masks[:-1]#.repeat(self.n_timesteps, axis=3)
        # debug_print(active_masks.shape)
        # print(np.tile(active_masks, (1, 1, 2, 1)).shape)
        advantages_copy[np.tile(active_masks, (1, 1, self.num_agents, 1)) == 0.] = np.nan



        if self.normalize_advantage:
            mean_advantages = np.nanmean(advantages_copy)
            std_advantages = np.nanstd(advantages_copy)
            if self.normalize_advantage_mean:
                advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)
            else:
                advantages = (advantages) / (std_advantages + 1e-5)
        else:
            adv_lower = np.nanquantile(advantages_copy, 0.05)
            adv_upper = np.nanquantile(advantages_copy, 0.95)
            advantages = np.clip(advantages, adv_lower, adv_upper)

        # debug_print(buffer.value_preds.shape, advantages.shape)
        # debug_print('adv', advantages.shape)
        return advantages

    def train(self, buffer, update_actor=True, update_critic=True, progress=0.0):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        buffer.scale_reward()
        self.prep_rollout()
        self.compute_returns(buffer)
        self.prep_training()

        advantages = self.compute_advantages(buffer)

        train_info = {}

        train_info['value_loss'] = []
        train_info['policy_loss'] = []
        train_info['bc_loss'] = []
        train_info['dist_entropy'] = []
        train_info['denorm_value'] = 0
        if self.sep_bc_phase:
            train_info['bc_actor_grad_norm'] = []
        train_info['actor_grad_norm'] = []
        train_info['critic_grad_norm'] = []
        train_info['ratio'] = []
        
        avg_reward = np.mean(buffer.rewards)

        # if use seperate bc phase
        if self._use_recurrent_policy:
            data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
        elif self._use_naive_recurrent:
            data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
        else:
            # data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)
            data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
        # for sample in data_generator:
        #     self.insert_bc_data(sample, reward=avg_reward)
        
        if self.aug_latent_actions:
            self.prep_rollout()
            buffer_replace_data = []
            for _ in range(self.ppo_epoch):
                self.augment_latent_actions(buffer)
                buffer_replace_data.append((buffer.value_preds.copy(), buffer.action_log_probs.copy(), buffer.latent_actions.copy()))
            self.prep_training()

        # self.compute_returns(buffer)
        # advantages = self.compute_advantages(buffer)
        for e in range(self.ppo_epoch + self.critic_epoch):
            if self.aug_latent_actions:
                # self.prep_rollout()
                # self.augment_latent_actions(buffer)
                # self.prep_training()
                value_preds, action_log_probs, latent_actions = buffer_replace_data.pop()
                buffer.value_preds[:] = value_preds[:]
                buffer.action_log_probs[:] = buffer.action_log_probs[:]
                buffer.latent_actions[:] = buffer.latent_actions[:]

                # print("[DEBUG] update buffer", value_preds.min(), value_preds.max(), action_log_probs.min(), action_log_probs.max(), latent_actions.min(), latent_actions.max(), flush=True)    

            if self.recompute_adv:
                self.prep_rollout()
                self.compute_returns(buffer)
                # debug_print(buffer.returns.shape, buffer.value_preds.shape, buffer.share_obs.shape)
                self.prep_training()
                advantages = self.compute_advantages(buffer)

            if self._use_recurrent_policy:
                data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
            elif self._use_naive_recurrent:
                data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
            else:
                # data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)
                data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)

            tot_time = 0
            cur_time = time.time()
            for sample in data_generator:
                
                if not self.sep_bc_phase:
                    for _ in range(self.bc_epoch):
                        bc_loss, actor_grad_norm = self.bc_update(sample, progress)
                    

                tot_time -= time.time()
                value_loss, critic_grad_norm, policy_loss, bc_loss, dist_entropy, actor_grad_norm, imp_weights, extra_loss_dict \
                    = self.ppo_update(sample, (update_actor and e < self.ppo_epoch), update_critic, progress)
                tot_time += time.time()

                train_info['value_loss'].append(value_loss.item())
                train_info['policy_loss'].append(policy_loss.item())
                if not self.sep_bc_phase:
                    train_info['bc_loss'].append(bc_loss.item())
                train_info['dist_entropy'].append(dist_entropy.item())
                train_info['actor_grad_norm'].append(actor_grad_norm.item())
                train_info['critic_grad_norm'].append(critic_grad_norm.item())
                train_info['ratio'].append(imp_weights.mean().item())
                if self._use_popart or self._use_valuenorm:
                    # print(self.value_normalizer.denormalize(buffer.value_preds[:-1, :, :, 0]).mean(), buffer.value_preds[:-1, :, :, 0].mean(), buffer.returns[:-1].mean())
                    train_info['denorm_value'] = self.value_normalizer.denormalize(buffer.value_preds[:-1, :, :, 0]).mean()
                    # val = self.value_normalizer.denormalize(buffer.value_preds[:-1, :, :, 0]).mean()
                    # debug_print(val.shape, val.mean())
                else:
                    train_info['denorm_value'] = buffer.value_preds[:-1].mean()
            # print('time on ppo update vs real time', tot_time, time.time() - cur_time)

        if self.sep_bc_phase:
            # save latest rl actor & optimizer
            self.rl_actor_state_dict = self.policy.actor.state_dict()
            self.rl_optimizer_state_dict = self.policy.actor_optimizer.state_dict()

        if self.sep_bc_phase:
            # load bc actor & optimizer
            self.policy.actor.load_state_dict(self.bc_actor_state_dict)
            self.policy.actor_optimizer.load_state_dict(self.bc_optimizer_state_dict)

            count = self.bc_epoch * 10
            for _ in range(count):
                if self._use_recurrent_policy:
                    data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
                elif self._use_naive_recurrent:
                    data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
                else:
                    # data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)
                    data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
            
                # for sample in data_generator:
                    
                bc_loss, actor_grad_norm = self.bc_update(self.bc_loss_buffer[len(self.bc_loss_buffer) - 1 - _ % len(self.bc_loss_buffer)], progress)
                        

                train_info['bc_loss'].append(bc_loss.item())
                train_info['bc_actor_grad_norm'].append(actor_grad_norm.item())
                if bc_loss < 0.1 and _ > self.bc_epoch:
                    break
            
            # save latest bc actor & optimizer
            self.bc_actor_state_dict = self.policy.actor.state_dict()
            self.bc_optimizer_state_dict = self.policy.actor_optimizer.state_dict()
            
            # load latest rl optimizer
            self.policy.actor_optimizer.load_state_dict(self.rl_optimizer_state_dict)

        from pprint import pprint
        # pprint(train_info)

        for k in train_info.keys():
            train_info[k] = np.mean(train_info[k])

        # train_info['eta'] = self.policy.actor.diffusion.eta.item()

 
        return train_info

    def prep_training(self):
        self.policy.actor.train()
        self.policy.critic.train()

    def prep_rollout(self):
        self.policy.actor.eval()
        self.policy.critic.eval()
