'''
TODO: 需要调整的地方
1. 不同的 agent 需要调整的地方请搜索地图名字
'''

import time
import os
from numpy.core.numeric import indices
from torch.distributions.normal import Normal
from algorithms.utils import collect, mem_report
from algorithms.models import GaussianActor, GraphConvolutionalModel, MLP, CategoricalActor
from tqdm.std import trange
#from algorithms.algorithm import ReplayBuffer
#from ray.state import actors
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import pickle
from copy import deepcopy as dp
from algorithms.models import CategoricalActor, EnsembledModel, SquashedGaussianActor, ParameterizedModel_MBPPO
import random
import multiprocessing as mp
# import torch.multiprocessing as mp
from torch import distributed as dist
import argparse
from algorithms.algo.buffer import DynamicMultiCollect, MultiCollect,Trajectory,TrajectoryBuffer,ModelBuffer, TrajectoryBad, TrajectoryBufferBad
from algorithms.algo.normalization_utils import ZFilter, RunningStat
import traci  # noqa


def translate_action(action_bias, action_scale, action):
    action = torch.as_tensor(action, dtype=torch.float)
    actions = action.detach().squeeze()
    # clip and scale action to correct range for safety
    cp_actions = torch.clamp(actions, min=-1.0, max=1.0)
    #cp_actions = torch.tanh(actions)
    low = action_bias - action_scale
    high = action_bias + action_scale
    cp_actions = 0.5 * (cp_actions + 1.0) * (high - low) + low
    cp_actions = cp_actions.cpu().numpy()
    return cp_actions

def translate_action_2(action_bias, action_scale, action):
    actions = action.detach().squeeze()
    # clip and scale action to correct range for safety
    cp_actions = torch.clamp(actions, min=-1.0, max=1.0)
    #cp_actions = torch.tanh(actions)
    low = action_bias - action_scale
    high = action_bias + action_scale
    cp_actions = 0.5 * (cp_actions + 1.0) * (high - low) + low
    #cp_actions = cp_actions.cpu().numpy()
    return cp_actions

def transfer_action_real_power(a,result):
    b = np.array([])
    
    for i in range(result.shape[0]):
        row = a[i, :]
        mask = result[i, :]
        non_zeros = row[np.nonzero(mask)]  # 当前行 mask 中值为非零的元素
        res = a[i, len(non_zeros):]
        if len(res)!=0:
            b = np.concatenate([b,np.mean(res) + non_zeros])
        else:
            b = np.concatenate([b,non_zeros])
    return b

