# -*- coding: utf-8 -*-
"""
Created on Tue Sep  6 01:56:13 2022

@author: 86153
"""


import time
import os
from numpy.core.numeric import indices
from torch.distributions.normal import Normal
from algorithms.utils import collect, mem_report
from algorithms.models import GaussianActor, GraphConvolutionalModel, MLP, CategoricalActor
from tqdm.std import trange
#from algorithms.algorithm import ReplayBuffer
#from ray.state import actors
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import pickle
from copy import deepcopy as dp
from algorithms.models import CategoricalActor, EnsembledModel, SquashedGaussianActor, ParameterizedModel_MBPPO
import random
import multiprocessing as mp
# import torch.multiprocessing as mp
from torch import distributed as dist
import argparse

class MultiCollect:
    '''
    这个类用于定义它通过定义的邻接矩阵（adjacency：torch Tensor）来管理智能体之间的连接关系,
    并提供不同的方法（gather, reduce_mean, reduce_sum)来收集或汇总信息
    Everything outward would be in the same device specifed in the initialization parameter.
    '''

    def __init__(self, adjacency, device='cuda'):
        '''
        function name: init
        proposition: 初始化邻接矩阵
        input: adjacency

        Step:
        1. 将邻接矩阵放到指定设备上
        2. 添加自环(Self-loop),并将邻接矩阵放到GPU上
        3. (邻居有几个)计算每个智能体的邻居数量(度数),sum(dim=1)：按行求和，得到每个智能体有多少个邻居（包括自己）
        4. (邻居到底有谁)构建每个智能体的邻居索引列表 self.indices, 返回每个智能体和其存在多跳联系的所有智能体
        '''

        # 将邻接矩阵放到指定设备上
        self.device = device
        self.n = adjacency.size()[0]
        adjacency = adjacency > 0 # Adjacency Matrix（邻接矩阵）, with size n_agent*n_agent，转为bool值，有连接：true
        
        # 添加自环（Self-loop），并将邻接矩阵放到GPU上
        adjacency = adjacency | torch.eye(self.n, device=device).bool() # 创建一个单位矩阵并转换为 bool 型，逻辑“或”运算符（bitwise OR），在两个布尔型张量之间进行“或”操作
        
        adjacency = adjacency.to(device)
        
        # 计算每个智能体的邻居数量（度数），sum(dim=1)：按行求和，得到每个智能体有多少个邻居（包括自己）

        self.degree = adjacency.sum(dim=1) # Number of information available to the agent（表示每个智能体有多少邻居（包括它自己））
        

        # 构建每个智能体的邻居索引列表 self.indices, 返回每个智能体和其存在多跳联系的所有智能体
        self.indices = [] 
        index_full = torch.arange(self.n, device=device)
        for i in range(self.n):
            self.indices.append(torch.masked_select(index_full, adjacency[i])) # Which agents are needed.

    def gather(self, tensor):
        """
        Input shape: [batch_size, n_agent, dim]
        Return shape: [[batch_size, dim_i] for i in range(n_agent)]
        """
        return self._collect('gather', tensor)

    def reduce_mean(self, tensor):
        """
        Input shape: [batch_size, n_agent, dim]
        Return shape: [[batch_size, dim] for i in range(n_agent)]
        """
        return self._collect('reduce_mean', tensor)

    def reduce_sum(self, tensor):
        """
        Input shape: [batch_size, n_agent, dim]
        Return shape: [[batch_size, dim] for i in range(n_agent)]
        """
        return self._collect('reduce_sum', tensor)

    def _collect(self, method, tensor):
        """
        - Function name: _collect
        - Propose: 根据 method 聚合 tensor 信息, 然后将每个agent 需要关注的邻居 index append 到一个 list 里面
        - Input shape: [batch_size, n_agent, dim]
        - Return shape: result

        Step:
        1. 将 tensor 按照 batch_size 分割成多个 batch
        2. 根据 method 聚合每个 batch 的 tensor 信息
        3. 将每个 batch 的 tensor 信息拼接起来
        4. 返回拼接后的 tensor 信息
        """
        tensor = tensor.to(self.device)

        # 保证输入始终是三维张量 [batch_size, n_agent, dim]
        if len(tensor.shape) == 1:
            tensor = tensor.unsqueeze(0)
        if len(tensor.shape) == 2:
            tensor = tensor.unsqueeze(-1)
        b, n, depth = tensor.shape

        # 根据 method 不同，进行不同处理：'gather': 把它们展平拼接起来（用于后续输入到网络中） / "reduce_mean" 表示将所有 agent 的输出进行平均  / 求和
        result = [] 
        for i in range(n):
            if method == 'gather':
                # print("buffer.py,122行, khop邻居个数",i, len(self.indices[i]))
                result.append(torch.index_select(tensor, dim=1, index=self.indices[i]).view(b, -1))
            elif method == 'reduce_mean':
                result.append(torch.index_select(tensor, dim=1, index=self.indices[i]).mean(dim=1))
            else:
                result.append(torch.index_select(tensor, dim=1, index=self.indices[i]).sum(dim=1))
        
        #  如果不是 'gather'，其他方法要将结果堆叠成张量返回
        if method != 'gather':
            result = torch.stack(result, dim=1)
        
        return result

