# -*- coding: utf-8 -*-
"""
Created on Tue Sep  6 02:04:27 2022

@author: 86153
"""

import time
import os
from numpy.core.numeric import indices
from torch.distributions.normal import Normal
from algorithms.utils import collect, mem_report
from algorithms.models import GaussianActor, GraphConvolutionalModel, MLP, CategoricalActor
from tqdm.std import trange
#from algorithms.algorithm import ReplayBuffer
from ray.state import actors
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import pickle
from copy import deepcopy as dp
from algorithms.models import CategoricalActor, EnsembledModel, SquashedGaussianActor, ParameterizedModel_MBPPO
import random
import multiprocessing as mp
from torch import distributed as dist
import argparse
from algorithms.algo.buffer import MultiCollect,Trajectory,TrajectoryBuffer,ModelBuffer


class CommunicationController(nn.Module):
    """
    Controller to select communication range for critic
    """
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),  # 0: local observation, 1: full state
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x):
        return self.net(x)


class CPPOAgent_AdaTwoState(nn.ModuleList):
    """
    Everything in and out is torch Tensor.
    """
    def __init__(self, logger, device, agent_args, env_args, **kwargs):
        super().__init__()
        self.logger = logger
        self.device = device
        self.n_agent = agent_args.n_agent
        self.gamma = agent_args.gamma
        self.lamda = agent_args.lamda
        self.clip = agent_args.clip
        self.target_kl = agent_args.target_kl
        self.v_coeff = agent_args.v_coeff
        self.v_thres = agent_args.v_thres
        self.entropy_coeff = agent_args.entropy_coeff
        self.lr = agent_args.lr
        self.lr_v = agent_args.lr_v
        self.n_update_v = agent_args.n_update_v
        self.n_update_pi = agent_args.n_update_pi
        self.n_minibatch = agent_args.n_minibatch
        self.use_reduced_v = agent_args.use_reduced_v
        self.use_rtg = agent_args.use_rtg
        self.use_gae_returns = agent_args.use_gae_returns
        self.env_name = env_args.env
        self.algo_name = env_args.algo
        self.advantage_norm = agent_args.advantage_norm
        self.observation_dim = agent_args.observation_dim
        self.action_space = agent_args.action_space
        self.discrete = isinstance(agent_args.action_space, Discrete)       
        if self.discrete:
            self.action_dim = self.action_space.n
            self.action_shape = self.action_dim
            
        else:
            self.action_shape = self.action_space.shape             
            self.action_dim = self.action_space.shape[0]
            self.squeeze = agent_args.squeeze
                 
        self.adj = torch.as_tensor(agent_args.adj, device=self.device, dtype=torch.float)
        self.radius_v = agent_args.radius_v
        self.radius_pi = agent_args.radius_pi
        self.pi_args = agent_args.pi_args
        self.v_args = agent_args.v_args
        self.collect_pi, self.actors = self._init_actors()
        self.collect_v, self.vs = self._init_vs()

        # Initialize communication controller for adaptive range selection
        self.comm_controller = CommunicationController(self.observation_dim * self.n_agent).to(self.device)
        self.collect_v_local, self.vs_local = self._init_vs_local()  # For V(o) - local observations only
        
        #self.actors.load_state_dict(torch.load(test_actors_model, map_location={'cuda:0':'cuda:0','cuda:1':'cuda:0','cuda:2':'cuda:0','cuda:3':'cuda:0','cuda:4':'cuda:0','cuda:5':'cuda:0','cuda:0':'cpu','cuda:1':'cpu','cuda:2':'cpu','cuda:3':'cpu','cuda:4':'cpu','cuda:5':'cpu',})) 

        self.optimizer_v = Adam(list(self.vs.parameters()) + list(self.vs_local.parameters()) + list(self.comm_controller.parameters()), lr=self.lr_v)
        self.optimizer_pi = Adam(self.actors.parameters(), lr=self.lr)

    def act(self, s,if_test=False, requires_log=False):
        """
        Function name: act
        Propose: 产生动作
        Inputs: Requires input of [batch_size, n_agent, dim] or [n_agent, dim].
        Note: This method is gradient-free. To get the gradient-enabled probability information, use get_logp(). Returns a distribution with the same dimensions of input.
        
        """
        # with 的意思是在使用完某些资源后自动进行清理操作（如关闭文件、释放锁、断开网络连接等），避免资源泄露
        with torch.no_grad():

            dim = s.dim()
            while s.dim() <= 2:
                s = s.unsqueeze(0) # 维度小于或等于 2，则在最前面添加一个维度（通过 unsqueeze(0)），可能是为了保证输入格式一致
            s = s.to(self.device)
            

            s = self.collect_pi.gather(s) # all state into [ self +  ]
            # Now s[i].dim() == 2 ([batch_size, dim])

            if self.discrete:

                # if_test = True, 直接选择最大概率的动作（one-hot 表示）；反之使用策略网络输出完整的动作概率
                # test 的时候不需要再进行探索了，直接增加产生 max 动作就行
                if if_test: 
                    probs = []
                    for i in range(self.n_agent):
                        
                        # 根据第 i 个智能体取 argmax 动作的 index
                        
                        index = torch.argmax(self.actors[i](s[i])[0])

                        # 构造一个大小为 [1, action_dim] 的张量action p；将最大概率动作位置设为 1，其余为 0
                        action_p = torch.tensor([[0]*self.action_dim]) 
                        action_p[0][index] = 1 
                        probs.append(action_p)
                    
                    # 压入到 prob 中
                    probs = torch.stack(probs, dim=1)

                    # 如果 `probs` 维度比 `dim` 高，就不断 `squeeze(0)` 去掉多余的 batch 维度
                    while probs.dim() > dim:
                        probs = probs.squeeze(0)
                    
                    return Categorical(probs)
                  
                else:
                    probs = []
                    for i in range(self.n_agent):
                        probs.append(self.actors[i](s[i]))
                    probs = torch.stack(probs, dim=1)
                    while probs.dim() > dim:
                        probs = probs.squeeze(0)
                    return Categorical(probs)

            else:
                means, stds = [], []
                for i in range(self.n_agent):
                    mean, std = self.actors[i](s[i])
                    means.append(mean)
                    

                    stds.append(std)
                    # stds.append(std.exp())
                    
                    
                    
                means = torch.stack(means, dim=1)
                stds = torch.stack(stds, dim=1)
                while means.dim() > dim:
                    means = means.squeeze(0)
                    stds = stds.squeeze(0)
                return Normal(means, stds)
    
    def get_logp(self, s, a):
        """
        Requires input of [batch_size, n_agent, dim] or [n_agent, dim].
        Returns a tensor whose dim() == 3.
        """
        s = torch.as_tensor(s, dtype=torch.float32, device=self.device)
        dim = s.dim()

        while s.dim() <= 2:
            s = s.unsqueeze(0)
            a = a.unsqueeze(0)
        while a.dim() < s.dim():
            a = a.unsqueeze(-1)
            
        s = self.collect_pi.gather(s)
        # Now s[i].dim() == 2, a.dim() == 3
        log_prob = []
        for i in range(self.n_agent):
            if self.discrete:
                probs = self.actors[i](s[i])
                log_prob.append(torch.log(torch.gather(probs, dim=-1, index=torch.select(a, dim=1, index=i).long())))
            else:
                log_prob.append(self.actors[i](s[i], a.select(dim=1, index=i)))
        log_prob = torch.stack(log_prob, dim=1)
        while log_prob.dim() < 3:
            log_prob = log_prob.unsqueeze(-1)
        return log_prob

    def updateAgent(self, trajs, clip=None):
        '''
        function name: updateAgent
        propose: 利用损失函数更新策略
        inputs: 轨迹trajs
        outputs: [r.mean().item(), loss_entropy.item(), max(kl_all)]
        
        1. 初始化相关参数
        2. 遍历每个轨迹，保证所有轨迹长度统一为 max_traj_length
        3. 合并所有轨迹便于后面批处理
        4. 循环更新每个个体
            4.1 计算 policy 的损失函数：
            4.2 计算 critic 的损失函数：用的是V
        '''

        # 1. 初始化相关参数
        time_t = time.time()
        if clip is None:
            clip = self.clip
        n_minibatch = self.n_minibatch

        names = Trajectory.names()
        traj_all = {name:[] for name in names}
        max_traj_length = max([i.length for i in trajs])
        
        # 2. 遍历每个轨迹，保证所有轨迹长度统一为 max_traj_length
        for traj in trajs:
            for name in names:

                # 获取当前轨迹的形状
                tensor_shape = traj[name].shape

                # 完整长度 = [最大长度 - 当前长度] + tensor_shape[1:] 的其他内容
                full_part_shape = [max_traj_length - tensor_shape[0]] + list(tensor_shape[1:])

                # 如果已经结束了, 拼接全 1 向量，否则是全 0 向量, 沿着 dim=0 拼接
                if name == 'd':
                    traj_all[name].append(torch.cat([traj[name], torch.ones(full_part_shape, dtype=torch.bool, device=self.device)], dim=0))
                else:
                    traj_all[name].append(torch.cat([traj[name], torch.zeros(full_part_shape, dtype=traj[name].dtype, device=self.device)], dim=0))
        
        # 3. 合并所有轨迹便于后面批处理
        traj = {name:torch.stack(value, dim=0) for name, value in traj_all.items()}

        # 4. 循环更新每个个体
        for i_update in range(self.n_update_pi):

            # 拿出来轨迹信息, 放在 GPU 上处理
            s, a, r, s1, d, logp = traj['s'], traj['a'], traj['r'], traj['s1'], traj['d'], traj['logp']
            s, a, r, s1, d, logp = [item.to(self.device) for item in [s, a, r, s1, d, logp]]
            
            # all in shape [batch_size, T, n_agent, dim]
            # 计算优势函数
            value_old, returns, advantages, reduced_advantages = self._process_traj(**traj)
            advantages_old = reduced_advantages if self.use_reduced_v else advantages

            # 重新调整数据形状
            b, T, n, d_s = s.size()
            d_a = a.size()[-1]  
            s = s.view(-1, n, d_s)
            a = a.view(-1, n, d_a)
            logp = logp.view(-1, n, d_a)    
            advantages_old = advantages_old.view(-1, n, 1)
            returns = returns.view(-1, n, 1)
            value_old = value_old.view(-1, n, 1)
            # 到这里为止, s, a, logp, adv, ret, v are now all in shape [-1, n_agent, dim]

            batch_total = logp.size()[0]
            batch_size = int(batch_total/n_minibatch)
            kl_all = []
            i_pi = 0

            # 4.1 计算 policy 的损失函数
            for i_pi in range(1):  

                # 计算策略更新的批量数据
                batch_state, batch_action, batch_logp, batch_advantages_old = [s, a, logp, advantages_old]

                # 如果有多个 minibatch, 则随机选一个
                if n_minibatch > 1:
                    idxs = np.random.choice(range(batch_total), size=batch_size, replace=False)
                    [batch_state, batch_action, batch_logp, batch_advantages_old] = [item[idxs] for item in [batch_state, batch_action, batch_logp, batch_advantages_old]]
                
                # 计算动作的对数概率，以及新旧对数概率差异和 KL 散度以及二者比率
                batch_logp_new = self.get_logp(batch_state, batch_action)            
                logp_diff = batch_logp_new - batch_logp
                kl = logp_diff.mean()
                ratio = torch.exp(batch_logp_new - batch_logp)

                # 计算 loss 函数
                surr1 = ratio * batch_advantages_old  # surr1 是原始目标函数
                surr2 = ratio.clamp(1 - clip, 1 + clip) * batch_advantages_old # surr2 是在 ratio 上进行裁剪的目标函数，clip 是裁剪的阈值，用于防止过大更新
                loss_surr = torch.min(surr1, surr2).mean()
                loss_entropy = - torch.mean(batch_logp_new)
                
                # 计算总损失loss_pi，策略损失 + 熵损失的加权和
                loss_pi = - loss_surr - self.entropy_coeff * loss_entropy

                # 反向传播
                self.optimizer_pi.zero_grad()
                loss_pi.backward()
                self.optimizer_pi.step()

                self.logger.log(surr_loss = loss_surr, entropy = loss_entropy, kl_divergence = kl, pi_update=None)
                kl_all.append(kl.abs().item())
                if self.target_kl is not None and kl.abs() > 1.5 * self.target_kl:
                    break
            self.logger.log(pi_update_step=i_update)

            # 4.2 计算 critic 的损失函数
            for i_v in range(1):
                batch_returns = returns
                batch_state = s
                if n_minibatch > 1:
                    idxs = np.random.randint(0, len(batch_total), size=batch_size)
                    [batch_returns, batch_state] = [item[idxs] for item in [batch_returns, batch_state]]
                
                # 计算新的状态 V 值以及损失函数
                batch_v_new = self._evalV(batch_state)
                loss_v = ((batch_v_new - batch_returns) ** 2).mean()

                # 梯度更新优化    
                self.optimizer_v.zero_grad()
                loss_v.backward()
                self.optimizer_v.step()
            
                # 计算 critic 的方差,用于衡量回报的波动程度
                var_v = ((batch_returns - batch_returns.mean()) ** 2).mean()

                # 计算相对损失 rel_v_loss，即价值损失与回报方差的比值,表示损失相对于回报的波动程度
                # Ps: 如果 loss_v 较大而 var_v 较小，可能表示网络有较大错误
                rel_v_loss = loss_v / (var_v + 1e-8) # 加上一个小常数 1e-8 来避免除以零

                self.logger.log(v_loss=loss_v, v_update=None, v_var=var_v, rel_v_loss=rel_v_loss)
                if rel_v_loss < self.v_thres:
                    break
            self.logger.log(v_update_step=i_update)
            self.logger.log(update=None, reward=r.mean().item(), value=value_old.mean().item(), clip=clip, returns=returns.mean().item(), advantages=advantages_old.abs().mean().item())
        self.logger.log(agent_update_time=time.time()-time_t)

        return [r.mean().item(), loss_entropy.item(), max(kl_all)]
    
    def checkConverged(self, ls_info):
        return False

    def save(self, info=None):
        self.logger.save(self, info=info)

    def load(self, state_dict):
        self.load_state_dict(state_dict[self.logger.prefix])
        