class OnPolicyRunner:
    '''
    class name: onpolicyrunner
    propose: 用于倒入参数运行代码

    function list:
        1. run: 主函数
        2. test: run 中需要每隔 100 次做一次验证
        3. rollout_env: 产生轨迹
        4. updateModel: 更新模型
        5, testModel: 不常用
    '''
    def __init__(self, logger, run_args, alg_args, agent, env_learn, env_test, env_args,**kwargs):
        self.logger = logger
        self.name = run_args.name
        if not run_args.init_checkpoint is None:
            agent.load(run_args.init_checkpoint)
            logger.log(interaction=run_args.start_step)  
        self.start_step = run_args.start_step 
        self.env_name = env_args.env
        self.algo_name = env_args.algo
        

        # algorithm arguments
        self.n_iter = alg_args.n_iter
        self.n_inner_iter = alg_args.n_inner_iter
        self.n_warmup = alg_args.n_warmup
        self.n_model_update = alg_args.n_model_update
        self.n_model_update_warmup = alg_args.n_model_update_warmup
        self.n_test = alg_args.n_test
        self.test_interval = alg_args.test_interval
        self.rollout_length = alg_args.rollout_length
        self.test_length = alg_args.test_length
        self.max_episode_len = alg_args.max_episode_len
        self.clip_scheme = None if (not hasattr(alg_args, "clip_scheme")) else alg_args.clip_scheme
        self.commscopecontrol = alg_args.CommScopeControl
        if self.commscopecontrol == True:
            self.alpha = alg_args.alpha    
        # agent initialization
        self.agent = agent
        self.device = self.agent.device if hasattr(self.agent, "device") else "cpu"

        # environment initialization
        self.env_learn = env_learn
        self.env_test = env_test
        if self.env_name == 'PowerGrid' and self.env_learn.n_agent==40:
            self.running_state = ZFilter((self.env_learn.n_agent,self.env_learn.n_s), clip=5.0)
        if self.env_name == 'Large_city':
            self.running_state = ZFilter((self.env_learn.n_agent,self.env_learn.n_s), clip=5.0)

        

        # buffer initialization
        self.discrete = agent.discrete
        action_dtype = torch.long if self.discrete else torch.float
        self.model_based = alg_args.model_based
        self.model_batch_size = alg_args.model_batch_size
        if self.model_based:
            self.n_traj = alg_args.n_traj
            self.model_traj_length = alg_args.model_traj_length
            self.model_error_thres = alg_args.model_error_thres
            self.model_buffer = ModelBuffer(alg_args.model_buffer_size)
            self.model_update_length = alg_args.model_update_length
            self.model_validate_interval = alg_args.model_validate_interval
            self.model_length_schedule = alg_args.model_length_schedule
            self.model_prob = alg_args.model_prob
        self.s, self.episode_len, self.episode_reward= self.env_learn.reset(), 0, 0
        self.cost_reward = 0

        
        # load pretrained model
        self.load_pretrained_model = alg_args.load_pretrained_model
        if self.model_based and self.load_pretrained_model:
            self.agent.load_model(alg_args.pretrained_model)
        
        if self.env_name == 'Real_Power':
            self.real_power_action_meam = (np.array([self.env_test.action_space.low]*self.env_test.n_agents) + np.array([self.env_test.action_space.high]*self.env_test.n_agents))/2
            self.real_power_action_var = (np.array([self.env_test.action_space.high]*self.env_test.n_agents) - np.array([self.env_test.action_space.low]*self.env_test.n_agents))/2
            self.running_state = ZFilter((self.env_learn.n_agents,self.env_learn.obs_size), clip=5.0)  #1.0

        elif self.env_name == 'Pandemic':
            self.running_state = ZFilter((self.env_learn.n_agent,self.env_learn.n_s), clip=5.0)

            s_min = np.array([[0]*16]*10)
            s_max = []
            num_persons = 500
            for i in range(len(self.env_learn.Nums_Location)):
                s_max.append(np.concatenate((np.array([self.env_learn.Nums_Location[i]]*3), np.array([num_persons, num_persons, num_persons, num_persons, num_persons, num_persons, num_persons, num_persons, num_persons, num_persons, 4, 1, 120]))))
            s_max.append(np.array([1,1,1,num_persons,num_persons,num_persons,num_persons,num_persons,num_persons,num_persons,num_persons,num_persons,num_persons,4,1,120]))
            s_max = np.array(s_max)
            self.s_mean = (s_max + s_min)/2
            self.s_std = (s_max - s_min)/2
            
    def run(self):
        '''
        Function name: run
        Propose: 用于运行代码

        1. 判断是否需要 model
        2. 迭代 iter 
            2.1 修改为每隔 100 次保存一次结果 -- 调用函数: self.test(iter)
            2.2 和环境交互，拿到轨迹 -- 调用函数: rollout_env()
            2.3 迭代 agent, 并 update -- 调用函数: updateAgent()
        '''   

        # 1.1 判断是否需要 model
        if self.model_based and not self.load_pretrained_model:
            for _ in trange(self.n_warmup):
                trajs = self.rollout_env()
                self.model_buffer.storeTrajs(trajs)
            self.updateModel(self.n_model_update_warmup) # Sample trajectories, then shorten them.

        # 2. 迭代 iter
        for iter in trange(self.n_iter):
            
            # 2.1 每 10 次就做一次测试,1000 次保存一次结果
            if iter % 10 == 0:
                mean_return = self.test(iter)
            if iter % 1000 == 0 and iter != 0:
                self.agent.save_nets(f'./checkpoints/{self.name}',iter)            

            # add a new function
            #if iter == 1300:
            #    trajs = self.rollout_env_judge()
            #else:
                # 2.2 和环境进行交互，拿到轨迹
            trajs = self.rollout_env()  #  TO cheak: rollout n_step, maybe multi trajs
            
            # 采用 model-based 则需要关注这部分
            t1=time.time()              
            if self.model_based:
                self.model_buffer.storeTrajs(trajs)
                # train the environment model
                if iter % 100 == 0:
                    self.updateModel()
            t2=time.time()
            print('t=',t2-t1)

            # 2.3 迭代 agent，并 update             
            agentInfo = []
            real_trajs = trajs
            for inner in trange(self.n_inner_iter):
                if self.model_based:
                    ## Use the model with a certain probability                  
                    use_model = np.random.uniform() < self.model_prob
                    if use_model:
                        if self.model_length_schedule is not None:
                            trajs = self.rollout_model(real_trajs, self.model_length_schedule(iter))
                        else:
                            trajs = self.rollout_model(real_trajs)
                    else:
                        trajs = trajs
                # 看是否需要打开 clip 部分进行更新策略，跳转到agent的updateagent函数
                if self.clip_scheme is not None:
                    info = self.agent.updateAgent(trajs, self.clip_scheme(iter))     #  TO cheak: updata
                else:
                    info = self.agent.updateAgent(trajs)
                
                agentInfo.append(info)
                if self.agent.checkConverged(agentInfo):
                    break
            self.logger.log(inner_iter = inner + 1, iter=iter)

    def test(self,nnn):
        '''
        Function name: test
        Propose: 用于测试评估算法结果
        Note: The environment should return sth like [n_agent, dim] or [batch_size, n_agent, dim] in either numpy or torch.
        
        1. 初始化
        2. 
        '''

        # 1. 初始化
        time_t = time.time()
        length = self.test_length
        returns = []
        test_cost_returns = []
        scaled = []
        lengths = []
        episodes = []
        S = []
        

        # 2. 运行
        for i in trange(self.n_test):
            episode = []
            k_list = []
            env = self.env_test    
            
            if self.env_name == 'eight' or self.env_name == 'ring':
                if i==0 and nnn == 0:
                    env.reset()    #for figure eight env
            elif self.env_name == "Large_city":
                env.clear()
                env.reset()   
            else:                           
                env.reset()     # for another env

            d, ep_ret, ep_len, ep_c = np.array([False]), 0, 0, 0

            # 初始化状态 s
            while not(d.any() or (ep_len == length)):     
                if self.env_name == 'PowerGrid' and env.n_agent==40:
                    s = env.get_state_()
                    s = self.running_state(s)

                elif self.env_name == "Pandemic":
                    s = env.get_state_()
                    #s = self.running_state(s)
                    
                    s = (s - self.s_mean) / self.s_std

                elif self.env_name == 'Real_Power':
                    s = env.get_state_()
                    s = np.array(s)
                    s = self.running_state(s)

                elif self.env_name == 'Large_city':
                    s = env.get_state_()
                    s = self.running_state(s)

                else:
                    s = env.get_state_()

                s = torch.as_tensor(s, dtype=torch.float, device=self.device)

                # 产生动作 a
                # TODO: 相关通信约束后面没有再使用过
                if self.commscopecontrol == True:
                    dist, scope_dist, k = self.agent.act(s, if_test=True)
                    a = dist.sample()
                else:
                    a = self.agent.act(s, if_test=True).sample() # a is a tensor            
                a = a.detach().cpu().numpy() # might not be squeezed at the last dimension. env should deal with this though.
                
                # 状态转移 step
                if (self.env_name == 'Monaco' and self.algo_name == 'IC3Net') or (self.env_name == 'Grid' and self.algo_name == 'IC3Net'):
                    s1, r, d, _ = env.step(np.squeeze(a))
                elif self.env_name == 'PowerGrid':
                    if self.algo_name == 'IA2C' or self.algo_name == 'IC3Net':
                        s1, r, d, _ = env.step(np.squeeze(a))
                    else:
                        s1, r, d, _ = env.step(a)
                    if env.n_agent==40:
                        s1 = self.running_state(s1)       
                elif self.env_name == "Pandemic":
                    if self.algo_name == 'IA2C' or self.algo_name == 'IC3Net':
                        s1, r, d, _ = env.step(np.squeeze(a))
                        #s1 = self.running_state(s1)
                        s1 = (s1 - self.s_mean) / self.s_std
                    else:
                        s1, r, d, _ = env.step(a)
                        #s1 = self.running_state(s1)
                        s1 = (s1 - self.s_mean) / self.s_std
                elif self.env_name == 'Large_city':  
                    if self.algo_name == 'IA2C' or self.algo_name == 'IC3Net':
                        s1, r, d, _ = env.step(np.squeeze(a))  
                    else:
                        s1, r, d, _ = env.step(a)
                    s1 = self.running_state(s1)
                elif self.env_name == 'Real_Power':

                    #a = translate_action(env.args.action_bias, env.args.action_scale, a)
                    #a = np.float32(a)               
                    #action_mask = env.action_mask                    
                    #n = np.max(np.count_nonzero(action_mask, axis=1))
                    #result = np.zeros((action_mask.shape[0], n))                  
                    #for i in range(action_mask.shape[0]):
                        #nonzero_indices = np.nonzero(action_mask[i])[0]
                        #result[i, :nonzero_indices.shape[0]] = action_mask[i, nonzero_indices]         
                    ##actual_action = np.multiply(a, result)
                    ##a_actual = actual_action[actual_action.nonzero()]

                    #a_actual = transfer_action_real_power(a,result)

                    #print("a=",a.shape)

                    r, d, info = env.step(a)
                    s1 = env.get_state_()
                    s1 = np.array(s1)
                    #print("s1=",s1.shape)
                    s1 = self.running_state(s1)
                    r = np.array([r/env.n_agents]*env.n_agents)
                    #r = np.clip(r, -1, 0)
                    r = np.array([info["totally_controllable_ratio"]]*env.n_agents)
                    
                    #r = np.array([info["totally_controllable_ratio"]/(env.n_agents**2)]*env.n_agents)

                    d = [d]*env.n_agents
                else:    
                    s1, r, d, _ = env.step(a)
                    #print("a=",a)
                    #print("s1=",s1)
                    #print("r=",r)
                    #print("d=",d)
                    
                episode += [(s.tolist(), a.tolist(), r.tolist())]
                d = np.array(d)
                ep_ret += r.sum()

                # 需要把每次的 k 值的平均值压出来看一下
                if self.commscopecontrol == True:
                    k_list.append(k)
                    ep_c += (torch.exp(self.alpha * k) - 1).item()
                    # ep_c += (1 + self.alpha * k.item()) ** 2 - 1
                ep_len += 1
                self.logger.log(interaction=None)

            if self.env_name != "Large_city":
                if hasattr(env, 'rescaleReward'):            
                    scaled += [ep_ret]
                    ep_ret = env.rescaleReward(ep_ret, ep_len)
            returns += [ep_ret]
            lengths += [ep_len]
            episodes += [episode]
            
            if self.commscopecontrol == True:
                test_cost_returns += [ep_c]
            
            #traci.close()
            #sys.stdout.flush()
        
        # 把 return 和 length 也压入 stack 里
        returns = np.stack(returns, axis=0)
        lengths = np.stack(lengths, axis=0)
        if self.commscopecontrol == True:
            test_scope = torch.stack(k_list).float().mean().item()
            test_cost_returns = np.stack(test_cost_returns, axis=0)
        
        self.logger.log(test_episode_reward=returns.mean().item(), test_episode_len=lengths.mean().item(), test_round=None)
        if self.commscopecontrol == True:
            self.logger.log(test_scope_mean = test_scope, test_episode_cost_reward = test_cost_returns.mean().item())
        print(f"{self.n_test} episodes average accumulated reward: {returns.mean()}")
        if self.env_name != "Large_city":
            if hasattr(env, 'rescaleReward'):
                print(f"scaled reward {np.mean(scaled)}")
        with open(f"checkpoints/{self.name}/test.pickle", "wb") as f:
            pickle.dump(episodes, f)
        with open(f"checkpoints/{self.name}/test.txt", "w") as f:
            for episode in episodes:
                for step in episode:
                    f.write(f"{step[0]}, {step[1]}, {step[2]}\n")
                f.write("\n")
        self.logger.log(test_time=time.time()-time_t)
        return returns.mean()

    def rollout_env_judge(self, length = 0):
        """
        Function name: rollout_env
        Propose: 用于和环境进行交互拿到一个
        Outpu: trajs
        Note：The environment should return sth like [n_agent, dim] or [batch_size, n_agent, dim] in either numpy or torch.
        
        1. 初始化相关配置
        2. 循环
            2.1 将状态s变为tensor
            2.2 计算一个分布，从分布中 sample 一个动作
            2.3 计算一下动作的对数值，policy gradient 算loss的时候要用
            2.4 转移到下一步，然后拿到下一个时刻的状态动作值（奖励用的 average reward)
            2.5 存储 traj
            2.6 计算累积奖励 episode_reward
            2.7 对不同的环境做一些简单的调整，主要是判断一下奖励信息是否正确
        """

        # 1. 初始化相关配置
        time_t = time.time()
        if length <= 0:
            length = self.rollout_length
        env = self.env_learn

        # 轨迹的生成要分离开
        trajs = []
        k_list = []
        if self.commscopecontrol == True:
            traj = TrajectoryBufferBad(device=self.device)
        else: 
            traj = TrajectoryBuffer(device=self.device)
        start = time.time()
        
        if self.env_name == 'Real_Power':
            totally_controllable_ratio = 0

        if self.env_name == 'catchup' or self.env_name == 'slowdown' or self.env_name == 'Grid' or self.env_name == 'Monaco':
            env.reset() 

        # 2. 循环
        for t in range(length):
        # d, ep_len = np.array([False]), 0
        # while not(d.any() or (ep_len == length)):
            # ep_len+=1

            s = env.get_state_()    

            if self.env_name == 'PowerGrid' and env.n_agent==40:
                s = self.running_state(s)
            elif self.env_name == "Pandemic":
                #s = self.running_state(s)
                s = (s - self.s_mean) / self.s_std
            elif self.env_name == 'Real_Power':
                s = np.array(s)
                s = self.running_state(s)
            elif self.env_name == 'Large_city':
                s = s
                s = self.running_state(s)

            # 2.1 将状态s变为tensor
            s = torch.as_tensor(s, dtype=torch.float, device=self.device)

            # 2.2 判断是否需要开启控制范围
            if self.commscopecontrol == True:
                dist, scope_dist, k = self.agent.act(s)
            else:
                # 2.2 计算一个分布，从分布中 sample 一个动作
                dist = self.agent.act(s)
            a = dist.sample()

            # 2.3 计算一下动作的对数值，policy gradient 算 loss 的时候要用
            # TODO： 这里需要知道一下 log 后面 loss 药用
            if self.commscopecontrol == True:
                logp_k = scope_dist.log_prob(k)
                
            logp = dist.log_prob(a)         
            a = a.detach().cpu().numpy()
            if self.commscopecontrol == True:
                k_list.append(k)
            
            # 2.4 转移到下一步，然后拿到下一个时刻的状态动作值（奖励用的 average reward)
            if (self.env_name == 'Monaco' and self.algo_name == 'IC3Net') or (self.env_name == 'Grid' and self.algo_name == 'IC3Net'):
                s1, r, d, _ = env.step(np.squeeze(a))
            elif self.env_name == 'PowerGrid' :
                if self.algo_name == 'IA2C' or self.algo_name == 'IC3Net':
                    s1, r, d, _ = env.step(np.squeeze(a))
                    d = np.array([d]*env.n_agent) 
                else:
                    s1, r, d, _ = env.step(a)
                    d = np.array([d]*env.n_agent)  
                if env.n_agent==40:
                    s1 = self.running_state(s1)
            elif self.env_name == "Pandemic":
                if self.algo_name == 'IA2C' or self.algo_name == 'IC3Net':
                    s1, r, d, _ = env.step(np.squeeze(a))
                    #s1 = self.running_state(s1)
                    s1 = (s1 - self.s_mean) / self.s_std
                else:
                    s1, r, d, _ = env.step(a)
                    #s1 = self.running_state(s1)
                    s1 = (s1 - self.s_mean) / self.s_std
            elif self.env_name == 'Large_city':   
                if self.algo_name == 'IA2C' or self.algo_name == 'IC3Net':
                    s1, r, d, _ = env.step(np.squeeze(a))   
                else:
                    s1, r, d, _ = env.step(a)
                s1 = self.running_state(s1)
                r = r
                # if self.episode_len + 1 == self.max_episode_len:
                #     done = np.array([True]*env.n_agent, dtype=np.float32)
            elif self.env_name == 'Real_Power':
                r, d, info = env.step(a)
                s1 = env.get_state_()
                s1 = np.array(s1)
                s1 = self.running_state(s1)
                r = np.array([r/env.n_agents]*env.n_agents)
                #r = np.clip(r, -1, 0)
                d = np.array([d]*env.n_agents)
                totally_controllable_ratio += info["totally_controllable_ratio"]
                r = np.array([info["totally_controllable_ratio"]]*env.n_agents)
                #r = np.array([info["totally_controllable_ratio"]/(env.n_agents**2)]*env.n_agents)
                r = np.float32(r)
            else:    
                s1, r, d, _ = env.step(a)

            # 2.5 判断是否需要通信，如果需要的化就要增加通信成本约束
            # 选择幂函数形式，跳数越大代价越高

            # new function: 增加新的统计函数来统计一下数组
            print(t,k)

            if self.commscopecontrol == True:
                c = torch.exp(self.alpha * k) - 1

            # 2.5 存储 traj
            if self.commscopecontrol == True:
                traj.store_commscope(s, a, r, c, k, s1, d, logp, logp_k)
            else:
                traj.store(s, a, r, s1, d, logp)

            # 2.6 计算累积奖励 episode_reward
            episode_r = r
            
            # 2.7 添加一个成本约束
            if self.commscopecontrol == True:
                c_r = c

            if hasattr(env, '_comparable_reward'):
                episode_r = env._comparable_reward()
            if episode_r.ndim > 1:
                episode_r = episode_r.mean(axis=0)
                if self.commscopecontrol == True:
                    c_r = c_r.mean(axis=0)
            #if episode_r.ndim == 1:
                #episode_r = episode_r.sum()

            self.episode_reward += episode_r
            if self.commscopecontrol == True:
                self.cost_reward += c_r

            self.episode_len += 1
            self.logger.log(interaction=None)
            if self.episode_len == self.max_episode_len:
                d = np.zeros(d.shape, dtype=np.float32)
            d = np.array(d)
            
            # 2.7 对不同的环境做一些简单的调整，主要是判断一下奖励信息是否正确