class DynamicMultiCollect(MultiCollect):
    '''
    - Class name: DynamicMultiCollect
    - Proposition: 处理需要动态变化的邻接矩阵
    '''

    def __init__(self, init_start, adj, device='cuda', max_adj=None):
        '''
        - Funcation name: __init__
        - Proposition: 初始化动态邻接矩阵

        Steps:
        1. 这里传入的 max_adj 仅用于记录邻接矩阵大小; 将其转换为一个 bool 值矩阵
        2. 按照行，将每一列的邻接矩阵相加；如果是初始化则将所有邻接矩阵进行填充
        3. 将邻居的索引拿出来
        '''

        super().__init__(adj, device='cuda')
        
        # init_actors 的时候需要增加一个这个
        if max_adj is None:
            max_adj = adj

        # 这里传入的 max_adj 仅用于记录邻接矩阵大小; 将其转换为一个 bool 值矩阵
        self.max_adj = (max_adj > 0).to(device).bool()
        self.max_adj = self.max_adj | torch.eye(self.n, device=device).bool()
        self.adj = (adj > 0).to(device).bool()

        # 按照行，将每一列的邻接矩阵相加；如果是初始化则将所有邻接矩阵进行填充
        self.degree = self.max_adj.sum(dim=1)
        if init_start == True:
            self.degree = torch.ones_like(self.degree) * self.n
        
        # self.max_indices 表示邻居索引，这里将邻居索引全拿出来
        self.max_indices = []
        index_full = torch.arange(self.n, device=device)
        for i in range(self.n):
            self.max_indices.append(torch.masked_select(index_full, self.max_adj[i]))    
            # print("buffer.py, 159 行, 最大邻居数", self.max_indices[i].size())
                
    def dynamic_gather(self, tensor):
        return self._collect("gather", tensor)

    def update_adj(self, new_adj):
        """
        动态更新邻接矩阵，并重新计算每个智能体的邻居索引和度数。
        
        Args:
            new_adj: 新的邻接矩阵 (torch.Tensor, shape [n_agent, n_agent])
        """
        device = self.device
        n = new_adj.size(0)
        new_adj = new_adj > 0
        new_adj = new_adj | torch.eye(n, device=device).bool()  # 添加自环
        new_adj = new_adj.to(device)

        # 更新 degree
        self.degree = new_adj.sum(dim=1)

        # 更新 indices
        index_full = torch.arange(n, device=device)
        self.indices = []
        for i in range(n):
            self.indices.append(torch.masked_select(index_full, new_adj[i]))

    def pad_tensor_list(tensor_list):
        """
        - Function name: pad_tensor_list
        - Proposition: 对动态变化的tensor 进行补 0 操作
        - Input: 输入是一个 List[Tensor], shape 各不相同
        - Output: 补 0 操作后的 Tensor, shape: [n_agent, max_dim]

        Step:
        1. 获取 List 中 Tensor 的最大 shape
        2. 填充 0
        """

        # 获取 List 中 Tensor 的最大 shape
        max_dim = max(t.size(1) for t in tensor_list)

        # 填充
        padded = []
        for t in tensor_list:
            pad_len = max_dim - t.size(1) # 计算需要补 0 的长度
            padded_t = torch.nn.functional.pad(t, (0, pad_len))  # 在最后一个维度补 0
            padded.append(padded_t)
        
        return torch.stack(padded, dim=0)
    
    
    def _collect(self, method, tensor):
        """
        - Function name: _collect
        - Propose: 根据 method 聚合 tensor 信息, 然后将每个agent 需要关注的邻居 index append 到一个 list 里面
        - Input shape: [batch_size, n_agent, dim]
        - Return shape: result

        Step:
        1. 将 tensor 按照 batch_size 分割成多个 batch
        2. 根据 method 聚合每个 batch 的 tensor 信息
        3. 将每个 batch 的 tensor 信息拼接起来
        4. 返回拼接后的 tensor 信息
        """

        tensor = tensor.to(self.device)

        # 保证输入始终是三维张量 [batch_size, n_agent, dim]
        if len(tensor.shape) == 1:
            tensor = tensor.unsqueeze(0)
        if len(tensor.shape) == 2:
            tensor = tensor.unsqueeze(-1)
        b, n, depth = tensor.shape

        # 根据 method 不同，进行不同处理：'gather': 把它们展平拼接起来（用于后续输入到网络中） / "reduce_mean" 表示将所有 agent 的输出进行平均  / 求和
        result = [] 
        for i in range(n):
            if method == 'gather':
                
                # 选择相关的下标
                select_tensor = torch.index_select(tensor, dim=1, index=self.indices[i])
                
                # 如果当前数值小于最大索引，则在后面补齐 0
                # print("两个之间的差值：", len(self.indices[i]),len(self.max_indices[i]))
                if len(self.indices[i]) < len(self.max_indices[i]):
                    pad_tensor = torch.zeros((b, len(self.max_indices[i]) - len(self.indices[i]), depth), device=select_tensor.device)
                    select_tensor = torch.cat([select_tensor, pad_tensor], dim=1)
                    
                # 将更新后的下标
                result.append(select_tensor.view(b, -1))
                # print("buffer.py, 246行", select_tensor.size())
                
                # print("buffer.py, 251 行", len(result[i][0]), result[i])
            elif method == 'reduce_mean':
                result.append(torch.index_select(tensor, dim=1, index=self.indices[i]).mean(dim=1))
            else:
                result.append(torch.index_select(tensor, dim=1, index=self.indices[i]).sum(dim=1))
        
        #  如果不是 'gather'，其他方法要将结果堆叠成张量返回
        if method != 'gather':
            result = torch.stack(result, dim=1)


        # 2. 对数据处理之后进行填充
        
        # if method == "gather":
        #     tensor = DynamicMultiCollect.pad_tensor_list(result)
        
        return result
    