#_____________________________________________________________________________________________________--

    def save_nets(self, dir_name,episode):
        if not os.path.exists(dir_name + '/Models'):
            os.mkdir(dir_name + '/Models')
        # torch.save(self.critic.state_dict(), dir_name + '/Models/' +str(episode)+ 'best_critic.pt')
        torch.save(self.actors.state_dict(), dir_name + '/Models/' + str(episode)+ 'best_actor.pt')
        # torch.save(self.actors.state_dict(), dir_name + '/' +str(episode)+ 'best_actor.pt')
        print('RL saved successfully')

    def load_nets(self, dir_name,episode):       
        self.actors.load_state_dict(torch.load(dir_name + '/Models/' + str(episode)+ 'best_actor.pt'))
#_____________________________________________________________________________________________________--
        

    def _evalV(self, s):
        # Requires input in shape [-1, n_agent, dim]
        s = s.to(self.device)
        
        # Flatten observations to feed into controller
        b_size = s.shape[0]
        flat_s = s.view(b_size, -1)  # [batch_size, n_agent * obs_dim]
        
        # Get communication selection probabilities
        comm_probs = self.comm_controller(flat_s)  # [batch_size, 2]
        comm_selection = torch.argmax(comm_probs, dim=-1)  # [batch_size]
        
        # Evaluate both V(o) and V(s)
        s_local = self.collect_v_local.gather(s)
        s_global = self.collect_v.gather(s)
        
        values_local = []
        values_global = []
        for i in range(self.n_agent):
            values_local.append(self.vs_local[i](s_local[i]))
            values_global.append(self.vs[i](s_global[i]))
        
        values_local = torch.stack(values_local, dim=1)   # V(o)
        values_global = torch.stack(values_global, dim=1) # V(s)
        
        # Select based on controller decision
        values = torch.where(
            comm_selection.unsqueeze(-1).unsqueeze(-1) == 0,
            values_local,
            values_global
        )
        
        return values

    def _init_actors(self):
        '''
        Function name: init_actor
        Proposition: 初始化 actor
        Note: 1. torch.matrix_power(self.adj, self.radius_pi) 定义了智能体通信范围，默认 radius = 1
        
        Step:
        1. 检查环境名称是否为 'UAV_9d'，并且算法名称是否为 'CPPO'
        2. 矩阵幂运算：使用 torch.matrix_power 计算邻接矩阵 adj 的 radius_pi 次幂。这一步是为了扩展邻域范围，radius_pi 表示邻域的半径
        3. 初始化模块列表：创建一个 nn.ModuleList，用于存储每个智能体的行为者
        '''

        # 检查环境名称是否为 'UAV_9d'，并且算法名称是否为 'CPPO'
        # 如果条件满足，则初始化一个大小为 25×25 的邻接矩阵 adj，所有元素值为 1
        if self.env_name == 'UAV_9d' and self.algo_name == 'CPPO':
            self.adj = torch.as_tensor(np.ones((25, 25)), device=self.device, dtype=torch.float)

        # 智能体邻域大小：使用 torch.matrix_power 计算邻接矩阵 adj 的 radius_pi 次幂。这一步是为了扩展邻域范围，radius_pi 表示邻域的半径
        collect_pi = MultiCollect(torch.matrix_power(self.adj, self.radius_pi), device=self.device) # 定义了邻接矩阵类， size is [28,1] in monaco

        # 初始化模块列表：创建一个 nn.ModuleList，用于存储每个智能体的行为者。
        # 循环遍历智能体：对每个智能体（self.n_agent 表示智能体的数量）进行操作。
        actors = nn.ModuleList()
        for i in range(self.n_agent):
            
            # 确定输入维度 self.pi_args：collect_pi.degree[i]: 邻居个数(包括自己), self.pi_args.sizes 观测的维度
            # e.g: env = monaco, observation.dim=22, monaco-self.discrete
            self.pi_args.sizes[0] = collect_pi.degree[i] * self.observation_dim
            if self.discrete:
                actors.append(CategoricalActor(**self.pi_args._toDict()).to(self.device))
            else:
                actors.append(GaussianActor(action_dim=self.action_dim, **self.pi_args._toDict()).to(self.device))
        
        return collect_pi, actors
    
    def _init_vs(self):
        '''
        Function name: init_vs
        Proposition: 初始化 value function
        Note: 用于评估 value function 的网络, v 的 size 是 obsevation_dim; 通过对邻接矩阵 self.adj 进行幂运算
              计算得到一个新的矩阵, 这个矩阵表示每个智能体在某个半径范围（self.radius_v）内能够感知到的其他智能体的连接情况，radius = 123
        
        Step:
        1. 如果环境名称为 'UAV_9d'，则初始化一个大小为 25×25 的邻接矩阵 adj，所有元素值为 1
        2. 计算邻域信息
        3. 构建价值函数集合；遍历所有智能体：根据每个智能体的邻域大小（collect_v.degree[i]）和观测维度（self.observation_dim），调整价值函数的输入维度
           从 self.v_args 中获取价值函数的网络结构
        '''

        # 如果环境名称为 'UAV_9d'，则初始化一个大小为 25×25 的邻接矩阵 adj，所有元素值为 1
        if self.env_name == 'UAV_9d':
            self.adj = torch.as_tensor(np.ones((25, 25)), device=self.device, dtype=torch.float)

        # 计算邻域信息
        collect_v = MultiCollect(torch.matrix_power(self.adj, self.radius_v), device=self.device)

        # 构建价值函数集合；遍历所有智能体：根据每个智能体的邻域大小（collect_v.degree[i]）和观测维度（self.observation_dim），调整价值函数的输入维度
        # 从 self.v_args 中获取价值函数的网络结构
        vs = nn.ModuleList()
        for i in range(self.n_agent):
            self.v_args.sizes[0] = collect_v.degree[i] * self.observation_dim
            #print('self.v_args.sizes[0]=',self.v_args.sizes[0])
            v_fn = self.v_args.network
            vs.append(v_fn(**self.v_args._toDict()).to(self.device))
        
        return collect_v, vs
    
    def _init_vs_local(self):
        '''
        Initialize value functions that only use local observations (V(o))
        '''
        collect_v_local = MultiCollect(torch.eye(self.n_agent, device=self.device), device=self.device)
        vs_local = nn.ModuleList()
        for i in range(self.n_agent):
            self.v_args.sizes[0] = self.observation_dim  # Only local observation
            v_fn = self.v_args.network
            vs_local.append(v_fn(**self.v_args._toDict()).to(self.device))
        return collect_v_local, vs_local
    
    def _process_traj(self, s, a, r, s1, d, logp):
        """
        Function name: process_traj
        Propose: 利用轨迹信息计算优势函数
        Input are all in shape [batch_size, T, n_agent, dim]
        Output: value.detach(), returns, advantages.detach(), reduced_advantages.detach()

        Step:
        1. 初始化相关参数, 状态价值, return, deltas, advantage, 前一个时刻的状态值（这里为了利用 GAE 来计算优势）
        2. GAE 计算优势
        3. 计算所有智能体的 advantage function
        4. 如果使用优势函数, 则对优势进行归一化处理,使其均值为0,标准差为1
        """

        # 1. 初始化相关参数, 状态价值, return, deltas, advantage, 前一个时刻的状态值（这里为了利用 GAE 来计算优势）
        b, T, n, dim_s = s.shape
        s, a, r, s1, d, logp = [item.to(self.device) for item in [s, a, r, s1, d, logp]]
        value = self._evalV(s.view(-1, n, dim_s)).view(b, T, n, -1)
        returns = torch.zeros(value.size(), device=self.device)
        deltas, advantages = torch.zeros_like(returns), torch.zeros_like(returns)
        prev_value = self._evalV(s1.select(1, T - 1))
        if not self.use_rtg:
            prev_return = prev_value
        else:
            prev_return = torch.zeros_like(prev_value)
        prev_advantage = torch.zeros_like(prev_return)
        d_mask = d.float()

        # 2. GAE 计算优势
        for t in reversed(range(T)):
            deltas[:, t, :, :]= r.select(1, t) + self.gamma * (1-d_mask.select(1,t)) * prev_value - value.select(1, t).detach()
            advantages[:, t, :, :] = deltas.select(1, t) + self.gamma * self.lamda * (1-d_mask.select(1,t)) * prev_advantage
            if self.use_gae_returns:
                returns[:, t, :, :] = value.select(1, t).detach() + advantages.select(1, t)
            else:
                returns[:, t, :, :] = r.select(1, t) + self.gamma * (1-d_mask.select(1, t)) * prev_return
            prev_return = returns.select(1, t)
            prev_value = value.select(1, t)
            prev_advantage = advantages.select(1, t)
        
        # 3. 计算所有智能体的 advantage function
        reduced_advantages = self.collect_v.reduce_sum(advantages.view(-1, n, 1)).view(advantages.size())

        # 4. 如果使用优势函数，则对优势进行归一化处理，使其均值为0，标准差为1
        if self.advantage_norm and reduced_advantages.size()[1] > 1:
            reduced_advantages = (reduced_advantages - reduced_advantages.mean(dim=1, keepdim=True)) / (reduced_advantages.std(dim=1, keepdim=True) + 1e-5)
            advantages = (advantages - advantages.mean(dim=1, keepdim=True)) / (advantages.std(dim=1, keepdim=True) + 1e-5)
        
        return value.detach(), returns, advantages.detach(), reduced_advantages.detach()