#-----------------------------------------------------------------------------------------  
            # for CACC_env(catchup and slowdown)
            if self.env_name == 'catchup' or self.env_name == 'slowdown':  
                if self.env_name == 'catchup':
                
                    # 判断是否当前回合已经结束，或者回合的最大长度已经达到
                    if self.episode_len == self.max_episode_len:                         
                        
                        # 记录当前回合奖励，回合长度以及其他信息。check 一下reward是否=0，看一下reset 是否成功
                        self.logger.log(episode_reward=self.episode_reward.sum()/600, episode_len = self.episode_len, episode=None)
                        if self.commscopecontrol == True:
                            train_scope = torch.stack(k_list).float().mean().item()
                            self.logger.log(train_scope_mean = train_scope, train_cost = self.cost_reward.sum().item())
                        try:
                            _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                        except Exception as e:
                            print('reset error!:', e)
                            _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                            if self.model_based == False:
                                trajs += traj.retrieve()
                                traj = TrajectoryBuffer(device=self.device)
                    
                    # 如果是 model-based 则会从buffer中检索轨迹，并创建一个学习心得TrajectoryBuffer来存储轨迹
                    if self.episode_len == self.max_episode_len:
                        if self.model_based:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)    
       
                elif self.env_name == 'slowdown':                                      
                    if d.any() or (self.episode_len == self.max_episode_len):      #for slowdown   
                    #if self.episode_len == self.max_episode_len:  
                        self.logger.log(episode_reward=self.episode_reward.sum()/600, episode_len = self.episode_len, episode=None)
                        if self.commscopecontrol == True:
                            train_scope = torch.stack(k_list).float().mean().item()
                            self.logger.log(train_scope_mean = train_scope, test_cost = self.cost_reward.sum().item() / 600)
                        try:
                            _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                            if self.commscopecontrol == True:
                                self.cost_reward = 0
                        except Exception as e:
                            print('reset error!:', e)
                            _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                            if self.model_based == False:
                                trajs += traj.retrieve()
                                traj = TrajectoryBuffer(device=self.device)                           
                    if self.episode_len == self.max_episode_len:
                        if self.model_based:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