class TrajectoryBad:
    '''
    class name: TrajectoryBad
    Propose: 定义轨迹 (带有双向动作信息)
    '''
    def __init__(self, **kwargs):
        """
        Data are of size [T, n_agent, dim].
        """
        
        self.names = ["s", "a", "r", "c", "k", "s1", "d", "logp", "logp_k"]
        self.dict = {name: kwargs[name] for name in self.names}
        self.length = self.dict["s"].size()[0]
          
    def getFraction(self, length, start=None):
        '''
        Function name: get fraction
        propose: 创建切片
        inputs: length
        outputs: Trajectory(**new_dict)
        
        Steps:
        1. 检查长度，如果是 none就随机开始一个位置
        2. 按照self.name 输出一个切片
        '''

        # 检查长度，如果是 none就随机开始一个位置
        if self.length < length:
            length = self.length
        start_max = self.length - length
        if start is None:
            start = torch.randint(low=0, high=start_max+1, size=(1,)).item()
        start = min(max(start, 0), start_max) 
        
        # 按照self.name 输出一个切片
        new_dict = {name: self.dict[name][start:start+length] for name in self.names}
        return Trajectory(**new_dict)
    
    def __getitem__(self, key):
        assert key in self.names
        return self.dict[key]
    
    @classmethod
    def names(cls):
        return ["s", "a", "r", "c", "k", "s1", "d", "logp", "logp_k"]

