# -*- coding: utf-8 -*-
# Communication range is controlled in three main places:
# 1. Line 742: max_adj - sets the maximum communication range
# 2. Line 56: controls the communication range
# 3. Line 214: adjusts the adjacency matrix range
#
# Code changes needed:
# 1. In config: set self.alpha = 0.1
# 2. In config: adjust cost_bound based on c * rollout_length,
#    where c = exp(alpha * k) - 1. Set c appropriately with alpha = 0.1
#
# Map parameters:
# 1. Monaco: 28
# 2. catchup: 16, k=10, length=600, bound=1200
# 3. Eight: n=14, k=10, length=1400, bound=3000
# 4. Grid: 25
# 5. Ring: 22, k=15, length=1500, bound=3000
# 6. Large_city
# 7. Pandemic
# 8. Powergrid: 40
# 9. real_power
# 10. slowdown: 8, k=4, length=600, bound=0.49 * 600 = 300

import time
import os
from numpy.core.numeric import indices
from torch.distributions.normal import Normal
from algorithms.utils import collect, mem_report
from algorithms.models import GaussianActor, GraphConvolutionalModel, MLP, CategoricalActor
from tqdm.std import trange
#from algorithms.algorithm import ReplayBuffer
from ray.state import actors
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import pickle
from copy import deepcopy as dp
from algorithms.models import CategoricalActor, EnsembledModel, SquashedGaussianActor, ParameterizedModel_MBPPO
import random
import multiprocessing as mp
from torch import distributed as dist
import argparse
from algorithms.algo.buffer import MultiCollect, DynamicMultiCollect, TrajectoryBad, TrajectoryBuffer, ModelBuffer

class CommunicationController(nn.Module):
    '''
    - Class name: CommunicationController
    - Proposition: Define the communication controller
    '''
    def __init__(self, n_agent, obs_dim, hidden_dim=64, device='cuda'):
        super().__init__()
        self.n_agent = n_agent
        self.obs_dim = obs_dim
        self.device = device
        self.softmax = nn.Softmax(dim=-1)

        # Actor: Each agent outputs communication decisions for all other agents, here 5 means corresponding to radius: 0~5
        self.actor = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.n_agent),
        )

        # Critic: Used to evaluate the value of communication policies
        self.critic = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        self.to(device)

    def forward(self, x):

        # add dimension judge
        if x.dim() == 2:
            x = x.unsqueeze(0)
        
        # shape: [batch_size, n_agent, n_agent]
        radius_logits = self.actor(x)
        radius_probs = self.softmax(radius_logits)

        # 全局状态均值作为 critic 输入
        value = self.critic(x.mean(dim=1))

        return radius_probs, value

    def get_action(self, s, deterministic=False):
        """
        s: [n_agent, obs_dim]
        Returns mask_logits, log_prob, entropy, value
        """
        s = torch.as_tensor(s, dtype=torch.float32, device=self.device)
        logits, value = self(s.unsqueeze(0))  # logits shape: [1, n_agent, 8]

        dist = Categorical(logits=logits)

        # judge
        if deterministic:
            action = dist.mode  # 取最大概率的动作
        else:
            action = dist.sample()  # 从分布中采样

        log_prob = dist.log_prob(action)
        entropy = dist.entropy().mean()

        return action.item(), log_prob.item(), entropy, value.item()


    def evaluate(self, s):
        with torch.no_grad():
            logits, value = self(s)
            dist = Categorical(logits=logits)
            log_probs = dist.log_prob(logits.argmax(dim=-1))
            entropy = dist.entropy().mean()
        return value, entropy, log_probs