#----------------------------------------------------------------------------------------- 
            elif self.env_name == 'eight' or self.env_name == 'ring':          
                # if d.any() or (self.episode_len == self.max_episode_len):     
                if self.episode_len == self.max_episode_len:                                
                    self.logger.log(episode_reward=self.episode_reward.sum(), episode_len = self.episode_len, episode=None)
                    if self.commscopecontrol == True:
                        train_scope = torch.stack(k_list).float().mean().item()
                        self.logger.log(train_scope_mean = train_scope, test_cost = self.cost_reward.sum().item())
                    try:
                        self.episode_reward, self.episode_len = 0, 0 #TODO:catch up the error
                        if self.commscopecontrol == True:
                            self.cost_reward = 0
                    except Exception as e:
                        print('reset error!:', e)
                        self.episode_reward, self.episode_len =  0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
            elif self.env_name == 'PowerGrid':
            # for other_env
                if d.any() or (self.episode_len == self.max_episode_len):      
                # if self.episode_len == self.max_episode_len:                 
                    self.logger.log(episode_reward=self.episode_reward.sum()/env.T, episode_len = self.episode_len, episode=None)
                    if self.commscopecontrol == True:
                        train_scope = torch.stack(k_list).float().mean().item()
                        self.logger.log(scope_mean = train_scope, episode_cost_reward = self.cost_reward.sum().item()/env.T)
                    try:
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                        if self.commscopecontrol == True:
                            self.cost_reward = 0
                    except Exception as e:
                        print('reset error!:', e)
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
                            
                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
            elif self.env_name == 'Real_Power':
            # for other_env
                if d.any() or (self.episode_len == self.max_episode_len):      
                # if self.episode_len == self.max_episode_len:                 
                    
                    self.logger.log(episode_reward=self.episode_reward.sum(), episode_len = self.episode_len, episode=None, totally_controllable_ratio=totally_controllable_ratio)
                    try:
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                    except Exception as e:
                        print('reset error!:', e)
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
                            
                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
            elif self.env_name == 'Large_city':  
                if self.episode_len == self.max_episode_len:      
                # if self.episode_len == self.max_episode_len:                 
                    
                    self.logger.log(episode_reward=self.episode_reward.sum(), episode_len = self.episode_len, episode=None)
                    if self.commscopecontrol == True:
                        train_scope = torch.stack(k_list).float().mean().item()
                        self.logger.log(train_scope_mean = train_scope, test_cost = self.cost_reward.sum().item())
                    try:
                        self.env_learn.clear()
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                        if self.commscopecontrol == True:
                            self.cost_reward = 0
                    except Exception as e:
                        print('reset error!:', e)
                        self.env_learn.clear()
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
                            
                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
            else:
            # for other_env
                if d.any() or (self.episode_len == self.max_episode_len):      
                # if self.episode_len == self.max_episode_len:                 
                    self.logger.log(episode_reward=self.episode_reward.sum(), episode_len = self.episode_len, episode=None)
                    if self.commscopecontrol == True:
                        train_scope = torch.stack(k_list).float().mean().item()
                        self.logger.log(train_scope_mean = train_scope, test_cost = self.cost_reward.sum().item())
                    try:
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0 #TODO:catch up the error
                        if self.commscopecontrol == True:
                            self.cost_reward = 0
                    except Exception as e:
                        print('reset error!:', e)
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)

                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