class TrajectoryBufferBad:
    '''
    Class name: TrajectoryBufferBad
    Propose: 定义 buffer (带有双向动作定义的 buffer)
    '''

    def __init__(self, device="cuda"):
        self.device = device
        self.s, self.a, self.r, self.c, self.k, self.s1, self.d, self.logp, self.logp_k = [], [], [], [], [], [], [], [], []
    
    def store_commscope(self, s, a, r, c, k, s1, d, logp, logp_k):
        """
        Function name: store_commscope
        Propose: 带通信成本的 buffer, 每个向量将转化为 [batch_size, n_agent, dim].
        inputs: s, a, r, k, s1, d, logp, logp_k
        outputs: None
        Notes:Would be converted into [batch_size, n_agent, dim].
        
        Steps:
        1. 将相关参数转化为 tensor
        2. 每个向量将转化为 [batch_size, n_agent, dim]
        """

        # 1. 将相关参数转化为 tensor
        device = self.device
        [s, r, c, s1, logp, logp_k] = [torch.as_tensor(item, device=device, dtype=torch.float) for item in [s, r, c, s1, logp, logp_k]]
        d = torch.as_tensor(d, device=device, dtype=torch.bool)
        k = torch.as_tensor(k, device=device)
        a = torch.as_tensor(a, device=device)

        while s.dim() <= 2:
            s = s.unsqueeze(dim=0)
        b, n, dim = s.size()
        
        if d.dim() <= 1:
            d = d.unsqueeze(0)
        d = d[:, :n]
        if r.dim() <= 1:
            r = r.unsqueeze(0)
        r = r[:, :n]
        if c.dim() <= 1:
            c = c.unsqueeze(0)
            c = c.expand(b, n)
        c = c[:, :n]
        if  k.dim() <= 1:
            k = k.unsqueeze(0)
            k = k.expand(b, n)
        k = k[:,:n]

        # 2. 每个向量将转化为 [batch_size, n_agent, dim]
        [s, a, k, r, c, s1, d, logp, logp_k] = [item.view(b, n, -1) for item in [s, a, k, r, c, s1, d, logp, logp_k]]
        self.s.append(s)
        self.a.append(a)
        self.r.append(r)
        self.c.append(c)
        self.k.append(k)
        self.s1.append(s1)
        self.d.append(d)
        self.logp.append(logp)
        self.logp_k.append(logp_k)
    
    def store(self, s, a, r, s1, d, logp):
        """
        Function name: store
        Propose: 普通版本的 buffer, 每个向量将转化为 [batch_size, n_agent, dim].
        inputs: s, a, r, s1, d, logp
        outputs: None
        Notes:Would be converted into [batch_size, n_agent, dim].
        
        Steps:
        1. 将相关参数转化为 tensor
        2. 每个向量将转化为 [batch_size, n_agent, dim]
        """

        # 1. 将相关参数转化为 tensor
        device = self.device
        [s, r, s1, logp] = [torch.as_tensor(item, device=device, dtype=torch.float) for item in [s, r, s1, logp]]
        d = torch.as_tensor(d, device=device, dtype=torch.bool)
        a = torch.as_tensor(a, device=device)

        while s.dim() <= 2:
            s = s.unsqueeze(dim=0)
        b, n, dim = s.size()
        
        if d.dim() <= 1:
            d = d.unsqueeze(0)
        d = d[:, :n]
        if r.dim() <= 1:
            r = r.unsqueeze(0)
        r = r[:, :n]

        # 2. 每个向量将转化为 [batch_size, n_agent, dim]
        [s, a, r, s1, d, logp] = [item.view(b, n, -1) for item in [s, a, r, s1, d, logp]]
        self.s.append(s)
        self.a.append(a)
        self.r.append(r)
        self.s1.append(s1)
        self.d.append(d)
        self.logp.append(logp)


    def retrieve(self, length=None):
        """
        Returns trajectories with s, a, r, s1, d, logp.
        Data are of size [T, n_agent, dim]
        """
        names = ["s", "a", "r", "c", "k", "s1", "d", "logp", "logp_k"]
        trajs = []
        traj_all = {}
        if self.s == []:
            return []
        for name in names:
            traj_all[name] = torch.stack(self.__getattribute__(name), dim=1)
        n = traj_all['s'].size()[0]
        for i in range(n):
            traj_dict = {}
            for name in names:
                traj_dict[name] = traj_all[name][i]  #ndecth batch into single traj
            trajs.append(TrajectoryBad(**traj_dict))
        return trajs