class CPPOAgent_CommunicationConstraint(nn.ModuleList):
    """
    Everything in and out is torch Tensor.
    """
    def __init__(self, logger, device, agent_args, env_args, **kwargs):
        super().__init__()
        self.logger = logger
        self.device = device
        self.n_agent = agent_args.n_agent
        self.gamma = agent_args.gamma
        self.lamda = agent_args.lamda
        self.clip = agent_args.clip
        self.target_kl = agent_args.target_kl
        self.v_coeff = agent_args.v_coeff
        self.v_thres = agent_args.v_thres
        self.costv_thres = agent_args.costv_thres
        self.entropy_coeff = agent_args.entropy_coeff
        self.lr = agent_args.lr
        self.lr_v = agent_args.lr_v
        self.lr_c = agent_args.lr_c
        self.n_update_v = agent_args.n_update_v
        self.n_update_pi = agent_args.n_update_pi
        self.n_minibatch = agent_args.n_minibatch
        self.use_reduced_v = agent_args.use_reduced_v
        self.use_rtg = agent_args.use_rtg
        self.use_gae_returns = agent_args.use_gae_returns
        self.env_name = env_args.env
        self.algo_name = env_args.algo
        self.advantage_norm = agent_args.advantage_norm
        self.observation_dim = agent_args.observation_dim
        self.action_space = agent_args.action_space
        self.discrete = isinstance(agent_args.action_space, Discrete)
        self.lambda_p = agent_args.lambda_p

        if self.discrete:
            self.action_dim = self.action_space.n
            self.action_shape = self.action_dim
            
        else:
            self.action_shape = self.action_space.shape             
            self.action_dim = self.action_space.shape[0]
            self.squeeze = agent_args.squeeze
        
        self.adj = torch.as_tensor(agent_args.adj, device=self.device, dtype=torch.float)

        self.pi_args = agent_args.pi_args
        self.v_args = agent_args.v_args
        self.comm_kl_threshold = getattr(agent_args, 'comm_kl_threshold', 0.01)
        self.comm_cost_coeff = getattr(agent_args, 'comm_cost_coeff', 0.1)
        
        # Maximize the initialization range, then add a strategy to control the communication range
        self.init_start = True
        self.actors = self._init_actors(self.init_start, self.adj)[1]
        self.collect_v, self.vs, self.cost_vs = self._init_vs(self.init_start, self.adj)

        self.optimizer_pi = Adam(self.actors.parameters(), lr=self.lr)
        self.optimizer_v = Adam(self.vs.parameters(), lr=self.lr_v)
        self.optimizer_costv = Adam(self.cost_vs.parameters(), lr=self.lr_v)

        self.use_comm_controller = getattr(agent_args, 'use_comm_controller', False)
        if self.use_comm_controller:
            self.comm_controller = CommunicationController(
                n_agent=self.n_agent,
                obs_dim=self.observation_dim,
                device=self.device
            )

        # Communication policy optimizer
        self.optimizer_comm = Adam(self.comm_controller.parameters(), lr=self.lr)

        # init lagrangian parameters
        self.cost_bound = agent_args.cost_bound
        self.lagrangian_coef = agent_args.lagrangian_coef
        self.lamda_lagr = torch.nn.Parameter(torch.tensor(0.0, device=self.device))
    def act(self, s, if_test=False):
        """
        Function name: act
        Propose: 产生动作
        Inputs: Requires input of [batch_size, n_agent, dim] or [n_agent, dim].
        Note: This method is gradient-free. To get the gradient-enabled probability information, use get_logp(). Returns a distribution with the same dimensions of input.
        
        1. 根据通信策略选择通信范围
        2. 
        """
        # with 的意思是在使用完某些资源后自动进行清理操作（如关闭文件、释放锁、断开网络连接等），避免资源泄露
        with torch.no_grad():

            # 根据通信策略选择通信范围
            if self.use_comm_controller:

                # 产生通信范围参数
                s_tensor = torch.as_tensor(s, dtype=torch.float32, device=self.device)
                radius_logits, value = self.comm_controller(s_tensor)
                dist = Categorical(logits=radius_logits)

                radius_action = dist.sample()

                # 根据每个 agent 返回来得到的 radius，选择一个最大的用于共享的范围(这么做的原因在于 k-hop 需要依赖邻接矩阵相乘实现，所以只能牵就最大的那个，不行我们可以换各种策略实验一下)
                self.radius_values = radius_action.detach().cpu().numpy().astype(int)
                
                # 使用 numpy 计算众数
                dynamic_values, dynamic_counts = np.unique(self.radius_values, return_counts=True)
                dynamic_radius = int(dynamic_values[np.argmax(dynamic_counts)])
                
                # 当然，也可以用最大值，但是这个肯定是不准的。dynamic_radius = int(np.max(self.radius_values))
                dynamic_radius_tensor = torch.tensor(dynamic_radius, device=self.device)
                
                # 根据 radius 每次动态更新邻接矩阵
                # 20250528 改动: 这里改为1
                
                # dynamic_adj = torch.matrix_power(self.adj, dynamic_radius)
                dynamic_adj = torch.matrix_power(self.adj, dynamic_radius)

            else:
                dynamic_adj = None

            # 检查代码：保证中间的维度不要出错，无其他意义
            dim = s.dim()
            while s.dim() <= 2:
                s = s.unsqueeze(0) # 维度小于或等于 2，则在最前面添加一个维度（通过 unsqueeze(0)），可能是为了保证输入格式一致
            s = s.to(self.device)
            init_start = False
            self.collect_pi = DynamicMultiCollect(init_start, dynamic_adj, device=self.device, max_adj=self.max_adj)
            # 根据通信范围修正下层的策略, 聚合 index 节点下标, 输出之后相当于把相关联的个体观测都 gather 到一起了
            # 最好在这里就可以将状态进行变换回最初的结构
            s = self.collect_pi.dynamic_gather(s) # all state into [ self +  ]

            if self.discrete:

                # if_test = True, 直接选择最大概率的动作（one-hot 表示）；反之使用策略网络输出完整的动作概率
                # test 的时候不需要再进行探索了，直接增加产生 max 动作就行
                if if_test: 
                    probs = []
                    for i in range(self.n_agent):
                        # print("CPPO_communication_constraint.py / 234: ", dynamic_radius, len(s[i][0]), i )

                        # 根据第 i 个智能体取 argmax 动作的 index
                        index = torch.argmax(self.actors[i](s[i])[0])

                        # 构造一个大小为 [1, action_dim] 的张量action p；将最大概率动作位置设为 1，其余为 0
                        action_p = torch.tensor([[0]*self.action_dim]) 
                        action_p[0][index] = 1 
                        probs.append(action_p)
                    
                    # 压入到 prob 中
                    probs = torch.stack(probs, dim=1)

                    # 如果 `probs` 维度比 `dim` 高，就不断 `squeeze(0)` 去掉多余的 batch 维度
                    while probs.dim() > dim:
                        probs = probs.squeeze(0)
                    
                    # 返回需要增加一个可以控制动态范围的动作
                    return Categorical(probs), dist, dynamic_radius_tensor
                  
                else:
                    probs = []
                    for i in range(self.n_agent):
                        probs.append(self.actors[i](s[i])[0])
                    probs = torch.stack(probs, dim=0)
                    while probs.dim() > dim:
                        probs = probs.squeeze(0)
                    
                    # 返回需要增加一个可以控制动态范围的动作
                    return Categorical(probs), dist, dynamic_radius_tensor

            else:
                means, stds = [], []
                for i in range(self.n_agent):
                    mean, std = self.actors[i](s[i])
                    means.append(mean)
                    
                    stds.append(std)
                    # stds.append(std.exp())
                    
                means = torch.stack(means, dim=1)
                stds = torch.stack(stds, dim=1)
                while means.dim() > dim:
                    means = means.squeeze(0)
                    stds = stds.squeeze(0)
                return Normal(means, stds), dist, dynamic_radius_tensor
    
    def get_logp(self, s, k, a):
        """
        - function name: get_logp
        - propostion: get log probability of action a given state s
        - input: s, a
        - output: logp
        - Note: Requires input of [batch_size, n_agent, dim] or [n_agent, dim]. Returns a tensor whose dim() == 3.
        
        Step:
        1. 转换输入为 Tensor 并统一维度
        2. 使用通信控制器聚合观测信息
        3. 对每个智能体分别计算 log_prob
        4. 整理并返回结果
        """

        # 1. 转换输入为 Tensor 并统一维度
        s = torch.as_tensor(s, dtype=torch.float32, device=self.device)
        dim = s.dim()

        while s.dim() <= 2:
            s = s.unsqueeze(0)
            k = k.unsqueeze(0)
            a = a.unsqueeze(0)

        while k.dim() <= s.dim():
            k = k.unsqueeze(-1)

        while a.dim() < s.dim():
            a = a.unsqueeze(-1)
        
        # 产生一个通信范围设定器来控制通信范围
        # self.comm_controller 里面包含两个参数，一个是radius_p（），一个是 ciritc 
        probs_k = self.comm_controller(s)[0]
        log_probs_k = torch.log(probs_k)

        # 2. 使用通信控制器聚合观测信息
        s = self.collect_pi.dynamic_gather(s)
        # Now s[i].dim() == 2, a.dim() == 3

        # 3. 对每个智能体分别计算 log_prob
        log_prob = []
        for i in range(self.n_agent):
            if self.discrete:
                probs = self.actors[i](s[i])
                log_prob.append(torch.log(torch.gather(probs, dim=-1, index=torch.select(a, dim=1, index=i).long())))
            else:
                log_prob.append(self.actors[i](s[i], a.select(dim=1, index=i)))
        
        # 4. 整理并返回结果
        log_prob = torch.stack(log_prob, dim=1)
        while log_prob.dim() < 3:
            log_prob = log_prob.unsqueeze(-1)
        
        return log_prob, log_probs_k

    def get_logp_k(self, s, k):
        """
        - function name: get_logp_k
        - propostion: 计算 k 策略的概率
        - input: s, k
        - output: logp_k
        - Note: Requires input of [batch_size, n_agent, dim] or [n_agent, dim]. Returns a tensor whose dim() == 3.
        
        Step:
        1. 转换输入为 Tensor 并统一维度
        2. 使用通信控制器聚合观测信息
        3. 对每个智能体分别计算 log_prob
        4. 整理并返回结果
        """

        # 1. 转换输入为 Tensor 并统一维度
        s = torch.as_tensor(s, dtype=torch.float32, device=self.device)
        dim = s.dim()

        while s.dim() <= 2:
            s = s.unsqueeze(0)
            k = k.unsqueeze(0)
        while k.dim() < s.dim():
            k = k.unsqueeze(-1)
        
        # 2. 使用通信控制器聚合观测信息
        s = self.collect_pi.gather(s)
        # Now s[i].dim() == 2, a.dim() == 3

        # 3. 对每个智能体分别计算 log_prob
        log_prob = []
        for i in range(self.n_agent):
            if self.discrete:
                probs = self.comm_controller(s[i])
                log_prob.append(torch.log(torch.gather(probs, dim=-1, index=torch.select(k, dim=1, index=i).long())))
            else:
                log_prob.append(self.actors[i](s[i], k.select(dim=1, index=i)))
        
        # 4. 整理并返回结果
        log_prob = torch.stack(log_prob, dim=1)
        while log_prob.dim() < 3:
            log_prob = log_prob.unsqueeze(-1)
        
        return log_prob

    def updateAgent(self, trajs, clip=None):
        '''
        function name: updateAgent
        propose: 利用损失函数更新策略
        inputs: 轨迹trajs
        outputs: [r.mean().item(), loss_entropy.item(), max(kl_all)]
        
        1. 初始化相关参数
        2. 遍历每个轨迹，保证所有轨迹长度统一为 max_traj_length
        3. 合并所有轨迹便于后面批处理
        4. 循环更新每个个体
            4.1 计算 policy 的损失函数：
            4.2 计算 critic 的损失函数: 用 V
        '''

        # 1. 初始化相关参数
        time_t = time.time()
        if clip is None:
            clip = self.clip
        n_minibatch = self.n_minibatch

        names = TrajectoryBad.names()
        traj_all = {name:[] for name in names}
        max_traj_length = max([i.length for i in trajs])
        
        # 2. 遍历每个轨迹，保证所有轨迹长度统一为 max_traj_length
        for traj in trajs:
            for name in names:

                # 获取当前轨迹的形状
                tensor_shape = traj[name].shape

                # 完整长度 = [最大长度 - 当前长度] + tensor_shape[1:] 的其他内容
                full_part_shape = [max_traj_length - tensor_shape[0]] + list(tensor_shape[1:])

                # 如果已经结束了, 拼接全 1 向量，否则是全 0 向量, 沿着 dim=0 拼接
                if name == 'd':
                    traj_all[name].append(torch.cat([traj[name], torch.ones(full_part_shape, dtype=torch.bool, device=self.device)], dim=0))
                else:
                    traj_all[name].append(torch.cat([traj[name], torch.zeros(full_part_shape, dtype=traj[name].dtype, device=self.device)], dim=0))
        
        # 3. 合并所有轨迹便于后面批处理
        traj = {name:torch.stack(value, dim=0) for name, value in traj_all.items()}

        # 4. 循环更新每个个体, 这里的 traj 已经被处理为轨迹值了
        for i_update in range(self.n_update_pi):

            # 拿出来轨迹信息, 放在 GPU 上处理, 这里已经知道了是带约束的方法，所以不用做判断
            s, a, r, c, k, s1, d, logp, logp_k = traj['s'], traj['a'], traj['r'], traj['c'], traj['k'], traj['s1'], traj['d'], traj['logp'], traj['logp_k']
            s, a, r, c, k, s1, d, logp, logp_k = [item.to(self.device) for item in [s, a, r, c, k, s1, d, logp, logp_k]]

            # all in shape [batch_size, T, n_agent, dim]
            # 计算优势函数
            value_old, returns, advantages, reduced_advantages, cost_value_old, cost_returns, cost_advantages, cost_reduced_advantages = self._process_traj(**traj)
            advantages_old = reduced_advantages if self.use_reduced_v else advantages
            cost_advantages_old = cost_reduced_advantages if self.use_reduced_v else cost_advantages

            # 重新调整数据形状
            b, T, n, d_s = s.size()
            d_a = a.size()[-1]
            d_k = k.size()[-1]  
            s = s.view(-1, n, d_s)
            a = a.view(-1, n, d_a)
            k = k.view(-1, n, d_k)
            logp = logp.view(-1, n, d_a)
            logp_k = logp_k.view(-1, n, d_k)
            advantages_old = advantages_old.view(-1, n, 1)
            cost_advantages_old = cost_advantages_old.view(-1, n, 1)
            returns = returns.view(-1, n, 1)
            cost_returns = cost_returns.view(-1, n, 1)
            value_old = value_old.view(-1, n, 1)
            cost_value_old = cost_value_old.view(-1, n, 1)

            # 到这里为止, s, a, logp, adv, ret, v are now all in shape [-1, n_agent, dim]
            batch_total = logp.size()[0]
            batch_size = int(batch_total/n_minibatch)
            kl_all = []
            kl_all_k = []
            i_pi = 0

            # 4.1 计算 policy 的损失函数
            for i_pi in range(1):  

                # 计算策略更新的批量数据
                batch_state, batch_action, batch_k, batch_logp, batch_logp_k, batch_advantages_old, batch_cost_advantages_old = [s, a, k, logp, logp_k, advantages_old, cost_advantages_old]

                # 如果有多个 minibatch, 则随机选一个
                if n_minibatch > 1:
                    idxs = torch.randperm(range(batch_total), size=batch_size, replace=False)
                    [batch_state, batch_action, batch_k, batch_logp, batch_logp_k, batch_advantages_old, batch_cost_advantages_old] = [item[idxs] for item in [batch_state, batch_action, batch_k, batch_logp, batch_logp_k, batch_advantages_old, batch_cost_advantages_old]]
                
                # 定义相关 lamda 需要的内容
                aver_episode_cost = c.sum(dim = 1)
                imp_weights = torch.ones_like(batch_advantages_old)
                cost_adv_targ = batch_cost_advantages_old

                # 计算下层动作和上层约束的动作概率( log\pi(a|s));
                batch_logp_new, batch_logp_k_new = self.get_logp(batch_state, batch_k, batch_action)

                # 计算 KL 散度。首先新旧对数概率差异; 其次计算 KL 散度; 最后计算二者比率
                # 新旧对数概率差异
                logp_diff = batch_logp_new - batch_logp
                logp_k_diff = batch_logp_k_new - batch_logp_k

                # 计算 KL 散度
                kl = logp_diff.mean()
                kl_k = logp_k_diff.mean()

                # 二者比率
                ratio = torch.exp(batch_logp_new - batch_logp)
                ratio_k = torch.exp(batch_logp_k_new - batch_logp_k)
                
                # 这里采用的是两个策略，因此需要将联合策略概率乘起来
                # NOTE: 这个乘起来 clip 的数值也要变一下，不然会影响性能
                # ratio = ratio * ratio_k

                # Primal-Dual Loss with lamda_lagr
                delta_lamda_lagr = -((aver_episode_cost.mean() - self.cost_bound) * (1 - self.gamma) + (imp_weights * cost_adv_targ)).mean().detach()
                new_lamda_lagr_value = self.lamda_lagr - (delta_lamda_lagr * self.lagrangian_coef)
                new_lamda_lagr = torch.relu(new_lamda_lagr_value)
                self.lamda_lagr.data = new_lamda_lagr.data 

                surr1 = ratio * (batch_advantages_old - new_lamda_lagr.detach() * batch_cost_advantages_old)  # surr1 是原始目标函数
                surr2 = ratio.clamp(1 - clip, 1 + clip) * (batch_advantages_old - new_lamda_lagr.detach() * batch_cost_advantages_old) # surr2 是在 ratio 上进行裁剪的目标函数，clip 是裁剪的阈值，用于防止过大更新
                
                loss_surr = torch.min(surr1, surr2).mean()

                # 这里是熵正则项, 用于鼓励智能体探索动作
                loss_entropy = - torch.mean(batch_logp_new)
                loss_entropy_k = - torch.mean(batch_logp_k_new)

                surr1_k = ratio_k * batch_cost_advantages_old
                surr2_k = ratio_k.clamp(1 - clip, 1 + clip) * batch_cost_advantages_old

                # loss comm 表示 上层行为策略损失函数
                loss_surr_k = torch.min(surr1_k, surr2_k).mean()
                loss_comm = -loss_surr_k - self.entropy_coeff * loss_entropy_k 

                # loss_pi 表示下层行为策略损失函数
                loss_pi = - loss_surr - self.entropy_coeff * loss_entropy

                # Actor 反向传播更新
                self.optimizer_pi.zero_grad()
                self.optimizer_comm.zero_grad()

                loss_pi.backward()
                loss_comm.backward()
                

                self.optimizer_pi.step()
                self.optimizer_comm.step()

                self.logger.log(surr_loss = loss_surr, surr_k_loss = loss_surr_k, loss_comm = loss_comm, entropy = loss_entropy, kl_divergence = kl, pi_update=None)
                kl_all.append(kl.abs().item())
                kl_all_k.append(kl_k.abs().item())

                # 当新旧策略之间的 KL 散度超过这个值时，停止进一步更新，防止策略突变
                if self.target_kl is not None and kl.abs() > 1.5 * self.target_kl:
                    break

            self.logger.log(pi_update_step=i_update)

            # 4.2 计算 critic 的损失函数
            for i_v in range(1):
                batch_returns = returns
                batch_cost_returns = cost_returns

                batch_state = s
                if n_minibatch > 1:
                    idxs = np.random.randint(0, len(batch_total), size=batch_size)
                    [batch_returns, batch_state] = [item[idxs] for item in [batch_returns, batch_state]]
                    [batch_cost_returns, batch_state] = [item[idxs] for item in [batch_cost_returns, batch_state]]
                
                # 更新状态 V. 这里需要两个 V: 一个是 V 另外一个是 costv
                batch_v_new, batch_cost_v_new = self._evalV(batch_state)

                loss_v = ((batch_v_new - batch_returns) ** 2).mean()
                loss_cost_v = ((batch_cost_v_new - batch_cost_returns) ** 2).mean()

                # Critic 反向传播更新
                self.optimizer_v.zero_grad()
                self.optimizer_costv.zero_grad()
                
                loss_v.backward()
                loss_cost_v.backward()

                self.optimizer_v.step()
                self.optimizer_costv.step()
            
                # 计算 critic 的方差,用于衡量回报的波动程度
                var_v = ((batch_returns - batch_returns.mean()) ** 2).mean()
                var_cost_v = ((batch_cost_returns - batch_cost_returns.mean()) ** 2).mean()

                # 计算相对损失 rel_v_loss，即价值损失与回报方差的比值,表示损失相对于回报的波动程度
                # Ps: 如果 loss_v 较大而 var_v 较小，可能表示网络有较大错误
                rel_v_loss = loss_v / (var_v + 1e-8) # 加上一个小常数 1e-8 来避免除以零
                rel_cost_v_loss = loss_cost_v / (var_cost_v + 1e-8)

                self.logger.log(v_loss=loss_v, v_update=None, v_var=var_v, rel_v_loss=rel_v_loss, rel_cost_v_loss = rel_cost_v_loss)
                
                # v_thres 是价值函数更新的提前终止阈值（threshold for value function update），用于控制 critic 网络更新的次数
                if rel_v_loss < self.v_thres:
                    break
                if rel_cost_v_loss < self.costv_thres:
                    break

            self.logger.log(v_update_step=i_update, update=None, reward=r.mean().item(), value=value_old.mean().item(), clip=clip, returns=returns.mean().item(), advantages=advantages_old.abs().mean().item())
        self.logger.log(agent_update_time=time.time()-time_t)

        return [r.mean().item(), loss_entropy.item(), max(kl_all), c.mean().item(), loss_entropy_k.item(), max(kl_all_k)]
    def checkConverged(self, ls_info):
        return False

    def save(self, info=None):
        self.logger.save(self, info=info)

    def load(self, state_dict):
        self.load_state_dict(state_dict[self.logger.prefix])