#--------------------------------------------------------------------------------------    
        end = time.time()
        print('time in 1 episode is ',end-start)
        trajs += traj.retrieve(length=self.max_episode_len)
        self.logger.log(env_rollout_time=time.time()-time_t)
        return trajs

    def rollout_env(self, length = 0):
        """
        Function name: rollout_env
        Propose: 用于和环境进行交互拿到一个
        Outpu: trajs
        Note：The environment should return sth like [n_agent, dim] or [batch_size, n_agent, dim] in either numpy or torch.
        
        1. 初始化相关配置
        2. 循环
            2.1 将状态s变为tensor
            2.2 计算一个分布，从分布中 sample 一个动作
            2.3 计算一下动作的对数值，policy gradient 算loss的时候要用
            2.4 转移到下一步，然后拿到下一个时刻的状态动作值（奖励用的 average reward)
            2.5 存储 traj
            2.6 计算累积奖励 episode_reward
            2.7 对不同的环境做一些简单的调整，主要是判断一下奖励信息是否正确
        """

        # 1. 初始化相关配置
        time_t = time.time()
        if length <= 0:
            length = self.rollout_length
        env = self.env_learn

        # 轨迹的生成要分离开
        trajs = []
        k_list = []
        if self.commscopecontrol == True:
            traj = TrajectoryBufferBad(device=self.device)
        else: 
            traj = TrajectoryBuffer(device=self.device)
        start = time.time()
        
        if self.env_name == 'Real_Power':
            totally_controllable_ratio = 0

        if self.env_name == 'catchup' or self.env_name == 'slowdown' or self.env_name == 'Grid' or self.env_name == 'Monaco':
            env.reset() 

        # 2. 循环
        for t in range(length):
        # d, ep_len = np.array([False]), 0
        # while not(d.any() or (ep_len == length)):
            # ep_len+=1

            s = env.get_state_()    

            if self.env_name == 'PowerGrid' and env.n_agent==40:
                s = self.running_state(s)
            elif self.env_name == "Pandemic":
                #s = self.running_state(s)
                s = (s - self.s_mean) / self.s_std
            elif self.env_name == 'Real_Power':
                s = np.array(s)
                s = self.running_state(s)
            elif self.env_name == 'Large_city':
                s = s
                s = self.running_state(s)

            # 2.1 将状态s变为tensor
            s = torch.as_tensor(s, dtype=torch.float, device=self.device)

            # 2.2 判断是否需要开启控制范围
            if self.commscopecontrol == True:
                dist, scope_dist, k = self.agent.act(s)
            else:
                # 2.2 计算一个分布，从分布中 sample 一个动作
                dist = self.agent.act(s)
            a = dist.sample()

            # 2.3 计算一下动作的对数值，policy gradient 算 loss 的时候要用
            # TODO： 这里需要知道一下 log 后面 loss 药用
            if self.commscopecontrol == True:
                logp_k = scope_dist.log_prob(k)
                
            logp = dist.log_prob(a)         
            a = a.detach().cpu().numpy()
            if self.commscopecontrol == True:
                k_list.append(k)
            
            # 2.4 转移到下一步，然后拿到下一个时刻的状态动作值（奖励用的 average reward)
            if (self.env_name == 'Monaco' and self.algo_name == 'IC3Net') or (self.env_name == 'Grid' and self.algo_name == 'IC3Net'):
                s1, r, d, _ = env.step(np.squeeze(a))
            elif self.env_name == 'PowerGrid' :
                if self.algo_name == 'IA2C' or self.algo_name == 'IC3Net':
                    s1, r, d, _ = env.step(np.squeeze(a))
                    d = np.array([d]*env.n_agent) 
                else:
                    s1, r, d, _ = env.step(a)
                    d = np.array([d]*env.n_agent)  
                if env.n_agent==40:
                    s1 = self.running_state(s1)
            elif self.env_name == "Pandemic":
                if self.algo_name == 'IA2C' or self.algo_name == 'IC3Net':
                    s1, r, d, _ = env.step(np.squeeze(a))
                    #s1 = self.running_state(s1)
                    s1 = (s1 - self.s_mean) / self.s_std
                else:
                    s1, r, d, _ = env.step(a)
                    #s1 = self.running_state(s1)
                    s1 = (s1 - self.s_mean) / self.s_std
            elif self.env_name == 'Large_city':   
                if self.algo_name == 'IA2C' or self.algo_name == 'IC3Net':
                    s1, r, d, _ = env.step(np.squeeze(a))   
                else:
                    s1, r, d, _ = env.step(a)
                s1 = self.running_state(s1)
                r = r
                # if self.episode_len + 1 == self.max_episode_len:
                #     done = np.array([True]*env.n_agent, dtype=np.float32)
            elif self.env_name == 'Real_Power':
                r, d, info = env.step(a)
                s1 = env.get_state_()
                s1 = np.array(s1)
                s1 = self.running_state(s1)
                r = np.array([r/env.n_agents]*env.n_agents)
                #r = np.clip(r, -1, 0)
                d = np.array([d]*env.n_agents)
                totally_controllable_ratio += info["totally_controllable_ratio"]
                r = np.array([info["totally_controllable_ratio"]]*env.n_agents)
                #r = np.array([info["totally_controllable_ratio"]/(env.n_agents**2)]*env.n_agents)
                r = np.float32(r)
            else:    
                s1, r, d, _ = env.step(a)

            # 2.5 判断是否需要通信，如果需要的化就要增加通信成本约束
            # 选择幂函数形式，跳数越大代价越高
            if self.commscopecontrol == True:
                c = torch.exp(self.alpha * k) - 1

            # 2.5 存储 traj
            if self.commscopecontrol == True:
                traj.store_commscope(s, a, r, c, k, s1, d, logp, logp_k)
            else:
                traj.store(s, a, r, s1, d, logp)

            # 2.6 计算累积奖励 episode_reward
            episode_r = r
            
            # 2.7 添加一个成本约束
            if self.commscopecontrol == True:
                c_r = c

            if hasattr(env, '_comparable_reward'):
                episode_r = env._comparable_reward()
            if episode_r.ndim > 1:
                episode_r = episode_r.mean(axis=0)
                if self.commscopecontrol == True:
                    c_r = c_r.mean(axis=0)
            #if episode_r.ndim == 1:
                #episode_r = episode_r.sum()

            self.episode_reward += episode_r
            if self.commscopecontrol == True:
                self.cost_reward += c_r

            self.episode_len += 1
            self.logger.log(interaction=None)
            if self.episode_len == self.max_episode_len:
                d = np.zeros(d.shape, dtype=np.float32)
            d = np.array(d)
            
            # 2.7 对不同的环境做一些简单的调整，主要是判断一下奖励信息是否正确