class Trajectory:
    def __init__(self, **kwargs):
        """
        Data are of size [T, n_agent, dim].
        """
        self.names = ["s", "a", "r", "s1", "d", "logp"]
        self.dict = {name: kwargs[name] for name in self.names}
        self.length = self.dict["s"].size()[0]
          
    def getFraction(self, length, start=None):

        if self.length < length:
            length = self.length
        start_max = self.length - length
        if start is None:
            start = torch.randint(low=0, high=start_max+1, size=(1,)).item()
            
        start = min(max(start, 0), start_max) 
        
        # if start > start_max:
        #     start = start_max
        # if start < 0:
        #     start = 0
      
        new_dict = {name: self.dict[name][start:start+length] for name in self.names}
        return Trajectory(**new_dict)
    
    def __getitem__(self, key):
        assert key in self.names
        return self.dict[key]
    
    @classmethod
    def names(cls):
        return ["s", "a", "r", "s1", "d", "logp"]

class TrajectoryBuffer:
    def __init__(self, device="cuda"):
        self.device = device
        self.s, self.a, self.r, self.s1, self.d, self.logp = [], [], [], [], [], []
    
    def store(self, s, a, r, s1, d, logp):
        """
        Would be converted into [batch_size, n_agent, dim].
        """
        device = self.device
        [s, r, s1, logp] = [torch.as_tensor(item, device=device, dtype=torch.float) for item in [s, r, s1, logp]]
        d = torch.as_tensor(d, device=device, dtype=torch.bool)
        a = torch.as_tensor(a, device=device)
        while s.dim() <= 2:
            s = s.unsqueeze(dim=0)
        b, n, dim = s.size()
        
        if d.dim() <= 1:
            d = d.unsqueeze(0)
        d = d[:, :n]
        if r.dim() <= 1:
            r = r.unsqueeze(0)
        r = r[:, :n]
        [s, a, r, s1, d, logp] = [item.view(b, n, -1) for item in [s, a, r, s1, d, logp]]
        self.s.append(s)
        self.a.append(a)
        self.r.append(r)
        self.s1.append(s1)
        self.d.append(d)
        self.logp.append(logp)
    
    def retrieve(self, length=None):
        """
        Returns trajectories with s, a, r, s1, d, logp.
        Data are of size [T, n_agent, dim]
        """
        names = ["s", "a", "r", "s1", "d", "logp"]
        trajs = []
        traj_all = {}
        if self.s == []:
            return []
        for name in names:
            traj_all[name] = torch.stack(self.__getattribute__(name), dim=1)
        n = traj_all['s'].size()[0]
        for i in range(n):
            traj_dict = {}
            for name in names:
                traj_dict[name] = traj_all[name][i]  #ndecth batch into single traj
            trajs.append(Trajectory(**traj_dict))
        return trajs

class ModelBuffer:
    def __init__(self, max_traj_num):
        self.max_traj_num = max_traj_num
        self.trajectories = []
        self.ptr = -1
        self.count = 0
    
    def storeTraj(self, traj):
        if self.count < self.max_traj_num:
            self.trajectories.append(traj)
            self.ptr = (self.ptr + 1) % self.max_traj_num
            self.count = min(self.count + 1, self.max_traj_num)
        else:
            self.trajectories[self.ptr] = traj
            self.ptr = (self.ptr + 1) % self.max_traj_num
    
    def storeTrajs(self, trajs):
        for traj in trajs:
            self.storeTraj(traj)
    
    def sampleTrajs(self, n_traj):
        traj_idxs = np.random.choice(range(self.count), size=(n_traj,), replace=True)
        return [self.trajectories[i] for i in traj_idxs]