import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from onpolicy.utils.util import get_gard_norm, huber_loss, mse_loss
from onpolicy.utils.popart import PopArt
from onpolicy.algorithms.utils.util import check
from onpolicy.algorithms.utils.distance import D_dict
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm
import pickle
import copy
class R_MAPPO():
    """
    Trainer class for MAPPO to update policies.
    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.
    :param policy: (R_MAPPO_Policy) policy to update.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self,
                 args,
                 policy,
                 device=torch.device("cpu")):
        self.args = args
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy
        if self.args.dynamic_clip_tag:
            self.clip_param = np.ones(self.args.n_agents)*args.clip_param
            self.clip_constant = np.log(self.clip_param + 1).sum()

            self.all_updates = int(self.args.num_env_steps) // self.args.episode_length // self.args.n_rollout_threads
            self.curr_delta_update = 0
            self.min_clip_constant = np.log(self.args.min_clip_params + 1) * self.args.n_agents
            self.curr_constant = self.clip_constant
        else:
            self.clip_param = args.clip_param




        self.value_clip_param = args.clip_param

        self.ppo_epoch = args.ppo_epoch
        self.num_mini_batch = args.num_mini_batch
        self.data_chunk_length = args.data_chunk_length
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm       
        self.huber_delta = args.huber_delta



        self._use_recurrent_policy = args.use_recurrent_policy
        self._use_naive_recurrent = args.use_naive_recurrent_policy
        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_popart = args.use_popart
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        
        if self._use_popart:
            self.value_normalizer = PopArt(1, device=self.device)
        else:
            self.value_normalizer = None


        self.log_inv_lambda_s = Variable(torch.ones([self.args.n_agents,self.args.episode_length*self.args.n_rollout_threads, 1])).to(**self.tpdv).double()
        self.log_inv_lambda_s.requires_grad = True

        self.mu_s = Variable(torch.ones([self.args.n_agents,self.args.episode_length*self.args.n_rollout_threads , 1])).to(**self.tpdv).double()
        self.mu_s.requires_grad = True

        self.opt_lambda = torch.optim.Adam([self.log_inv_lambda_s], lr=0.1)
        self.opt_mu = torch.optim.Adam([self.mu_s], lr=0.1)
        self.es_tag = False
        if self.args.penalty_method:
            print('init 11 dtar_sqrt_kl = {} dtar_kl = {}'.format(self.args.dtar_sqrt_kl, self.args.dtar_kl))
            self.beta_kl = np.ones(self.args.n_agents) * self.args.beta_kl
            self.beta_sqrt_kl = np.ones(self.args.n_agents) * self.args.beta_sqrt_kl

            self.set_dtar()

        self.term_kl =  None
        self.term_sqrt_kl = None
        # np.zeros(self.args.n_agents)
        self.p_loss_part1 = None
        self.p_loss_part2 = None
        self.d_coeff = None
        self.d_term = None
        self.NP_decay_max = int(self.args.num_env_steps * self.args.NP_decay_rate)
        print('NP_decay_rate = {} NP_decay_max = {}'.format(self.args.NP_decay_rate,self.NP_decay_max))
        self.NP_coeff = self.args.NP_coeff

        self.NP_auto_log_coeff = Variable(torch.zeros(1)).to(**self.tpdv)
        self.NP_auto_log_coeff.requires_grad = True

        self.opt_NP_auto = torch.optim.Adam([self.NP_auto_log_coeff ], lr=self.args.NP_auto_lr)

        if self.args.NP_auto_target is None:
            self.NP_auto_target = - (self.args.n_actions ** self.args.n_agents)
        else:
            self.NP_auto_target = self.args.NP_auto_target

        self.term1_grad_norm = None
        self.term2_grad_norm = None
        if self.args.overflow_save:
            self.overflow = np.zeros(self.args.n_agents)
        if self.args.NP_delta_init is not None:
            self.delta = self.args.NP_delta_init * np.ones(self.args.n_agents)
        else:
            self.delta = self.args.NP_delta * np.ones(self.args.n_agents)
        print('init delta = {}'.format(self.delta))
        self.term_dist = np.zeros(self.args.n_agents)
        if self.args.check_optimal_V_bound:
            n_joint_actions = self.args.n_actions ** self.args.n_agents
            self.old_Q_table = np.zeros([self.args.shape_shape, n_joint_actions])
        if self.args.NP_KL_delta_init is not None:
            self.NP_KL_delta = self.args.NP_KL_delta_init * np.ones(self.args.n_agents)
        else:
            self.NP_KL_delta = self.args.NP_KL_delta * np.ones(self.args.n_agents)
        print('init KL_delta = {}'.format(self.NP_KL_delta))
        self.NP_KL_term_dist = np.zeros(self.args.n_agents)
        self.NP_KL_d_coeff = None
        self.NP_KL_d_term = None
    def set_dtar(self, dtar_kl = None, dtar_sqrt_kl = None,kl_para1 = None,kl_para2 = None,sqrt_kl_para1 = None,sqrt_kl_para2 = None):
        if self.args.dtar_kl_specific is not None:
            dtar_kl_tmp = self.args.dtar_kl_specific
        else:
            dtar_kl_tmp = self.args.dtar_kl

        self.dtar_kl = dtar_kl_tmp if dtar_kl is None else dtar_kl
        self.kl_para1 = self.args.kl_para1 if kl_para1 is None else kl_para1
        self.kl_para2 = self.args.kl_para2 if kl_para2 is None else kl_para2
        self.kl_lower = self.dtar_kl / self.kl_para1
        self.kl_upper = self.dtar_kl * self.kl_para1

        if self.args.inner_refine:
            dtar_sqrt_kl_tmp = np.sqrt(self.args.dtar_kl)
            if self.args.NP_KL_term:
                self.args.NP_KL_delta = np.sqrt(self.args.NP_delta)
        else:
            dtar_sqrt_kl_tmp = self.args.dtar_sqrt_kl

        self.dtar_sqrt_kl = dtar_sqrt_kl_tmp if dtar_sqrt_kl is None else dtar_sqrt_kl
        print('after set: dtar_sqrt_kl = {} dtar_kl = {}'.format(self.dtar_sqrt_kl, self.dtar_kl))

        if self.args.inner_refine_sqrt:
            sqrt_kl_para1_tmp = np.sqrt(self.kl_para1)
        else:
            sqrt_kl_para1_tmp = self.args.sqrt_kl_para1

        print(f'sqrt_kl_para1_tmp = {sqrt_kl_para1_tmp}')

        self.sqrt_kl_para1 = sqrt_kl_para1_tmp if sqrt_kl_para1 is None else sqrt_kl_para1
        self.sqrt_kl_para2 = self.args.sqrt_kl_para2 if sqrt_kl_para2 is None else sqrt_kl_para2
        self.sqrt_kl_lower = self.dtar_sqrt_kl / self.sqrt_kl_para1
        self.sqrt_kl_upper = self.dtar_sqrt_kl * self.sqrt_kl_para1

        if self.args.DPO_beta_clip_mode == 'central':
            self.para_upper_bound = self.dtar_kl * self.args.para_upper_bound
            self.para_lower_bound = self.dtar_kl * self.args.para_lower_bound
        elif self.args.DPO_beta_clip_mode == 'fixed':
            self.para_upper_bound = 1.0 * self.args.para_upper_bound
            self.para_lower_bound = 1.0 * self.args.para_lower_bound
    def get_matrix_state_table(self):
        # return value = (n_agents,n_states, 1), probs = (n_agents,n_states,n_actions)
        eye = np.eye(self.args.obs_shape)
        value_all = []
        probs_all = []
        for i in range(self.args.obs_shape):
            obs_input = []
            share_obs_input = []
            for a in range(self.args.n_agents):
                obs_input.append(eye[i])
                share_obs_input.append(eye[i])
            obs_input = np.array(obs_input)
            share_obs_input = np.array(share_obs_input)
            masks_inputs = np.ones([self.args.n_agents, 1])
            rnn_states_input = np.zeros([self.args.n_agents, 1, self.args.hidden_size])
            rnn_states_critic_inputs = np.zeros([self.args.n_agents, 1, self.args.hidden_size])

            value, action, action_log_prob, rnn_state, rnn_state_critic \
                = self.policy.get_actions(share_obs_input,
                                          obs_input,
                                          rnn_states_input,
                                          rnn_states_critic_inputs,
                                          masks_inputs)
            probs = self.policy.get_probs(obs_input, rnn_states_input, masks_inputs)
            value_all.append(value)
            probs_all.append(probs)
            # print('value = {} probs = {}'.format(value.shape, probs.shape))
        value_all = torch.stack(value_all,dim=1)
        probs_all = torch.stack(probs_all, dim=1)
        # print('value_all = {} probs = {}'.format(value_all.shape, probs_all.shape))
        return value_all,probs_all

    def calc_matrix_V_table(self,probs,optimal_check=False):
        env = self.args.calc_V_envs
        n_actions = self.args.n_actions
        n_agents = self.args.n_agents
        n_states = self.args.shape_shape
        gamma = self.args.gamma
        def joint_to_idv(joint_a,n_actions = n_actions,n_agents = n_agents):
            res = []
            while joint_a > 0:
                r = joint_a % n_actions
                joint_a = joint_a // n_actions
                res.append(r)
            if len(res) < n_agents:
                for i in range(len(res),n_agents):
                    res.append(0)
            reversed(res)
            return res


        dp_judge = 100000000000000
        dp_eps = self.args.NP_dp_eps
        n_joint_actions = n_actions ** n_agents
        Q = copy.deepcopy(self.old_Q_table)
        joint_pi = np.zeros([n_states, n_joint_actions])

        for s in range(n_states):
            if not optimal_check:
                for joint_a in range(n_joint_actions):
                    idv_a = joint_to_idv(joint_a)
                    joint_pi[s, joint_a] = 1
                    for i in range(n_agents):
                        joint_pi[s,joint_a] *= probs[i,s,idv_a[i]]
            else:
                joint_pi[s,self.args.optimal_pi_table[s]] = 1

        new_Q = copy.deepcopy(Q)
        value_iter = 0
        while dp_judge >= dp_eps:
            next_V = Q * joint_pi
            next_V = np.sum(next_V,axis = -1)
            for s in range(n_states):
                for joint_a in range(n_joint_actions):
                    idv_a = joint_to_idv(joint_a)
                    r, p = env.get_model_info(s, idv_a)

                    new_Q[s, joint_a] = r + gamma * (np.dot(p, next_V))
            dp_judge = np.linalg.norm(Q - new_Q)
            Q = copy.deepcopy(new_Q)

            print('value_iter = {} dp_judge = {}'.format(value_iter, dp_judge))
            value_iter += 1
        self.old_Q_table = copy.deepcopy(Q)
        V = Q * joint_pi
        V = np.sum(V, axis=-1)
        return V

    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):
        """
        Calculate value function loss.
        :param values: (torch.Tensor) value function predictions.
        :param value_preds_batch: (torch.Tensor) "old" value  predictions from data batch (used for value clip loss)
        :param return_batch: (torch.Tensor) reward to go returns.
        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.

        :return value_loss: (torch.Tensor) value function loss.
        """
        if self._use_popart:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.value_clip_param,
                                                                                        self.value_clip_param)
            error_clipped = self.value_normalizer(return_batch) - value_pred_clipped
            error_original = self.value_normalizer(return_batch) - values
        else:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.value_clip_param,
                                                                                        self.value_clip_param)
            error_clipped = return_batch - value_pred_clipped
            error_original = return_batch - values

        if self._use_huber_loss:
            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
            value_loss_original = huber_loss(error_original, self.huber_delta)
        else:
            value_loss_clipped = mse_loss(error_clipped)
            value_loss_original = mse_loss(error_original)

        if self._use_clipped_value_loss:
            value_loss = torch.max(value_loss_original, value_loss_clipped)
        else:
            value_loss = value_loss_original

        if self._use_value_active_masks:
            value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            value_loss = value_loss.mean()

        return value_loss

    def ppo_update(self, sample, update_actor=True,aga_update_tag = False,update_index = -1,curr_update_num=None):
        """
        Update actor and critic networks.
        :param sample: (Tuple) contains data batch with which to update networks.
        :update_actor: (bool) whether to update actor network.

        :return value_loss: (torch.Tensor) value function loss.
        :return critic_grad_norm: (torch.Tensor) gradient norm from critic up9date.
        ;return policy_loss: (torch.Tensor) actor(policy) loss value.
        :return dist_entropy: (torch.Tensor) action entropies.
        :return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
        :return imp_weights: (torch.Tensor) importance sampling weights.
        """
        torch.autograd.set_detect_anomaly(True)

        base_ret,q_ret,sp_ret,penalty_ret = sample
        share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
        adv_targ, available_actions_batch = base_ret




        # print('share_obs_batch = {}'.format(share_obs_batch.shape))
        # print('obs_batch = {}'.format(share_obs_batch.shape))

        old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
        adv_targ = check(adv_targ).to(**self.tpdv)
        value_preds_batch = check(value_preds_batch).to(**self.tpdv)
        return_batch = check(return_batch).to(**self.tpdv)
        active_masks_batch = check(active_masks_batch).to(**self.tpdv)


        # Reshape to do in a single forward pass for all steps
        # print('aga_update_tag = {}, update_index = {}'.format(aga_update_tag,update_index))
        if self.args.idv_para and aga_update_tag and self.args.aga_tag:
            # print('case1')
            values_all, action_log_probs, dist_entropy = self.policy.evaluate_actions_single(share_obs_batch,
                                                                                  obs_batch,
                                                                                  rnn_states_batch,
                                                                                  rnn_states_critic_batch,
                                                                                  actions_batch,
                                                                                  masks_batch,
                                                                                  available_actions_batch,
                                                                                  active_masks_batch,update_index=update_index)
        else:
            prob_merge_tag = not self.args.dynamic_clip_tag
            values_all, action_log_probs, dist_entropy = self.policy.evaluate_actions(share_obs_batch,
                                                                                  obs_batch,
                                                                                  rnn_states_batch,
                                                                                  rnn_states_critic_batch,
                                                                                  actions_batch,
                                                                                  masks_batch,
                                                                                  available_actions_batch,
                                                                                  active_masks_batch,prob_merge=prob_merge_tag)
        # actor update
        #imp_weights = (episode_length * agent_num, 1)
        # imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)
        # print('imp_weights = {}'.format(imp_weights.shape))
        self.es_tag = False
        if self.args.penalty_method:
            if self.args.new_penalty_method:
                if self.args.NP_code_check:
                    # dist_function = D_dict[self.args.NP_dist_name]

                    eps_kl = 1e-9
                    eps_sqrt = 1e-12
                    if self.args.env_name == 'mujoco':
                        with torch.no_grad():
                            old_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch, target=True)
                        new_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch)
                        kl = []
                        for idv_old, idv_new in zip(old_dist, new_dist):
                            idv_kl = torch.distributions.kl_divergence(idv_old, idv_new)
                            kl.append(idv_kl)
                        kl = torch.stack(kl, dim=1)
                        # print('n_actions = {} kl = {}, idv_kl = {}'.format(self.args.n_actions,kl.shape,idv_kl.shape))
                    else:
                        if (self.args.env_name == 'StarCraft2' or self.args.env_name == 'smacv2') and not self.args.NP_ava_test:
                            probs = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch,
                                                          available_actions=available_actions_batch)
                        else:
                            probs = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch)

                        old_probs_batch = penalty_ret[0]

                        old_probs_batch = check(old_probs_batch).to(**self.tpdv)

                        # print('n_actions = {} old_probs_batch = {}, probs = {}'.format(self.args.n_actions,old_probs_batch.shape,probs.shape))
                        kl = old_probs_batch * (torch.log(old_probs_batch + eps_kl) - torch.log(probs + eps_kl))
                    # print('prev kl = {}'.format(kl.shape))
                    # print('obs_batch = {}'.format(obs_batch.shape))
                    # kl and obs_batch is with shape (episode_length * n_agents, n_actions)
                    kl = torch.sum(kl, dim=-1, keepdim=True).reshape([-1, 1])
                    # print('after kl = {}'.format(kl.shape))
                    imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)
                    term1 = imp_weights * adv_targ

                    dist_function = D_dict[self.args.NP_dist_name]
                    # print('dist_function = {}'.format(dist_function))
                    sqrt_kl = dist_function(probs, old_probs_batch)



                    term1 = term1.reshape([-1, self.args.n_agents])
                    sqrt_kl = sqrt_kl.reshape([-1, self.args.n_agents])
                    kl = kl.reshape([-1, self.args.n_agents])

                    if self.args.env_name == 'mujoco':
                        term1_active_masks_batch = active_masks_batch.repeat([1, self.args.n_actions])
                        term1_active_masks_batch = term1_active_masks_batch.reshape([-1, self.args.n_agents])
                    else:
                        term1_active_masks_batch = active_masks_batch.reshape([-1, self.args.n_agents])
                    # print('active_masks = {}'.format(active_masks_batch.shape))
                    # print('n_agents = {}, n_actions = {}'.format(self.args.n_agents, self.args.n_actions))
                    policy_active_masks_batch = active_masks_batch.reshape([-1, self.args.n_agents])

                    if self._use_policy_active_masks:
                        term1 = (-term1 * term1_active_masks_batch).sum(dim=0) / term1_active_masks_batch.sum(dim=0)

                        term_sqrt_kl = (sqrt_kl * term1_active_masks_batch).sum(dim=0) / term1_active_masks_batch.sum(dim=0)

                        term_kl = (kl * policy_active_masks_batch).sum(dim=0) / active_masks_batch.sum(dim=0)
                        # term1 = (-torch.sum(term1,dim=-1,keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
                        # print('kl = {} sqrt_kl = {} active_masks = {}'.format(kl.shape,sqrt_kl.shape,active_masks_batch.shape))
                        # term_sqrt_kl = (torch.sum(sqrt_kl,dim=-1,keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
                        # term_kl = (torch.sum(kl, dim=-1, keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
                    else:
                        term1 = (-term1).mean(dim=0)
                        term_sqrt_kl = sqrt_kl.mean(dim=0)
                        term_kl = kl.mean(dim=0)
                        # term1 = -torch.sum(term1, dim=-1, keepdim=True).mean()
                        # term_sqrt_kl = torch.sum(sqrt_kl, dim=-1, keepdim=True).mean()
                        # term_kl = torch.sum(kl, dim=-1, keepdim=True).mean()
                    if self.args.idv_beta:
                        if self.args.NP_term_update_mode == 'mean':
                            self.NP_KL_term_dist += term_kl.clone().detach().cpu().numpy()
                            self.term_dist += term_sqrt_kl.clone().detach().cpu().numpy()
                            # self.term_sqrt_kl += term_sqrt_kl.clone().detach().cpu().numpy()
                            # self.term_kl += term_kl.clone().detach().cpu().numpy()
                        elif self.args.NP_term_update_mode == 'single':
                            self.term_dist = term_sqrt_kl.clone().detach().cpu().numpy()
                            self.NP_KL_term_dist = term_kl.clone().detach().cpu().numpy()
                    else:
                        self.term_dist = torch.ones_like(term_sqrt_kl) * (term_sqrt_kl.mean())
                        self.NP_KL_term_dist = torch.ones_like(term_kl) * (term_kl.mean())

                    sqrt_coeff = torch.tensor(self.delta).to(**self.tpdv).detach()
                    # kl_coeff = torch.tensor(self.beta_kl).to(**self.tpdv).detach()
                    kl_coeff = torch.tensor(self.NP_KL_delta).to(**self.tpdv).detach()
                    term2 = sqrt_coeff * term_sqrt_kl
                    term3 = kl_coeff * term_kl
                    policy_loss = term1 + term2 + term3

                    self.p_loss_part1 = term1.mean()
                    self.p_loss_part2 = term2.mean()
                    self.p_loss_part3 = term3.mean()
                    self.d_coeff = sqrt_coeff.mean()
                    self.d_term = term_sqrt_kl.clone().detach().cpu().numpy().mean()
                    policy_loss = policy_loss.mean()
                    self.es_tag = False
                    clip_rate = 0


                else:
                    dist_function = D_dict[self.args.NP_dist_name]
                    print('NO code check')
                    if self.args.env_name == 'mujoco':
                        # print('continous action have not be done !!!')
                        with torch.no_grad():
                            old_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch,target=True)
                        new_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch)
                        d_term = []
                        for idv_old,idv_new in zip(old_dist,new_dist):
                            idv_d_term = dist_function(idv_old,idv_new,continuous=True)
                            # print('idv_d_term = {}'.format(idv_d_term.shape))
                            d_term.append(idv_d_term)
                        d_term = torch.stack(d_term,dim = 1).reshape([-1, 1])

                        # print('d_term = {}'.format(d_term.shape))
                    else:
                        if (self.args.env_name == 'StarCraft2' or self.args.env_name == 'smacv2')  and not self.args.NP_ava_test:
                            probs = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch,
                                                          available_actions=available_actions_batch)
                        else:
                            probs = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch)
                        old_probs_batch = penalty_ret[0]
                        old_probs_batch = check(old_probs_batch).to(**self.tpdv)

                        if self.args.NP_DPO_check:
                            eps_kl = 1e-9
                            eps_sqrt = 1e-12
                            # print('n_actions = {} old_probs_batch = {}, probs = {}'.format(self.args.n_actions,old_probs_batch.shape,probs.shape))
                            kl_term = old_probs_batch * (torch.log(old_probs_batch + eps_kl) - torch.log(probs + eps_kl))
                            kl_term = torch.sum(kl_term, dim=-1, keepdim=True).reshape([-1, 1])
                            d_term = torch.sqrt(torch.max(kl_term + eps_sqrt, eps_sqrt * torch.ones_like(kl_term)))
                        else:
                            d_term = dist_function(probs, old_probs_batch)

                    if self.args.NP_KL_term:
                        if self.args.env_name == 'mujoco':
                            with torch.no_grad():
                                old_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch, target=True)
                            new_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch)
                            kl = []
                            for idv_old, idv_new in zip(old_dist, new_dist):
                                idv_kl = torch.distributions.kl_divergence(idv_old, idv_new)
                                kl.append(idv_kl)
                            kl_term = torch.stack(kl, dim=1)
                        else:
                            eps_kl = 1e-9
                            # print('n_actions = {} old_probs_batch = {}, probs = {}'.format(self.args.n_actions,old_probs_batch.shape,probs.shape))
                            kl_term = old_probs_batch * (torch.log(old_probs_batch + eps_kl) - torch.log(probs + eps_kl) )
                        kl_term = torch.sum(kl_term, dim=-1, keepdim=True).reshape([-1, 1])
                    imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)



                    # print('imp_weights = {}'.format(imp_weights.shape))
                    if self.args.NP_use_clip:
                        surr1 = imp_weights * adv_targ
                        surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
                        if self._use_policy_active_masks:
                            term1 = (-torch.sum(torch.min(surr1, surr2),
                                                             dim=-1,
                                                             keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
                        else:
                            term1 = -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True).mean()
                    else:
                        term1 = imp_weights * adv_targ
                    # print('term1 = {}'.format(term1.shape))
                    print(f'd_term before reshape = {d_term}')
                    term1 = term1.reshape([-1, self.args.n_agents])
                    d_term = d_term.reshape([-1, self.args.n_agents])
                    print(f'd_term after reshape = {d_term}')
                    if self.args.env_name == 'mujoco':
                        term1_active_masks_batch = active_masks_batch.repeat([1, self.args.n_actions])
                        term1_active_masks_batch = term1_active_masks_batch.reshape([-1, self.args.n_agents])
                    else:
                        term1_active_masks_batch = active_masks_batch.reshape([-1, self.args.n_agents])
                    # print('active_masks = {}'.format(active_masks_batch.shape))
                    # print('n_agents = {}, n_actions = {}'.format(self.args.n_agents, self.args.n_actions))
                    policy_active_masks_batch = active_masks_batch.reshape([-1, self.args.n_agents])

                    policy_active_masks_batch = active_masks_batch.reshape([-1, self.args.n_agents])
                    if self._use_policy_active_masks:
                        term1 = (-term1 * term1_active_masks_batch).sum(dim=0) / term1_active_masks_batch.sum(dim=0)
                        d_term = (d_term * term1_active_masks_batch).sum(dim=0) / term1_active_masks_batch.sum(dim=0)
                    else:
                        term1 = (-term1).mean(dim=0)
                        d_term = d_term.mean(dim=0)

                    if self.args.NP_KL_term:
                        kl_term = kl_term.reshape([-1, self.args.n_agents])

                        if self._use_policy_active_masks:
                            kl_term = (kl_term * policy_active_masks_batch).sum(dim=0) / active_masks_batch.sum(dim=0)
                        else:
                            kl_term = kl_term.mean(dim=0)
                        if self.args.NP_term_update_check:
                            self.NP_KL_term_dist = kl_term.clone().detach().cpu().numpy()
                        else:
                            self.NP_KL_term_dist += kl_term.clone().detach().cpu().numpy()

                    if self.args.NP_term_update_check:
                        self.term_dist = d_term.clone().detach().cpu().numpy()
                    else:
                        self.term_dist += d_term.clone().detach().cpu().numpy()

                    print(f'term_dist = {self.term_dist}')
                    self.p_loss_part1 = term1.clone().detach().cpu().numpy().mean()
                    self.d_term = d_term.clone().detach().cpu().numpy().mean()

                    if self.args.NP_decay_mode == 'theory':
                        d_coeff = np.sqrt(2) / (1 - self.args.gamma)
                    elif self.args.NP_decay_mode == 'balance':
                        d_coeff = self.p_loss_part1 / self.d_term * self.args.NP_balance_rate
                    elif self.args.NP_decay_mode == 'auto':
                        d_coeff = self.NP_auto_log_coeff.exp().clone().detach().cpu().numpy()[0]
                        d_coeff = d_coeff.astype(np.float64)
                    elif self.args.NP_decay_mode == 'clip' and self.args.NP_clip_coeff_refine:
                        d_coeff = np.sqrt(self.args.clip_param)
                    else:
                        d_coeff = self.NP_coeff

                    # print('before gamma correct: omega = {}'.format(d_coeff))
                    if self.args.NP_gamma_correct_omega:
                        d_coeff /= (1 - self.args.gamma)
                    # print('after gamma correct: omega = {}'.format(d_coeff))
                    if self.args.NP_add_delta:
                        d_coeff = torch.tensor(self.delta).to(**self.tpdv)
                    else:
                        d_coeff = torch.full_like(d_term, d_coeff).to(**self.tpdv)

                    # print('d_term = {} term1 = {}'.format(type(d_term),type(term1)))

                    # print('d_term = {}'.format(self.d_term))

                    self.d_coeff = d_coeff.mean()

                    if self.args.NP_KL_term:
                        kl_d_coeff = torch.tensor(self.NP_KL_delta).to(**self.tpdv)
                        self.NP_KL_d_coeff = kl_d_coeff.mean()

                    if self.args.NP_decay_mode == 'auto':
                        auto_loss = (-self.NP_auto_log_coeff * ((d_term + self.NP_auto_target).detach())).mean()
                    # print('d_coeff = {} d_coeff_type = {}'.format(d_coeff,type(d_coeff)))

                    term2 = d_coeff * d_term
                    self.p_loss_part2 =term2.mean()

                    if self.args.NP_KL_term:
                        term3 = kl_d_coeff * kl_term
                        self.p_loss_part3 = term3.mean()
                    else:
                        term3 = 0
                        self.p_loss_part3 = 0


                    policy_loss = term1 + term2 + term3
                    policy_loss = policy_loss.mean()
                    # print('policy_loss = {}'.format(policy_loss))
                    if self.args.NP_grad_check:


                        self.actor_zero_grad()
                        term1 = term1.mean()
                        term1.backward(retain_graph=True)

                        term1_grad_norm = 0
                        if self.args.idv_para:
                            if aga_update_tag and self.args.aga_tag:
                                term1_grad_norm = get_gard_norm(self.policy.actor[update_index].parameters())
                            else:
                                for i in range(self.args.n_agents):
                                    idv_term1_grad = get_gard_norm(self.policy.actor[i].parameters())
                                    term1_grad_norm += idv_term1_grad
                                term1_grad_norm /= self.args.n_agents
                        else:
                            term1_grad_norm = get_gard_norm(self.policy.actor.parameters())
                        self.term1_grad_norm = term1_grad_norm
                        self.actor_zero_grad()


                        term2 = term2.mean()
                        term2.backward(retain_graph=True)

                        term2_grad_norm = 0
                        if self.args.idv_para:
                            if aga_update_tag and self.args.aga_tag:
                                term2_grad_norm = get_gard_norm(self.policy.actor[update_index].parameters())
                            else:
                                for i in range(self.args.n_agents):
                                    idv_term2_grad = get_gard_norm(self.policy.actor[i].parameters())
                                    term2_grad_norm += idv_term2_grad
                                term2_grad_norm /= self.args.n_agents
                        else:
                            term2_grad_norm = get_gard_norm(self.policy.actor.parameters())
                        self.term2_grad_norm = term2_grad_norm
                        self.actor_zero_grad()

                    if self.args.NP_use_clip:
                        clip_check = (surr1 != surr2)
                        clip_check_sum = clip_check.sum()
                        clip_check_total = torch.ones_like(clip_check).sum()
                        # print('clip_check = {}'.format(clip_check))
                        # print('clip_check_sum = {}'.format(clip_check_sum))
                        # print('clip_check_total = {}'.format(clip_check_total))
                        clip_rate = float(clip_check_sum) / float(clip_check_total)
                    else:
                        clip_rate = 0
            else:
                if self.args.NP_code_check:
                    dist_function = D_dict[self.args.NP_dist_name]

                    if self.args.env_name == 'mujoco':
                        # print('continous action have not be done !!!')
                        with torch.no_grad():
                            old_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch, target=True)
                        new_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch)
                        d_term = []
                        for idv_old, idv_new in zip(old_dist, new_dist):
                            idv_d_term = dist_function(idv_old, idv_new, continuous=True)
                            # print('idv_d_term = {}'.format(idv_d_term.shape))
                            d_term.append(idv_d_term)
                        d_term = torch.stack(d_term, dim=1).reshape([-1, 1])

                        # print('d_term = {}'.format(d_term.shape))
                    else:
                        if (self.args.env_name == 'StarCraft2' or self.args.env_name == 'smacv2')  and not self.args.NP_ava_test:
                            probs = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch,
                                                          available_actions=available_actions_batch)
                        else:

                            probs = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch)
                        old_probs_batch = penalty_ret[0]
                        old_probs_batch = check(old_probs_batch).to(**self.tpdv)

                        if self.args.NP_DPO_check:
                            eps_kl = 1e-9
                            eps_sqrt = 1e-12
                            # print('n_actions = {} old_probs_batch = {}, probs = {}'.format(self.args.n_actions,old_probs_batch.shape,probs.shape))
                            kl_term = old_probs_batch * (
                                        torch.log(old_probs_batch + eps_kl) - torch.log(probs + eps_kl))
                            kl_term = torch.sum(kl_term, dim=-1, keepdim=True).reshape([-1, 1])
                            d_term = torch.sqrt(torch.max(kl_term + eps_sqrt, eps_sqrt * torch.ones_like(kl_term)))
                        else:
                            d_term = dist_function(probs, old_probs_batch)


                    if self.args.env_name == 'mujoco':
                        with torch.no_grad():
                            old_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch, target=True)
                        new_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch)
                        kl = []
                        for idv_old, idv_new in zip(old_dist, new_dist):
                            idv_kl = torch.distributions.kl_divergence(idv_old, idv_new)
                            kl.append(idv_kl)
                        kl_term = torch.stack(kl, dim=1)
                    else:
                        eps_kl = 1e-9
                        # print('n_actions = {} old_probs_batch = {}, probs = {}'.format(self.args.n_actions,old_probs_batch.shape,probs.shape))
                        kl_term = old_probs_batch * (
                                    torch.log(old_probs_batch + eps_kl) - torch.log(probs + eps_kl))
                    kl_term = torch.sum(kl_term, dim=-1, keepdim=True).reshape([-1, 1])
                    imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)

                    # print('imp_weights = {}'.format(imp_weights.shape))

                    term1 = imp_weights * adv_targ
                    # print('term1 = {}'.format(term1.shape))
                    term1 = term1.reshape([-1, self.args.n_agents])
                    d_term = d_term.reshape([-1, self.args.n_agents])

                    if self.args.env_name == 'mujoco':
                        term1_active_masks_batch = active_masks_batch.repeat([1, self.args.n_actions])
                        term1_active_masks_batch = term1_active_masks_batch.reshape([-1, self.args.n_agents])
                    else:
                        term1_active_masks_batch = active_masks_batch.reshape([-1, self.args.n_agents])
                    # print('active_masks = {}'.format(active_masks_batch.shape))
                    # print('n_agents = {}, n_actions = {}'.format(self.args.n_agents, self.args.n_actions))


                    policy_active_masks_batch = active_masks_batch.reshape([-1, self.args.n_agents])
                    if self._use_policy_active_masks:
                        term1 = (-term1 * term1_active_masks_batch).sum(dim=0) / term1_active_masks_batch.sum(dim=0)
                        d_term = (d_term * term1_active_masks_batch).sum(dim=0) / term1_active_masks_batch.sum(dim=0)
                    else:
                        term1 = (-term1).mean(dim=0)
                        d_term = d_term.mean(dim=0)


                    kl_term = kl_term.reshape([-1, self.args.n_agents])

                    if self._use_policy_active_masks:
                        kl_term = (kl_term * policy_active_masks_batch).sum(dim=0) / active_masks_batch.sum(dim=0)
                    else:
                        kl_term = kl_term.mean(dim=0)
                    # self.NP_KL_term_dist += kl_term.clone().detach().cpu().numpy()
                    self.term_kl = kl_term.clone().detach().cpu().numpy()

                    # self.term_dist += d_term.clone().detach().cpu().numpy()
                    self.term_sqrt_kl = d_term.clone().detach().cpu().numpy()

                    self.p_loss_part1 = term1.clone().detach().cpu().numpy().mean()
                    self.d_term = d_term.clone().detach().cpu().numpy().mean()


                    # print('after gamma correct: omega = {}'.format(d_coeff))

                    d_coeff = torch.tensor(self.beta_sqrt_kl).to(**self.tpdv)


                    # print('d_term = {} term1 = {}'.format(type(d_term),type(term1)))

                    # print('d_term = {}'.format(self.d_term))

                    self.d_coeff = d_coeff.mean()


                    kl_d_coeff = torch.tensor(self.beta_kl).to(**self.tpdv)
                    self.NP_KL_d_coeff = kl_d_coeff.mean()

                    # print('d_coeff = {} d_coeff_type = {}'.format(d_coeff,type(d_coeff)))

                    term2 = d_coeff * d_term
                    self.p_loss_part2 = term2.mean()


                    term3 = kl_d_coeff * kl_term
                    self.p_loss_part3 = term3.mean()


                    policy_loss = term1 + term2 + term3
                    policy_loss = policy_loss.mean()
                    # print('policy_loss = {}'.format(policy_loss))

                    clip_rate = 0
                else:
                    eps_kl = 1e-9
                    eps_sqrt = 1e-12
                    if self.args.env_name == 'mujoco':
                        with torch.no_grad():
                            old_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch,target=True)
                        new_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch)
                        kl = []
                        for idv_old,idv_new in zip(old_dist,new_dist):
                            idv_kl = torch.distributions.kl_divergence(idv_old,idv_new)
                            kl.append(idv_kl)
                        kl = torch.stack(kl,dim = 1)
                        # print('n_actions = {} kl = {}, idv_kl = {}'.format(self.args.n_actions,kl.shape,idv_kl.shape))
                    else:
                        if (self.args.env_name == 'StarCraft2' or self.args.env_name == 'smacv2')  and not self.args.NP_ava_test:
                            probs = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch,available_actions=available_actions_batch)
                        else:
                            print(f'here, you are not using available_actions env_name = {self.args.env_name}')
                            probs = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch)

                        if self.args.target_for_dist:
                            if (self.args.env_name == 'StarCraft2' or self.args.env_name == 'smacv2')  and not self.args.NP_ava_test:
                                old_probs_batch = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch,
                                                              available_actions=available_actions_batch, target=True)
                            else:
                                old_probs_batch = self.policy.get_probs(obs_batch, rnn_states_batch, masks_batch, target=True)
                            old_probs_batch = old_probs_batch.detach()
                        else:
                            old_probs_batch = penalty_ret[0]

                            old_probs_batch = check(old_probs_batch).to(**self.tpdv)

                        # print('n_actions = {} old_probs_batch = {}, probs = {}'.format(self.args.n_actions,old_probs_batch.shape,probs.shape))
                        kl = old_probs_batch * (torch.log(old_probs_batch + eps_kl) - torch.log(probs + eps_kl) )
                    # print('prev kl = {}'.format(kl.shape))
                    # print('obs_batch = {}'.format(obs_batch.shape))
                    # kl and obs_batch is with shape (episode_length * n_agents, n_actions)
                    kl = torch.sum(kl,dim = -1,keepdim=True).reshape([-1,1])
                    # print('after kl = {}'.format(kl.shape))
                    imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)
                    term1 = imp_weights * adv_targ
                    if not self.args.early_stop:
                        # print('dpo_NP_check = {}'.format(self.args.dpo_NP_check))
                        if self.args.DPO_NP_check:
                            dist_function = D_dict[self.args.NP_dist_name]
                            # print('dist_function = {}'.format(dist_function))
                            if self.args.env_name == 'mujoco':
                                with torch.no_grad():
                                    old_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch,
                                                                    target=True)
                                new_dist = self.policy.get_dist(obs_batch, rnn_states_batch, masks_batch)
                                d_term = []
                                for idv_old, idv_new in zip(old_dist, new_dist):
                                    idv_d_term = dist_function(idv_old, idv_new, continuous=True)
                                    # print('idv_d_term = {}'.format(idv_d_term.shape))
                                    d_term.append(idv_d_term)
                                sqrt_kl = torch.stack(d_term, dim=1).reshape([-1, 1])
                            else:
                                sqrt_kl = dist_function(probs, old_probs_batch)
                        else:
                            sqrt_kl = torch.sqrt(torch.max(kl + eps_sqrt,eps_sqrt * torch.ones_like(kl)))

                        if self.args.check_kl_output:
                            kl_1 = old_probs_batch * torch.log(old_probs_batch + eps_kl)
                            kl_2 = old_probs_batch * torch.log(probs + eps_kl)
                            arr_length = kl.shape[0]
                            kl_1 = kl_1.reshape([arr_length, -1])
                            kl_2 = kl_2.reshape([arr_length, -1])
                            for i in range(kl.shape[0]):
                                print(
                                    'kl_term[{}] = {}, kl_1[{}] = {}, kl_2[{}] = {} kl+eps_sqrt = {}, sqrt ={}'.format(i, kl[i][0],
                                                                                                                       i, kl_1[i],
                                                                                                                       i, kl_2[i],
                                                                                                                       kl[i][
                                                                                                                           0] + eps_sqrt,
                                                                                                                       torch.sqrt(
                                                                                                                           kl[i][
                                                                                                                               0] + eps_sqrt)))
                            print('eps_sqrt = {}'.format(eps_sqrt))
                            for i in range(kl.shape[0]):
                                print('sqrt_kl_term[{}] = {}'.format(i, sqrt_kl[i]))

                        term1 = term1.reshape([-1,self.args.n_agents])
                        sqrt_kl = sqrt_kl.reshape([-1,self.args.n_agents])
                        kl = kl.reshape([-1,self.args.n_agents])
                        # print(f'sqrt_kl = {sqrt_kl}\nkl = {kl}')
                        if self.args.env_name == 'mujoco':
                            term1_active_masks_batch = active_masks_batch.repeat([1, self.args.n_actions])
                            term1_active_masks_batch = term1_active_masks_batch.reshape([-1,self.args.n_agents])
                        else:
                            term1_active_masks_batch = active_masks_batch.reshape([-1, self.args.n_agents])
                        # print('active_masks = {}'.format(active_masks_batch.shape))
                        # print('n_agents = {}, n_actions = {}'.format(self.args.n_agents, self.args.n_actions))
                        policy_active_masks_batch = active_masks_batch.reshape([-1,self.args.n_agents])

                        if self._use_policy_active_masks:
                            term1 = (-term1 * term1_active_masks_batch).sum(dim = 0) / term1_active_masks_batch.sum(dim = 0)
                            if self.args.DPO_term_sqrt_update_mode == 'one':
                                term_sqrt_kl = (sqrt_kl * term1_active_masks_batch).sum(dim=0) / term1_active_masks_batch.sum(
                                    dim=0)
                            elif self.args.DPO_term_sqrt_update_mode == 'two':
                                term_sqrt_kl = (sqrt_kl * policy_active_masks_batch).sum(dim = 0) / active_masks_batch.sum(dim = 0)
                            term_kl  = (kl * policy_active_masks_batch).sum(dim = 0) / active_masks_batch.sum(dim = 0)
                            # term1 = (-torch.sum(term1,dim=-1,keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
                            # print('kl = {} sqrt_kl = {} active_masks = {}'.format(kl.shape,sqrt_kl.shape,active_masks_batch.shape))
                            # term_sqrt_kl = (torch.sum(sqrt_kl,dim=-1,keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
                            # term_kl = (torch.sum(kl, dim=-1, keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
                        else:
                            term1 = (-term1).mean(dim=0)
                            term_sqrt_kl = sqrt_kl.mean(dim=0)
                            term_kl = kl.mean(dim=0)
                            # term1 = -torch.sum(term1, dim=-1, keepdim=True).mean()
                            # term_sqrt_kl = torch.sum(sqrt_kl, dim=-1, keepdim=True).mean()
                            # term_kl = torch.sum(kl, dim=-1, keepdim=True).mean()
                        # print(f'term_sqrt_kl = {term_sqrt_kl}\nterm_kl = {term_kl}')
                        if self.args.idv_beta:
                            if self.args.DPO_term_update_mode == 'mean':
                                self.term_sqrt_kl += term_sqrt_kl.clone().detach().cpu().numpy()
                                self.term_kl += term_kl.clone().detach().cpu().numpy()
                            elif self.args.DPO_term_update_mode == 'single':
                                self.term_sqrt_kl = term_sqrt_kl.clone().detach().cpu().numpy()
                                self.term_kl = term_kl.clone().detach().cpu().numpy()
                        else:
                            self.term_sqrt_kl = torch.ones_like(term_sqrt_kl) * (term_sqrt_kl.mean())
                            self.term_kl = torch.ones_like(term_kl) * (term_kl.mean())
                        # print(f'self.term_sqrt_kl = {self.term_sqrt_kl}\nterm_kl = {self.term_kl}')
                        if self.args.dpo_policy_div_agent_num:
                            term1 /= self.args.n_agents

                        if self.args.dpo_check_kl_baseline:
                            imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)
                            surr1 = imp_weights * adv_targ
                            surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
                            if self._use_policy_active_masks:
                                policy_action_loss = (-torch.sum(torch.min(surr1, surr2),
                                                                 dim=-1,
                                                                 keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
                            else:
                                policy_action_loss = -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True).mean()

                            policy_loss = policy_action_loss
                        else:
                            sqrt_coeff = torch.tensor(self.beta_sqrt_kl).to(**self.tpdv).detach()
                            kl_coeff = torch.tensor(self.beta_kl).to(**self.tpdv).detach()
                            policy_loss = term1 + sqrt_coeff * term_sqrt_kl + kl_coeff * term_kl
                            policy_loss = policy_loss.mean()
                        self.es_tag = False
                    else:
                        if self._use_policy_active_masks:
                            term1 = (-torch.sum(term1,dim=-1,keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
                        else:
                            term1 = -torch.sum(term1, dim=-1, keepdim=True).mean()
                        policy_loss = term1
                        self.es_kl.append(torch.mean(kl))
                    clip_rate = 0
        else:
            if self.args.dynamic_clip_tag:
                if self.args.idv_para and aga_update_tag and self.args.aga_tag:
                    imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)
                    surr1 = imp_weights * adv_targ
                    surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param[update_index], 1.0 + self.clip_param[update_index]) * adv_targ
                else:
                    surr1 = []
                    surr2 = []
                    agent_adv_targ = adv_targ.reshape([-1, self.args.n_agents, 1])
                    if self.args.env_name == 'mujoco':
                        old_action_log_probs_batch = old_action_log_probs_batch.mean(axis = -1)
                    agent_old_action_log_probs = old_action_log_probs_batch.reshape([-1, self.args.n_agents, 1])
                    # print('init_agent_old_action_log_probs = {},agent_old_action_log_probs = {}'.format(old_action_log_probs_batch.shape,agent_old_action_log_probs.shape))
                    # print('old_batch = {}'.format(old_action_log_probs_batch))
                    imp_weights = []
                    for i in range(self.args.n_agents):
                        # print('action_log_probs[{}] = {},  agent_old_action_log_probs[:, {}] = {}'.format(i,action_log_probs[i].shape,i,agent_old_action_log_probs[:, i].shape ))

                        agent_imp_weights = torch.exp(action_log_probs[i] - agent_old_action_log_probs[:, i])
                        agent_surr1 = agent_imp_weights * agent_adv_targ[:, i]
                        agent_surr2 = torch.clamp(agent_imp_weights, 1.0 - self.clip_param[i],
                                                  1.0 + self.clip_param[i]) * agent_adv_targ[:, i]
                        surr1.append(agent_surr1)
                        surr2.append(agent_surr2)
                        imp_weights.append(agent_imp_weights)
                    surr1 = torch.stack(surr1, dim=1).reshape([-1, 1])
                    surr2 = torch.stack(surr2, dim=1).reshape([-1, 1])
                    imp_weights = torch.stack(imp_weights, dim=1).reshape([-1, 1])
            else:
                imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)
                surr1 = imp_weights * adv_targ
                surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
            clip_check = (surr1 != surr2)
            clip_check_sum = clip_check.sum()
            clip_check_total = torch.ones_like(clip_check).sum()
            # print('clip_check = {}'.format(clip_check))
            # print('clip_check_sum = {}'.format(clip_check_sum))
            # print('clip_check_total = {}'.format(clip_check_total))
            clip_rate = float(clip_check_sum) / float(clip_check_total)
            print('clip_rate = {}'.format(clip_rate))

            if self._use_policy_active_masks:
                policy_action_loss = (-torch.sum(torch.min(surr1, surr2),
                                                 dim=-1,
                                                 keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
            else:
                policy_action_loss = -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True).mean()

            policy_loss = policy_action_loss


        if self.args.NP_decay_mode == 'auto':
            self.opt_NP_auto.zero_grad()
            auto_loss.backward()
            self.opt_NP_auto.step()
        else:
            auto_loss = 0


        self.actor_zero_grad(aga_update_tag,update_index)

        # print('policy_loss = {}'.format(policy_loss))
        if update_actor:
            (policy_loss - dist_entropy * self.entropy_coef).backward()

        if self._use_max_grad_norm:
            if self.args.idv_para:
                if aga_update_tag and self.args.aga_tag:
                    actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor[update_index].parameters(), self.max_grad_norm)
                else:
                    actor_grad_norm = 0
                    for i in range(self.args.n_agents):
                        idv_actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor[i].parameters(),
                                                                   self.max_grad_norm)
                        actor_grad_norm += idv_actor_grad_norm
                    actor_grad_norm /= self.args.n_agents

            else:
                actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
        else:
            if self.args.idv_para:
                if aga_update_tag and self.args.aga_tag:
                    actor_grad_norm = get_gard_norm(self.policy.actor[update_index].parameters())
                else:
                    actor_grad_norm = 0
                    for i in range(self.args.n_agents):
                        idv_actor_grad_norm = get_gard_norm(self.policy.actor[i].parameters())
                        actor_grad_norm += idv_actor_grad_norm
                    actor_grad_norm /= self.args.n_agents

            else:
                actor_grad_norm = get_gard_norm(self.policy.actor.parameters())




        if self.args.idv_para:
            if aga_update_tag and self.args.aga_tag:
                self.policy.actor_optimizer[update_index].step()
            else:
                for i in range(self.args.n_agents):
                    self.policy.actor_optimizer[i].step()
        else:
            self.policy.actor_optimizer.step()

        if self.args.soft_target and self.args.soft_target_mode == 'mini_batch':
            self.policy.soft_update_policy(self.args.soft_target_tau)

        # critic update
        if self.args.use_q:
            act_idx = torch.from_numpy(actions_batch).to(**self.tpdv)
            act_idx = act_idx.long()
            # act_idx =
            # print('act_idx.dtype = {} act_idx = {} values = {}'.format(act_idx.dtype,act_idx.shape,values.shape))
            values = torch.gather(values_all,index = act_idx,dim = -1)
        else:
            values = values_all

        value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)

        self.critic_zero_grad(aga_update_tag,update_index)

        # print('value_loss = {}'.format(value_loss.dtype))

        (value_loss * self.value_loss_coef).backward()

        if self._use_max_grad_norm:
            if self.args.idv_para:
                if aga_update_tag and self.args.aga_tag:
                    critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic[update_index].parameters(),
                                                               self.max_grad_norm)
                else:
                    critic_grad_norm = 0
                    for i in range(self.args.n_agents):
                        idv_critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic[i].parameters(),
                                                                       self.max_grad_norm)
                        critic_grad_norm += idv_critic_grad_norm
                    critic_grad_norm /= self.args.n_agents

            else:
                critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
        else:
            if self.args.idv_para:
                if aga_update_tag and self.args.aga_tag:
                    critic_grad_norm = get_gard_norm(self.policy.critic[update_index].parameters())
                else:
                    critic_grad_norm = 0
                    for i in range(self.args.n_agents):
                        idv_critic_grad_norm = get_gard_norm(self.policy.critic[i].parameters())
                        critic_grad_norm += idv_critic_grad_norm
                    critic_grad_norm /= self.args.n_agents

            else:
                critic_grad_norm = get_gard_norm(self.policy.critic.parameters())


        if self.args.idv_para:
            if aga_update_tag and self.args.aga_tag:
                self.policy.critic_optimizer[update_index].step()
            else:
                for i in range(self.args.n_agents):
                    self.policy.critic_optimizer[i].step()
        else:
            self.policy.critic_optimizer.step()

        # if self.args.dynamic_clip_tag and self.args.use_q and curr_update_num < self.args.clip_update_num:
        #     # all_probs = self.policy.get_probs(obs_batch,rnn_states_batch, masks_batch,available_actions_batch)
        #     old_A = old_values_batch - baseline_batch
        #     self.update_policy_clip(old_probs_batch,old_A)

        return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights,clip_rate,auto_loss

    def update_dtar(self,time_steps):
        if self.args.NP_dtar_update_mode == 'linear':
            coeff_steps = np.min([time_steps, self.args.NP_dtar_cycle])
            coeff_rate = coeff_steps / self.args.NP_dtar_cycle
            new_dtar = coeff_rate * self.args.NP_dtar_min + (1 - coeff_rate) * self.args.NP_dtar_max
        elif self.args.NP_dtar_update_mode == 'cos':
            coeff_steps = time_steps % self.args.NP_dtar_cycle
            coeff_rate = coeff_steps / self.args.NP_dtar_cycle * np.pi * 2
            new_dtar = 0.5 * (np.cos(coeff_rate) + 1) * (self.args.NP_dtar_max - self.args.NP_dtar_min) + self.args.NP_dtar_min
        elif self.args.NP_dtar_update_mode == 'linear_cycle':
            coeff_steps = time_steps % self.args.NP_dtar_cycle
            coeff_rate = coeff_steps / self.args.NP_dtar_cycle
            new_dtar = coeff_rate * self.args.NP_dtar_min + (1 - coeff_rate) * self.args.NP_dtar_max
        return new_dtar
    def train(self, buffer, update_actor=True,time_steps=None):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        epoch_num = 0
        if self.args.new_period and buffer.aga_update_tag and self.args.aga_tag:
            epoch_num = self.ppo_epoch * self.args.period
        else:
            epoch_num = self.ppo_epoch

        if self.args.new_penalty_method and self.args.NP_coeff_decay:
            if self.args.NP_decay_mode == 'linear':
                coeff_steps = np.min([time_steps,self.NP_decay_max])
                coeff_rate = coeff_steps / self.NP_decay_max
                self.NP_coeff = coeff_rate * self.args.NP_coeff_end + (1 - coeff_rate) * self.args.NP_coeff
                print('time_steps = {}, coeff_steps = {}, coeff_rate = {}, NP_decay_max = {}'.format(time_steps,coeff_steps,coeff_rate,self.NP_decay_max))


        if self.args.new_penalty_method and self.args.NP_add_delta:
            if self.args.NP_delta_mode == 'fix':
                self.delta = self.args.NP_delta * np.ones(self.args.n_agents)
            elif self.args.NP_delta_mode == 'adaptive':
                print('prev delta = {}'.format(self.delta))
                if self.term_dist is not None:

                    if self.args.NP_term_update_mode == 'mean':
                        self.term_dist /= epoch_num
                    print('term_dist = {}'.format(self.term_dist))
                    term_dist_for_output = np.array(self.term_dist)
                    ada_lower = self.args.NP_delta / self.kl_para1
                    ada_upper = self.args.NP_delta * self.kl_para1
                    print('ada_lower = {} ada_upper = {}'.format(ada_lower,ada_upper))
                    for i in range(self.args.n_agents):
                        if self.term_dist[i] < ada_lower:

                            self.delta[i] /= self.kl_para2
                            print(
                                'for agent {}:: self.term_dist[{}] = {} < ada_lower = {} increase self.delta[{}] to {}' \
                                    .format(i, i, self.term_dist[i], ada_lower, i, self.delta[i]))
                            # self.beta_kl = np.maximum(self.para_lower_bound,self.beta_kl)
                        elif self.term_dist[i] > ada_upper:
                            self.delta[i] *= self.kl_para2
                            print(
                                'for agent {}:: self.term_dist[{}] = {} < ada_upper = {} increase self.delta[{}] to {}' \
                                    .format(i, i, self.term_dist[i], ada_upper, i, self.delta[i]))
                else:
                    term_dist_for_output = None
                self.term_dist = np.zeros(self.args.n_agents)
                if self.args.NP_delta_clip_mode == 'central':
                    clip_lower_bound = self.args.NP_delta * self.para_lower_bound
                    clip_upper_bound = self.args.NP_delta * self.para_upper_bound
                elif self.args.NP_delta_clip_mode == 'fixed':
                    clip_lower_bound = 1.0 * self.para_lower_bound
                    clip_upper_bound = 1.0 * self.para_upper_bound
                for i in range(self.args.n_agents):
                    if self.delta[i] < clip_lower_bound:
                        self.delta[i] = clip_lower_bound
                    if self.delta[i] > clip_upper_bound:
                        self.delta[i] = clip_upper_bound
                print('KL_clip_lower_bound = {}, KL_clip_upper_bound = {}'.format(clip_lower_bound,
                                                                                  clip_upper_bound))
                print('after delta = {}'.format(self.delta))

                if self.args.NP_KL_term:
                    print('prev kl_delta = {}'.format(self.NP_KL_delta))
                    if self.args.NP_term_update_mode == 'mean':
                        self.NP_KL_term_dist /= epoch_num
                    print('NP_KL_term_dist = {}'.format(self.NP_KL_term_dist))
                    NP_KL_term_dist_for_output = np.array(self.term_dist)
                    ada_lower = self.args.NP_KL_delta / self.args.NP_kl_para1
                    ada_upper = self.args.NP_KL_delta * self.args.NP_kl_para1
                    print('ada_lower = {} ada_upper = {}'.format(ada_lower, ada_upper))
                    for i in range(self.args.n_agents):
                        if self.NP_KL_term_dist[i] < ada_lower:
                            self.NP_KL_delta[i] /= self.args.NP_kl_para2
                            print(
                                'for agent {}:: self.NP_KL_term_dist[{}] = {} < ada_lower = {} increase self.NP_KL_delta[{}] to {}' \
                                    .format(i, i, self.NP_KL_term_dist[i], ada_lower, i, self.NP_KL_delta[i]))
                            # self.beta_kl = np.maximum(self.para_lower_bound,self.beta_kl)
                        elif self.NP_KL_term_dist[i] > ada_upper:
                            self.NP_KL_delta[i] *= self.args.NP_kl_para2
                            print(
                                'for agent {}:: self.NP_KL_term_dist[{}] = {} < ada_upper = {} increase self.NP_KL_delta[{}] to {}' \
                                    .format(i, i, self.NP_KL_term_dist[i], ada_upper, i, self.NP_KL_delta[i]))

                    self.NP_KL_term_dist = np.zeros(self.args.n_agents)
                    if self.args.NP_delta_clip_mode == 'central':
                        KL_clip_lower_bound = self.args.NP_KL_delta * self.args.NP_KL_para_lower_bound
                        KL_clip_upper_bound = self.args.NP_KL_delta * self.args.NP_KL_para_upper_bound
                    elif self.args.NP_delta_clip_mode == 'fixed':
                        KL_clip_lower_bound = 1.0 * self.args.NP_KL_para_lower_bound
                        KL_clip_upper_bound = 1.0 * self.args.NP_KL_para_upper_bound
                    for i in range(self.args.n_agents):
                        if self.NP_KL_delta[i] < KL_clip_lower_bound:
                            self.NP_KL_delta[i] = KL_clip_lower_bound
                        if self.NP_KL_delta[i] > KL_clip_upper_bound:
                            self.NP_KL_delta[i] = KL_clip_upper_bound
                    print('KL_clip_lower_bound = {}, KL_clip_upper_bound = {}'.format(KL_clip_lower_bound,KL_clip_upper_bound))
                    print('after kl_delta = {}'.format(self.NP_KL_delta))

        if self.args.new_penalty_method:
            self.term_dist = np.zeros(self.args.n_agents)

        if self.args.NP_dtar_update:
            new_dtar = self.update_dtar(time_steps)
            print(f'time_steps = {time_steps} new_dtar = {new_dtar}')
            self.set_dtar(dtar_kl=new_dtar)

        if self.args.penalty_method and self.term_kl is not None and (not self.args.new_penalty_method ):
            print('term_sqrt_kl = {} term_kl = {} old_beta_sqrt_kl = {}, old_beta_kl = {}'.format(self.term_sqrt_kl,
                                                                                                  self.term_kl,
                                                                                                  self.beta_sqrt_kl,
                                                                                                  self.beta_kl))
            print('prev beta_kl = {}'.format(self.beta_kl))

            if self.args.penalty_beta_type == 'adaptive':
                if self.args.DPO_term_update_mode == 'mean':
                    self.term_kl /= epoch_num
                print('self.kl_lower = {}, self.kl_upper = {}'.format(self.kl_lower,self.kl_upper))
                for i in range(self.args.n_agents):
                    if self.term_kl[i] < self.kl_lower:
                        self.beta_kl[i] /= self.kl_para2
                        print('for agent {}:: self.term_kl[{}] = {} < self.kl_lower = {} decrease self.beta_kl[{}] to {}'\
                              .format(i,i,self.term_kl[i],self.kl_lower,i,self.beta_kl[i]))
                        # self.beta_kl = np.maximum(self.para_lower_bound,self.beta_kl)
                    elif self.term_kl[i] > self.kl_upper:
                        self.beta_kl[i] *= self.kl_para2
                        print(
                            'for agent {}:: self.term_kl[{}] = {} < self.kl_upper = {} increase self.beta_kl[{}] to {}' \
                            .format(i, i, self.term_kl[i], self.kl_upper, i, self.beta_kl[i]))
                    # self.beta_kl = np.minimum(self.para_upper_bound, self.beta_kl)
                print('after beta_kl = {}'.format(self.beta_kl))
            elif self.args.penalty_beta_type == 'fixed':
                self.beta_kl = np.ones(self.args.n_agents) * self.dtar_kl
            elif self.args.penalty_beta_type == 'adaptive_rule2':
                if self.args.use_q:
                    old_pi = buffer.old_probs_all[:-1]

                    # (39,1,3,5) (episode_length,rollout_thread,agents, actions)
                    old_pi = old_pi.reshape(-1, self.args.n_actions)
                    old_q = buffer.old_values_all[:-1]
                    # print('old_q_shape = {}'.format(old_q.shape))
                    old_q = old_q.reshape(-1, self.args.n_actions)
                    old_baseline = buffer.baseline[:-1]
                    old_baseline = old_baseline.reshape(-1, 1)
                    old_A = old_q - old_baseline
                    E_A_square = old_A * old_A * old_pi
                    E_A_square = 2 * np.sqrt( np.sum(E_A_square) )
                elif self.args.sp_use_q:
                    old_pi = buffer.sp_prob_all[:-1]

                    # (39,1,3,5) (episode_length,rollout_thread,agents, actions)
                    old_pi = old_pi.reshape(-1, self.args.sp_num)
                    old_q = buffer.sp_value_all[:-1]
                    # print('old_q_shape = {}'.format(old_q.shape))
                    old_q = old_q.reshape(-1, self.args.sp_num)
                    old_baseline = np.mean(old_q, axis=-1, keepdims=True)
                    old_A = old_q - old_baseline
                    E_A_square = old_A * old_A
                    E_A_square = 2 * np.sqrt(np.mean(E_A_square))
                self.beta_sqrt_kl = 0.9 * E_A_square

                print('before_ramge_clip: new_beta_sqrt_kl = {}'.format(self.beta_sqrt_kl))
            if self.args.penalty_beta_sqrt_type == 'adaptive':
                if self.args.DPO_term_update_mode == 'mean':
                    self.term_kl /= epoch_num
                print('self.sqrt_kl_lower = {}, self.sqrt_kl_upper = {}'.format(self.sqrt_kl_lower, self.sqrt_kl_upper))
                for i in range(self.args.n_agents):
                    if self.term_sqrt_kl[i] < self.sqrt_kl_lower:
                        self.beta_sqrt_kl[i] /= self.sqrt_kl_para2
                        print('for agent {}:: self.term_sqrt_kl[{}] = {} < self.sqrt_kl_lower = {} decrease self.beta_sqrt_kl[{}] to {}'\
                              .format(i,i,self.term_sqrt_kl[i],self.sqrt_kl_lower,i,self.beta_sqrt_kl[i]))
                        # self.beta_sqrt_kl = np.maximum(self.para_lower_bound, self.beta_sqrt_kl)

                    elif self.term_sqrt_kl[i] > self.sqrt_kl_upper:
                        self.beta_sqrt_kl[i] *= self.sqrt_kl_para2
                        print('for agent {}:: self.term_sqrt_kl[{}] = {} > self.sqrt_kl_upper = {} increase self.beta_sqrt_kl[{}] to {}'\
                              .format(i,i,self.term_sqrt_kl[i],self.sqrt_kl_upper,i,self.beta_sqrt_kl[i]))
                    # self.beta_sqrt_kl = np.minimum(self.para_upper_bound, self.beta_sqrt_kl)
            elif self.args.penalty_beta_sqrt_type == 'fixed':
                self.beta_sqrt_kl = np.ones(self.args.n_agents) * self.dtar_sqrt_kl


            for i in range(self.args.n_agents):
                if self.beta_kl[i] < self.para_lower_bound:
                    self.beta_kl[i] = self.para_lower_bound
                if self.beta_kl[i] > self.para_upper_bound:
                    self.beta_kl[i] = self.para_upper_bound
                if self.beta_sqrt_kl[i] < self.para_lower_bound:
                    self.beta_sqrt_kl[i] = self.para_lower_bound
                if self.beta_sqrt_kl[i] > self.para_upper_bound:
                    self.beta_sqrt_kl[i] = self.para_upper_bound
            print('para_lower_bound = {} para_upper_bound = {}'.format(self.para_lower_bound, self.para_upper_bound))
            print('after_ramge_clip: new_beta_sqrt_kl = {} new_beta_kl = {}'.format(self.beta_sqrt_kl,self.beta_kl))

            if self.args.no_sqrt_kl:
                for i in range(self.args.n_agents):
                    self.beta_sqrt_kl[i] = 0
            if self.args.no_kl:
                for i in range(self.args.n_agents):
                    self.beta_kl[i] = 0

            print('new_beta_sqrt_kl = {}, new_beta_kl = {}'.format(self.beta_sqrt_kl, self.beta_kl))


        if self.args.penalty_method:
            self.term_kl =  np.zeros(self.args.n_agents)
            self.term_sqrt_kl = np.zeros(self.args.n_agents)
            # np.zeros(self.args.n_agents)


        if self.args.dynamic_clip_tag:
            if self.args.use_q:
                if self.args.all_state_clip:

                    value,probs = self.get_matrix_state_table()
                    old_q.append(value.detach().cpu().numpy())
                    old_pi.append(probs.detach().cpu().numpy())
                    old_q = np.array(old_q)
                    old_pi = np.array(old_pi)
                    old_baseline = np.sum(old_q * old_pi,axis = -1,keepdims= True)
                    old_A = old_q - old_baseline

                else:
                    old_pi = buffer.old_probs_all[:-1]

                    # (39,1,3,5) (episode_length,rollout_thread,agents, actions)
                    old_pi = old_pi.reshape(-1, self.args.n_actions)
                    old_q = buffer.old_values_all[:-1]
                    # print('old_q_shape = {}'.format(old_q.shape))
                    old_q = old_q.reshape(-1, self.args.n_actions)
                    old_baseline = buffer.baseline[:-1]
                    old_baseline = old_baseline.reshape(-1, 1)
                    old_A = old_q - old_baseline
            elif self.args.sp_clip:
                old_pi = buffer.sp_prob_all[:-1]

                # (39,1,3,5) (episode_length,rollout_thread,agents, actions)
                old_pi = old_pi.reshape(-1, self.args.sp_num)
                old_q = buffer.sp_value_all[:-1]
                # print('old_q_shape = {}'.format(old_q.shape))
                old_q = old_q.reshape(-1, self.args.sp_num)
                old_baseline = np.mean(old_q,axis=-1,keepdims=True)
                old_A = old_q - old_baseline

            print('all_state_clip = {}'.format(self.args.all_state_clip))
            print('old_pi_shape = {}'.format(old_pi.shape))
            print('buffer.obs = {}'.format(buffer.obs.shape))
            # (40,10,3,30) (epsiode_limit + 1, rollout_num,agents, state_dim)
            if self.args.true_rho_s:
                sample_state = buffer.obs[:-1,:,0,:] #(39,10,30)
                sample_state = np.mean(sample_state,axis = 1)
                # print('state_cnt = {}'.format(sample_state))
                episode_length = sample_state.shape[0]
                gamma_list = np.ones([episode_length,1])
                for i in range(1,episode_length):
                    gamma_list[i] = self.args.gamma * gamma_list[i - 1]
                rho_s = np.sum(gamma_list * sample_state,axis = 0)
                print('rho_s = {}'.format(rho_s))
            else:
                rho_s = None
            if self.args.dcmode == 1:
                if self.args.delta_decay:
                    decay_constant = (1 - (2*self.curr_delta_update)/self.all_updates )*self.clip_constant
                    print('self.min_clip_constant = {}, decay_constant = {}'.format(self.min_clip_constant,decay_constant))
                    self.curr_constant = np.maximum(self.min_clip_constant,  decay_constant)
                else:
                    self.curr_constant = self.clip_constant

                if self.args.delta_reset:
                    idv_delta = self.curr_constant / self.args.n_agents
                    idv_eps = np.exp(idv_delta) - 1
                    self.clip_param = np.ones(self.args.n_agents) * idv_eps
                print('clip update {}/{} init clip params = {}'.format(self.curr_delta_update,self.all_updates,self.clip_param))

                all_deltas = []
                for i in range(self.args.clip_update_num):
                    clip_iter, clip_solve_loss = self.update_policy_clip_ver_1(old_pi, old_A)
                    print('clip step {}: clip_params = {}'.format(i,self.clip_param))
                    all_deltas.append(self.clip_delta)
                all_deltas = np.array(all_deltas)
                final_delta = np.mean(all_deltas,axis = 0)
                final_eps = np.exp(final_delta) - 1
                self.clip_param = final_eps
                self.curr_delta_update += 1

            elif self.args.dcmode == 2:
                clip_iter, clip_solve_loss = self.update_policy_clip_ver_2(old_pi, old_A,rho_s)


            print('solved_clip_params = {}'.format( self.clip_param))
            if self.args.weighted_clip and time_steps is not None:
                ratios = np.minimum(float(time_steps)/float(self.args.weighted_clip_step),1.0)
                self.clip_param = ratios * self.clip_param + (1.0 - ratios) * self.args.weighted_clip_init
                print('weighted_clip_params = {}, time_steps = {} ratios = {}'.format(self.clip_param,time_steps,ratios))

        if self.args.penalty_method and self.args.env_name == 'mujoco' and self.args.correct_kl:
            self.policy.hard_update_policy()


        if self.args.use_q:
            if self._use_popart:
                advantages = self.value_normalizer.denormalize(buffer.value_curr[:-1]) - self.value_normalizer.denormalize(buffer.baseline[:-1])
            else:
                advantages = buffer.value_curr[:-1] - buffer.baseline[:-1]
        else:
            if self._use_popart:
                advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])
            else:
                advantages = buffer.returns[:-1] - buffer.value_preds[:-1]



        advantages_copy = advantages.copy()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_advantages = np.nanmean(advantages_copy)
        std_advantages = np.nanstd(advantages_copy)
        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)
        

        train_info = {}

        train_info['value_loss'] = 0
        train_info['policy_loss'] = 0
        train_info['dist_entropy'] = 0
        train_info['actor_grad_norm'] = 0
        train_info['critic_grad_norm'] = 0
        train_info['ratio'] = 0
        train_info['clip_rate'] = 0
        train_info['kl_div'] = 0
        train_info['p_loss_part1'] = 0
        train_info['p_loss_part2'] = 0
        train_info['d_coeff'] = 0
        train_info['d_term'] = 0
        train_info['auto_loss'] = 0
        train_info['term1_grad_norm'] = 0
        train_info['term2_grad_norm'] = 0
        train_info['grad_ratio'] = 0




        if self.args.optim_reset and self.args.aga_tag and buffer.aga_update_tag:
            self.policy.optim_reset(buffer.update_index)
        curr_update_num = 0
        for epoch_cnt in range(epoch_num):
            self.es_tag = False
            self.es_kl = []
            if self._use_recurrent_policy:
                data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
            elif self._use_naive_recurrent:
                data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
            else:
                data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)

            for sample in data_generator:

                value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights,clip_rate,auto_loss \
                    = self.ppo_update(sample, update_actor,aga_update_tag = buffer.aga_update_tag,update_index = buffer.update_index,curr_update_num = epoch_cnt)

                train_info['value_loss'] += value_loss.item()
                train_info['policy_loss'] += policy_loss.item()
                train_info['dist_entropy'] += dist_entropy.item()
                train_info['actor_grad_norm'] += actor_grad_norm
                train_info['critic_grad_norm'] += critic_grad_norm
                train_info['ratio'] += imp_weights.mean()
                train_info['clip_rate'] += clip_rate
                train_info['auto_loss'] += auto_loss
                if self.args.new_penalty_method:
                    train_info['p_loss_part1'] += self.p_loss_part1
                    train_info['p_loss_part2'] += self.p_loss_part2
                    train_info['d_coeff'] += self.d_coeff
                    train_info['d_term'] += self.d_term
                    if self.args.NP_grad_check:
                        train_info['term1_grad_norm'] += self.term1_grad_norm
                        train_info['term2_grad_norm'] += self.term2_grad_norm
                        train_info['grad_ratio'] = train_info['term1_grad_norm'] / train_info['term2_grad_norm']
                if self.term_kl is not None:
                    train_info['kl_div'] += self.term_kl.mean()
                curr_update_num += 1
                if self.args.early_stop:
                    self.es_kl = torch.mean(torch.tensor(self.es_kl))
                    self.es_tag = self.es_kl > self.args.es_judge
                    if self.es_tag:
                        print('es_kl = {}, es_judge = {}'.format(self.es_kl,self.args.es_judge))
                        print('early stop after {} epoch, total updates {}'.format(epoch_cnt,curr_update_num))
                        train_info['early_stop_epoch'] = curr_update_num
                        train_info['early_stop_kl'] = self.es_kl
                        break
            if self.es_tag:
                break
        if not self.es_tag and self.args.early_stop:
            train_info['early_stop_epoch'] = curr_update_num
            train_info['early_stop_kl'] = self.es_kl

        if self.args.target_dec or self.args.target_for_dist:
            if self.args.soft_target and self.args.soft_target_mode == 'batch':
                self.policy.soft_update_policy(self.args.soft_target_tau)
            else:
                if buffer.aga_update_tag and self.args.aga_tag:
                    self.policy.hard_update_policy()

        if self.args.sp_use_q:
            if self.args.sp_update_policy == 'hard':
                self.policy.hard_update_policy()
            else:
                self.policy.soft_update_policy()

        if self.args.use_q or self.args.sp_use_q:
            self.policy.soft_update_critic()


        num_updates = self.ppo_epoch * self.num_mini_batch

        for k in train_info.keys():
            if k in ['early_stop_epoch','early_stop_kl','grad_ratio']:
                continue
            train_info[k] /= num_updates
        out_info_key = [ 'min_clip', 'max_clip', 'mean_clip', 'clip_rate']
        if self.args.dynamic_clip_tag and self.args.use_q:
            train_info['clip_iteration'] = clip_iter
            train_info['clip_solve_loss'] = clip_solve_loss
            out_info_key.append('clip_iteration')
            out_info_key.append('clip_solve_loss')

        train_info['min_clip'] = np.min(self.clip_param)
        train_info['max_clip'] = np.max(self.clip_param)
        train_info['mean_clip'] = np.mean(self.clip_param)

        if self.args.check_optimal_V_bound:
            approx_value,probs = self.get_matrix_state_table()
            probs = probs.clone().detach().cpu().numpy()
            optimal_value = self.args.optimal_V_table
            # print('returned_value = {}, optimal_value = {}'.format(value.shape,optimal_value.shape))
            value = self.calc_matrix_V_table(probs,optimal_check=self.args.NP_recalc_optimal_V)
            v_dist = np.linalg.norm(optimal_value - value)
            key = 'joint_V_bound'
            train_info[key] = v_dist
            if self.args.check_V_details:
                print('V_table = {}\noptimal_V_table = {}\nv_dist = {}'.format(value,optimal_value,v_dist))
                for i in range(self.args.n_agents):
                    value_i = approx_value[i]
                    value_i = value_i.reshape([-1])
                    value_i = value_i.clone().detach().cpu().numpy()
                    print('for agent {}::V_net = {}'.format(i,value_i))

        if self.args.agent_dist_and_coeff:
            for i in range(self.args.n_agents):
                train_info['NP_delta_agent_{}'.format(i)] = self.beta_sqrt_kl[i]
                train_info['distance_agent_{}'.format(i)] =self.term_sqrt_kl[i]
                train_info['NP_KL_delta_agent_{}'.format(i)] = self.beta_kl[i]
                train_info['KL_agent_{}'.format(i)] = self.term_kl[i]




        out_info = ''
        for key in out_info_key:
            out_info += '{}: {}    '.format(key, train_info[key])
        print(out_info)
        return train_info

    def prep_training(self):
        if self.args.idv_para:
            for i in range(self.args.n_agents):
                self.policy.actor[i].train()
                self.policy.critic[i].train()
        else:
            self.policy.actor.train()
            self.policy.critic.train()

    def prep_rollout(self):
        if self.args.idv_para:
            for i in range(self.args.n_agents):
                self.policy.actor[i].eval()
                self.policy.critic[i].eval()
        else:
            self.policy.actor.eval()
            self.policy.critic.eval()

    def actor_zero_grad(self,aga_update_tag = False,update_index = -1):
        if self.args.idv_para:
            if aga_update_tag and self.args.aga_tag:
                self.policy.actor_optimizer[update_index].zero_grad()
            else:
                for i in range(self.args.n_agents):
                    self.policy.actor_optimizer[i].zero_grad()
        else:
            self.policy.actor_optimizer.zero_grad()
    def critic_zero_grad(self,aga_update_tag = False,update_index = -1):
        if self.args.idv_para:
            if aga_update_tag and self.args.aga_tag:
                self.policy.critic_optimizer[update_index].zero_grad()
            else:
                for i in range(self.args.n_agents):
                    self.policy.critic_optimizer[i].zero_grad()
        else:
            self.policy.critic_optimizer.zero_grad()