#_____________________________________________________________________________________________________--

    def save_nets(self, dir_name,episode):
        if not os.path.exists(dir_name + '/Models'):
            os.mkdir(dir_name + '/Models')
        # torch.save(self.critic.state_dict(), dir_name + '/Models/' +str(episode)+ 'best_critic.pt')
        torch.save(self.actors.state_dict(), dir_name + '/Models/' + str(episode)+ 'best_actor.pt')
        # torch.save(self.actors.state_dict(), dir_name + '/' +str(episode)+ 'best_actor.pt')
        print('RL saved successfully')

    def load_nets(self, dir_name,episode):       
        self.actors.load_state_dict(torch.load(dir_name + '/Models/' + str(episode)+ 'best_actor.pt'))
#_____________________________________________________________________________________________________--
        

    def _evalV(self, s):
        '''
        - Function name:   _evalV
        - Proposition: 评估状态 s 的价值，用于 critic 以及 cost_critic 的前向传播
        - Input: 状态 s, input in shape [-1, n_agent, dim]
        - Output: 状态 s 的价值以及 cost 价值

        Step:
        1. 输入当前的 s
        2. 聚合邻接信息
        3. 输入到网络中，得到 q_value
        '''
 
        # 1. 输入当前的 s
        s = s.to(self.device)
        
        # 2. 聚合邻接信息
        s = self.collect_v.dynamic_gather(s)

        # 3. 对每个智能体分别计算 value
        values = []
        cost_values = []
        for i in range(self.n_agent):
            values.append(self.vs[i](s[i]))
            cost_values.append(self.cost_vs[i](s[i]))
        
        return torch.stack(values, dim=1), torch.stack(cost_values, dim=1)
    
    def _init_actors(self, init_start, adj=None):
        '''
        - Function name: init_actor
        - Proposition: 初始化actor,
        - inputs: adj
        - outputs: actor
        - Note: 
            1. torch.matrix_power(self.adj, self.radius_pi) 定义了智能体通信范围，默认 radius = 1
        
        - Step:
            1. 检查环境名称是否为 'UAV_9d'，并且算法名称是否为 'CPPO'
            2. (new) 使用外部传入的邻接矩阵 (若存在)
            3. 智能体邻域大小：使用 torch.matrix_power 计算邻接矩阵 adj 的 radius_pi 次幂。这一步是为了扩展邻域范围, radius_pi 表示邻域的半径
            4. 初始化模块列表：创建一个 nn.ModuleList, 用于存储每个智能体的行为者
        '''

        # 检查环境名称是否为 'UAV_9d'，并且算法名称是否为 'CPPO'
        # 如果条件满足，则初始化一个大小为 25×25 的邻接矩阵 adj，所有元素值为 1
        if self.env_name == 'UAV_9d' and self.algo_name == 'CPPO':
            self.adj = torch.as_tensor(np.ones((25, 25)), device=self.device, dtype=torch.float)

        # 最大化矩阵,一定要保证 max_radius > [1, 28] 才能覆盖整个范围; 我这里强制让初始化的邻接矩阵全部设置为 1
        self.max_radius = 1
        self.max_adj = torch.matrix_power(self.adj, self.max_radius).to(self.device)
        self.max_adj = torch.ones_like(self.max_adj, dtype=torch.bool)

        # 智能体邻域大小：使用 torch.matrix_power 计算邻接矩阵 adj 的 radius_pi 次幂。这一步是为了扩展邻域范围，radius_pi 表示邻域的半径
        collect_pi = DynamicMultiCollect(init_start, self.max_adj, device=self.device) # 定义了邻接矩阵类， size is [28,1] in monaco
        
        # 初始化模块列表, 循环遍历智能体: 对每个智能体(self.n_agent 表示智能体的数量) 进行操作。
        actors = nn.ModuleList()
        for i in range(self.n_agent):
            
            # 确定输入维度 self.pi_args：collect_pi.degree[i]: 邻居个数(包括自己), self.pi_args.sizes 观测的维度
            # pi_args 的设定请查看 config 
            # e.g: env = monaco, observation.dim=22, monaco-self.discrete
            
            # print("CPPO_commconst.py/613行, 邻居智能体个数，观测维度：", collect_pi.degree[i], self.observation_dim)
            self.pi_args.sizes[0] = collect_pi.degree[i] * self.observation_dim
            # print("CPPO_commconstrint.py line 690, the self.pi_args.size is ", i, self.pi_args.sizes[0])
            if self.discrete:
                actors.append(CategoricalActor(**self.pi_args._toDict()).to(self.device))
            else:
                actors.append(GaussianActor(action_dim=self.action_dim, **self.pi_args._toDict()).to(self.device))
        
        return collect_pi, actors
    
    def _init_vs(self, init_start, adj=None):
        '''
        - Function name: init_vs
        - Proposition: 初始化 value function
        - Input: adj: 邻接矩阵
        - Output: collect_v(定义v的类), vs 表示每个智能体初始化后的网络结构
        - Note
            1. v 的 size 是 obsevation_dim; 通过对邻接矩阵 self.adj 进行幂运算
            2. vs 不用变, s 的可变长度大小使用 mask 就可以实现
        
        Step:
        1. 如果环境名称为 'UAV_9d'，则初始化一个大小为 25×25 的邻接矩阵 adj, 所有元素值为 1
        2. 定义聚合的类
        3. 构建价值函数集合；遍历所有智能体：根据每个智能体的邻域大小（collect_v.degree[i]）和观测维度（self.observation_dim），调整价值函数的输入维度
           从 self.v_args 中获取价值函数的网络结构
        '''

        # 1. 如果环境名称为 'UAV_9d'，则初始化一个大小为 25×25 的邻接矩阵 adj，所有元素值为 1
        if self.env_name == 'UAV_9d':
            self.adj = torch.as_tensor(np.ones((25, 25)), device=self.device, dtype=torch.float)
        
        # 最大化矩阵,一定要保证 max_radius > [1, 28] 才能覆盖整个范围
        self.max_radius = 1
        self.max_adj = torch.matrix_power(self.adj, self.max_radius).to(self.device)
        self.max_adj = torch.ones_like(self.max_adj, dtype=torch.bool)

        # 智能体邻域大小：使用 torch.matrix_power 计算邻接矩阵 adj 的 radius_pi 次幂。这一步是为了扩展邻域范围，radius_pi 表示邻域的半径
        collect_v = DynamicMultiCollect(init_start, self.max_adj, device=self.device) # 定义了邻接矩阵类， size is [28,1] in monaco

        # 2. 定义聚合的类
        # dynamic_adj = adj.to(self.device)
        # collect_v = DynamicMultiCollect(dynamic_adj, device=self.device)
        
        # 3. 构建价值函数集合；遍历所有智能体：根据每个智能体的邻域大小（collect_v.degree[i]）和观测维度（self.observation_dim），调整价值函数的输入维度
        # vs 是一个 nn.ModuleList，包含 每个智能体对应的价值函数网络（critic network）
        vs = nn.ModuleList()
        cost_vs = nn.ModuleList()

        for i in range(self.n_agent):
            self.v_args.sizes[0] = collect_v.degree[i] * self.observation_dim
            v_fn = self.v_args.network

            # 从 self.v_args 中获取价值函数的网络结构
            vs.append(v_fn(**self.v_args._toDict()).to(self.device))
            cost_vs.append(v_fn(**self.v_args._toDict()).to(self.device))

        return collect_v, vs, cost_vs
    
    def _process_traj(self, s, a, r, c, k, s1, d, logp, logp_k):
        """
        Function name: process_traj
        Propose: 利用轨迹信息计算优势函数
        Input are all in shape [batch_size, T, n_agent, dim]
        Output: value.detach(), returns, advantages.detach(), reduced_advantages.detach()

        Step:
        1. 初始化相关参数, 状态价值, return, deltas, advantage, 前一个时刻的状态值（这里为了利用 GAE 来计算优势）
        2. GAE 计算优势
        3. 计算所有智能体的 advantage function
        4. 如果使用优势函数, 则对优势进行归一化处理,使其均值为0,标准差为1
        """

        # 1.1 初始化相关参数, 状态价值, return, deltas, advantage, 前一个时刻的状态值（这里为了利用 GAE 来计算优势）
        b, T, n, dim_s = s.shape
        s, a, r, c, k, s1, d, logp, logp_k = [item.to(self.device) for item in [s, a, r, c, k, s1, d, logp, logp_k]]
        value, cost_value = self._evalV(s.view(-1, n, dim_s))
        value = value.view(b, T, n, -1)
        cost_value = cost_value.view(b, T, n, -1)
        
        returns = torch.zeros(value.size(), device=self.device)
        cost_returns = torch.zeros(cost_value.size(), device=self.device)

        deltas, advantages = torch.zeros_like(returns), torch.zeros_like(returns)
        cost_deltas, cost_advantages = torch.zeros_like(cost_returns), torch.zeros_like(cost_returns)

        prev_value, prev_cost_value = self._evalV(s1.select(1, T - 1))

        if not self.use_rtg:
            prev_return = prev_value
            prev_cost_return = prev_cost_value
        else:
            prev_return = torch.zeros_like(prev_value)
            prev_cost_return = torch.zeros_like(prev_cost_value)
        
        prev_advantage = torch.zeros_like(prev_return)
        prev_cost_advantage = torch.zeros_like(prev_cost_return)

        d_mask = d.float()

        # 2. GAE 计算优势
        for t in reversed(range(T)):
            deltas[:, t, :, :]= r.select(1, t) + self.gamma * (1-d_mask.select(1,t)) * prev_value - value.select(1, t).detach()
            cost_deltas[:, t, :, :]= c.select(1, t) + self.gamma * (1-d_mask.select(1,t)) * prev_cost_value - cost_value.select(1, t).detach()

            advantages[:, t, :, :] = deltas.select(1, t) + self.gamma * self.lamda * (1-d_mask.select(1,t)) * prev_advantage
            cost_advantages[:, t, :, :] = cost_deltas.select(1, t) + self.gamma * self.lamda * (1-d_mask.select(1,t)) * prev_cost_advantage

            if self.use_gae_returns:
                returns[:, t, :, :] = value.select(1, t).detach() + advantages.select(1, t)
                cost_returns[:, t, :, :] = cost_value.select(1, t).detach() + cost_advantages.select(1, t)
            else:
                returns[:, t, :, :] = r.select(1, t) + self.gamma * (1-d_mask.select(1, t)) * prev_return
                cost_returns[:, t, :, :] = c.select(1, t) + self.gamma * (1-d_mask.select(1, t)) * prev_cost_return

            prev_return = returns.select(1, t)
            prev_value = value.select(1, t)
            prev_advantage = advantages.select(1, t)

            prev_cost_return = cost_returns.select(1, t)
            prev_cost_value = cost_value.select(1, t)
            prev_cost_advantage = cost_advantages.select(1, t)
        
        # 3. 计算所有智能体的 advantage function 和 cost advantage function
        reduced_advantages = self.collect_v.reduce_sum(advantages.view(-1, n, 1)).view(advantages.size())
        cost_reduced_advantages = self.collect_v.reduce_sum(cost_advantages.view(-1, n, 1)).view(cost_advantages.size())

        # 4. 如果使用优势函数，则对优势进行归一化处理，使其均值为0，标准差为1
        if self.advantage_norm and reduced_advantages.size()[1] > 1:
            reduced_advantages = (reduced_advantages - reduced_advantages.mean(dim=1, keepdim=True)) / (reduced_advantages.std(dim=1, keepdim=True) + 1e-5)
            advantages = (advantages - advantages.mean(dim=1, keepdim=True)) / (advantages.std(dim=1, keepdim=True) + 1e-5)

            cost_reduced_advantages = (cost_reduced_advantages - cost_reduced_advantages.mean(dim=1, keepdim=True)) / (cost_reduced_advantages.std(dim=1, keepdim=True) + 1e-5)
            cost_advantages = (cost_advantages - cost_advantages.mean(dim=1, keepdim=True)) / (cost_advantages.std(dim=1, keepdim=True) + 1e-5)
        
        return value.detach(), returns, advantages.detach(), reduced_advantages.detach(), cost_value.detach(), cost_returns, cost_advantages.detach(), cost_reduced_advantages.detach()