#-----------------------------------------------------------------------------------------  
            # for CACC_env(catchup and slowdown)
            if self.env_name == 'catchup' or self.env_name == 'slowdown':  
                if self.env_name == 'catchup':
                
                    # 判断是否当前回合已经结束，或者回合的最大长度已经达到
                    if self.episode_len == self.max_episode_len:                         
                        
                        # 记录当前回合奖励，回合长度以及其他信息。check 一下reward是否=0，看一下reset 是否成功
                        self.logger.log(episode_reward=self.episode_reward.sum()/600, episode_len = self.episode_len, episode=None)
                        if self.commscopecontrol == True:
                            train_scope = torch.stack(k_list).float().mean().item()
                            self.logger.log(train_scope_mean = train_scope, train_cost = self.cost_reward.sum().item())
                        try:
                            _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                        except Exception as e:
                            print('reset error!:', e)
                            _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                            if self.model_based == False:
                                trajs += traj.retrieve()
                                traj = TrajectoryBuffer(device=self.device)
                    
                    # 如果是 model-based 则会从buffer中检索轨迹，并创建一个学习心得TrajectoryBuffer来存储轨迹
                    if self.episode_len == self.max_episode_len:
                        if self.model_based:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)    
       
                elif self.env_name == 'slowdown':                                      
                    if d.any() or (self.episode_len == self.max_episode_len):      #for slowdown   
                    #if self.episode_len == self.max_episode_len:  
                        self.logger.log(episode_reward=self.episode_reward.sum()/600, episode_len = self.episode_len, episode=None)
                        if self.commscopecontrol == True:
                            train_scope = torch.stack(k_list).float().mean().item()
                            self.logger.log(train_scope_mean = train_scope, test_cost = self.cost_reward.sum().item() / 600)
                        try:
                            _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                            if self.commscopecontrol == True:
                                self.cost_reward = 0
                        except Exception as e:
                            print('reset error!:', e)
                            _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                            if self.model_based == False:
                                trajs += traj.retrieve()
                                traj = TrajectoryBuffer(device=self.device)                           
                    if self.episode_len == self.max_episode_len:
                        if self.model_based:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
#----------------------------------------------------------------------------------------- 
            elif self.env_name == 'eight' or self.env_name == 'ring':          
                # if d.any() or (self.episode_len == self.max_episode_len):     
                if self.episode_len == self.max_episode_len:                                
                    self.logger.log(episode_reward=self.episode_reward.sum(), episode_len = self.episode_len, episode=None)
                    if self.commscopecontrol == True:
                        train_scope = torch.stack(k_list).float().mean().item()
                        self.logger.log(train_scope_mean = train_scope, test_cost = self.cost_reward.sum().item())
                    try:
                        self.episode_reward, self.episode_len = 0, 0 #TODO:catch up the error
                        if self.commscopecontrol == True:
                            self.cost_reward = 0
                    except Exception as e:
                        print('reset error!:', e)
                        self.episode_reward, self.episode_len =  0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
            elif self.env_name == 'PowerGrid':
            # for other_env
                if d.any() or (self.episode_len == self.max_episode_len):      
                # if self.episode_len == self.max_episode_len:                 
                    self.logger.log(episode_reward=self.episode_reward.sum()/env.T, episode_len = self.episode_len, episode=None)
                    if self.commscopecontrol == True:
                        train_scope = torch.stack(k_list).float().mean().item()
                        self.logger.log(scope_mean = train_scope, episode_cost_reward = self.cost_reward.sum().item()/env.T)
                    try:
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                        if self.commscopecontrol == True:
                            self.cost_reward = 0
                    except Exception as e:
                        print('reset error!:', e)
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
                            
                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
            elif self.env_name == 'Real_Power':
            # for other_env
                if d.any() or (self.episode_len == self.max_episode_len):      
                # if self.episode_len == self.max_episode_len:                 
                    
                    self.logger.log(episode_reward=self.episode_reward.sum(), episode_len = self.episode_len, episode=None, totally_controllable_ratio=totally_controllable_ratio)
                    try:
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                    except Exception as e:
                        print('reset error!:', e)
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
                            
                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
            elif self.env_name == 'Large_city':  
                if self.episode_len == self.max_episode_len:      
                # if self.episode_len == self.max_episode_len:                 
                    
                    self.logger.log(episode_reward=self.episode_reward.sum(), episode_len = self.episode_len, episode=None)
                    if self.commscopecontrol == True:
                        train_scope = torch.stack(k_list).float().mean().item()
                        self.logger.log(train_scope_mean = train_scope, test_cost = self.cost_reward.sum().item())
                    try:
                        self.env_learn.clear()
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0#TODO:catch up the error
                        if self.commscopecontrol == True:
                            self.cost_reward = 0
                    except Exception as e:
                        print('reset error!:', e)
                        self.env_learn.clear()
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)
                            
                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
            else:
            # for other_env
                if d.any() or (self.episode_len == self.max_episode_len):      
                # if self.episode_len == self.max_episode_len:                 
                    self.logger.log(episode_reward=self.episode_reward.sum(), episode_len = self.episode_len, episode=None)
                    if self.commscopecontrol == True:
                        train_scope = torch.stack(k_list).float().mean().item()
                        self.logger.log(train_scope_mean = train_scope, test_cost = self.cost_reward.sum().item())
                    try:
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0 #TODO:catch up the error
                        if self.commscopecontrol == True:
                            self.cost_reward = 0
                    except Exception as e:
                        print('reset error!:', e)
                        _, self.episode_reward, self.episode_len = self.env_learn.reset(), 0, 0  # TODO:catch up the error
                        if self.model_based == False:
                            trajs += traj.retrieve()
                            traj = TrajectoryBuffer(device=self.device)

                if self.episode_len == self.max_episode_len:
                    if self.model_based:
                        trajs += traj.retrieve()
                        traj = TrajectoryBuffer(device=self.device)
#--------------------------------------------------------------------------------------    
        end = time.time()
        print('time in 1 episode is ',end-start)
        trajs += traj.retrieve(length=self.max_episode_len)
        self.logger.log(env_rollout_time=time.time()-time_t)
        return trajs
    
    def rollout_model(self, trajs, length=0):
        '''
        function name: rollout_model
        propose: Use the environment model to collect data
        inputs: trajs
        outputs: trajs
        '''

        time_t = time.time()
        n_traj = self.n_traj
        if length <= 0:
            length = self.model_traj_length
        s = [traj['s'] for traj in trajs]

        s = torch.stack(s, dim=0)
        b, T, n, depth = s.shape
        s = s.view(-1, n, depth)
        idxs = torch.randint(low=0, high=b * T, size=(n_traj,), device=self.device)
        s = s.index_select(dim=0, index=idxs)

        trajs = TrajectoryBuffer(device=self.device)
        for _ in range(length):
            #a, logp = self.agent.act(s, requires_log=True)

            dist = self.agent.act(s)
            a = dist.sample()
            
            
            #if self.env_name == 'Real_Power':
                #a = translate_action_2(self.env_learn.args.action_bias, self.env_learn.args.action_scale, a)
            

            logp = dist.log_prob(a)
            r, s1, d, _ = self.agent.model_step(s, a)
            
            #if self.env_name == 'Real_Power':
                #r = torch.clamp(r, min=-1.0, max=0)
            


            if self.env_name == 'UAV_9d':
               env = self.env_learn
               s = env.get_model_state(s,self.device)
               s1 = env.get_model_state(s1,self.device)
               r = env.get_model_reward(s1,self.device)

            trajs.store(s, a, r, s1, d, logp)
            s = s1
        trajs = trajs.retrieve()
        self.logger.log(model_rollout_time=time.time()-time_t)
        return trajs
    
    def updateModel(self, n=0):
        if n <= 0:
            n = self.n_model_update
        for i_model_update in trange(n):
            trajs = self.model_buffer.sampleTrajs(self.model_batch_size)
            trajs = [traj.getFraction(length=self.model_update_length) for traj in trajs]
            
            self.agent.updateModel(trajs, length=self.model_update_length)

            if i_model_update % self.model_validate_interval == 0:
                validate_trajs = self.model_buffer.sampleTrajs(self.model_batch_size)
                validate_trajs = [traj.getFraction(length=self.model_update_length) for traj in validate_trajs]
                rel_error = self.agent.validateModel(validate_trajs, length=self.model_update_length)
                if rel_error < self.model_error_thres:
                    break
        self.logger.log(model_update = i_model_update + 1)

    def testModel(self, n = 0):
        trajs = self.model_buffer.sampleTrajs(self.model_batch_size)
        trajs = [traj.getFraction(length=self.model_update_length) for traj in trajs]
        return self.agent.validateModel(trajs, length=self.model_update_length)

