from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
from multiprocessing import Pipe, Process
import numpy as np
import torch as th
import os 
import time
from scipy.stats import kendalltau
from itertools import combinations
from components.reward_model import reward_model,team_reward_model
from components.transforms import OneHot

# Based (very) heavily on SubprocVecEnv from OpenAI Baselines
# https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py
class ParallelRunner:

    def __init__(self, args, logger):
        self.args = args
        self.device = 'cuda' if self.args.use_cuda else 'cpu'
        self.logger = logger
        self.batch_size = self.args.batch_size_run        
        
        # Make subprocesses for the envs
        self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.batch_size)])
        env_fn = env_REGISTRY[self.args.env]
        self.ps = []

        for i, worker_conn in enumerate(self.worker_conns):
            ps = Process(target=env_worker, 
                    args=(worker_conn, CloudpickleWrapper(partial(env_fn, **self.args.env_args))))
            self.ps.append(ps)

        for p in self.ps:
            p.daemon = True
            p.start()
            
        #self.zeros_kendal = th.zeros((self.batch_size,1,1))
        self.parent_conns[0].send(("get_env_info", None))
        self.env_info = self.parent_conns[0].recv()
        self.episode_limit = self.env_info["episode_limit"]
        self.one_hot = OneHot(self.env_info["n_actions"])
        self.agent_order = OneHot(self.env_info["n_agents"])
                
        self.use_human_pref = self.args.use_human_pref
        self.reward_models = [reward_model(self.env_info["state_shape"],self.env_info["obs_shape"]+self.env_info["n_agents"]+self.env_info["n_actions"],self.env_info["n_actions"]).to(self.device) for _ in range(self.args.n_reward_functions) ] 


        #self.reward_model = reward_model(48,42,9).to(self.device)
        if self.args.use_team_reward :
            self.team_reward_models = [team_reward_model(self.env_info["state_shape"],self.env_info["obs_shape"]+self.env_info["n_agents"]+self.env_info["n_actions"],self.env_info["n_actions"],n_agents=self.env_info["n_agents"]).to(self.device) for _ in range(self.args.n_reward_functions)] 
           #self.team_reward_model = team_reward_model(48,42,9).to(self.device)
        
        self.get_reward_function(initial=True)
        
        
        self.t = 0

        self.t_env = 0

        self.train_returns = []
        self.ori_train_returns = []
        self.ori_ind_returns = []
        self.ori_team_returns = []
        
        self.test_returns = []
        self.ori_test_returns = []
        self.ori_test_ind_returns = []
        self.ori_test_team_returns = []
        self.train_stats = {}
        self.test_stats = {}
        self.actions_all = th.zeros((self.batch_size,self.env_info["n_agents"]))
        self.log_train_stats_t = -100000
        
    def get_reward_function(self,path1='',path2='',initial=False) :
        path_ori_1 = path1
        path_ori_2 = path2
        for reward_num in range(self.args.n_reward_functions) :
            if initial :
                path1 = '{}/{}_{}_{}_{}_0_{}.pt'.format(self.args.reward_model_path,self.args.LLM_model,self.args.map_name,self.args.reward_model_info,self.args.model_info_num,reward_num)
                path2 = '{}/{}_{}_{}_{}_0_{}_team.pt'.format(self.args.reward_model_path,self.args.LLM_model,self.args.map_name,self.args.reward_model_info,self.args.model_info_num,reward_num)
            else :
                path1 = path_ori_1+'{}.pt'.format(reward_num)
                path2 = path_ori_2+'{}_team.pt'.format(reward_num)
            
            print(path1)
            print(path2)
            if os.path.exists(path1) :
                self.reward_models[reward_num].load_state_dict(th.load(path1))
                print('ind : load successfully')
                if self.args.use_team_reward and os.path.exists(path2):
                    self.team_reward_models[reward_num].load_state_dict(th.load(path2))
                    print('team : load successfully')
            else :
                print('cannot find saved model')
            
    def get_reward(self, traj, end, n_agent,reward_num,team=False) :
        
        #for name, param in self.reward_model.named_parameters():
            #print(f"Parameter: {name}, Requires_grad: {param.requires_grad}")

        r = 0
        s = traj['state'][0,:end,:]
        #print(s.requires_grad)

        s_next = traj['state'][0,1:end+1,:]
        #print(s_next.requires_grad)
        end_point = traj['obs'].shape[3]-self.env_info["n_agents"]
        start_point = end_point-self.env_info["n_actions"]
        #print(end_point)
        #print(start_point)
        if team :
            for agent in range(n_agent) :
                action = traj['obs'][0,1:end+1,agent,start_point:end_point]
                if agent == 0 :
                    action_all = action
                else :
                    action_all = th.cat((action_all,action),dim=-1)
            r = th.sum(self.team_reward_models[reward_num](s,s_next,action_all))
            
        else :
            for agent in range(n_agent) :
                obs = traj['obs'][0,:end,agent,:]
                action = traj['obs'][0,1:end+1,agent,start_point:end_point]
                r_step = self.reward_models[reward_num](s,s_next,obs,action)

                r = r+ th.sum(r_step)
        return r

    
    def _build_inputs(self, batch, t,n_agents):
        # Assumes homogenous agents with flat observations.
        # Other MACs might want to e.g. delegate building inputs to each agent
        bs = batch.batch_size
        inputs = []
        inputs.append(batch["obs"][:, t])  # b1av
        if self.args.obs_last_action:
            if t == 0:
                inputs.append(th.zeros_like(batch["actions_onehot"][:, t]))
            else:
                inputs.append(batch["actions_onehot"][:, t-1])
        if self.args.obs_agent_id:
            inputs.append(th.eye(n_agents, device=batch.device).unsqueeze(0).expand(bs, -1, -1))

        inputs = th.cat([x.reshape(bs, n_agents, -1) for x in inputs], dim=-1)
        return inputs
    
    def update_reward_in_replay_memory(self,replay_data:EpisodeBatch) :   
        max_ep_t = replay_data.max_t_filled()
        actions = replay_data["actions"].to(self.device) # [88,61,3,1]
        state = replay_data["state"].to(self.device) # [88,61,48]
        state_next = replay_data["state"].to(self.device) 
        obs = replay_data["obs"].to(self.device) # [88,61,3,30]
        reward = replay_data["reward"].to(self.device)
        batch_size = obs.shape[0]
        n_agents = obs.shape[2]
        
        start_point = obs.shape[3]
        end_point = start_point+self.env_info["n_actions"]
        for step in range(max_ep_t-1) :       
            reward[:,step] = 0
            s = state[:,step,:].to(self.device)
            s_next = state[:,step+1,:].to(self.device)
            obs = self._build_inputs(replay_data,step,n_agents)
            for n_agent in range(n_agents) :
                o = obs[:,n_agent,:].to(self.device)
                a = obs[:,n_agent,start_point:end_point].to(self.device)
                ind_reward = 0
                for reward_num in range(self.args.n_reward_functions) :
                    ind_reward = ind_reward + self.reward_models[reward_num](s,s_next,o,a)
                ind_reward = ind_reward / self.args.n_reward_functions
                reward[:,step] += ind_reward
                if n_agent == 0 :
                    a_all = a 
                else :
                    a_all = th.cat((a_all,a)).to(self.device)
            team_reward = self.team_reward_model(s,s_next,a_all)
        replay_data['reward'] = reward
        return replay_data, team_reward

    def get_one_step_reward(self, traj, step, reward_num,n_agent) :
        r = [0 for i in range(n_agent)]
        s = traj['state'][0,step,:].reshape((1,-1))
        s_next = traj['state'][0,step+1,:].reshape((1,-1))
        end_point = traj['obs'].shape[3]-self.env_info["n_agents"]
        start_point = end_point-self.env_info["n_actions"]
        
        for agent in range(n_agent) :
            obs = traj['obs'][0,step,agent,:].reshape((1,-1))
            action = traj['obs'][0,step+1,agent,start_point:end_point].reshape((1,-1))
            r[agent] = self.reward_models[reward_num](s,s_next,obs,action)
        return r   

    def get_loss(self,num_1,num_2,r_hat,rank) :
        p = (th.exp(r_hat[num_1])+1e-6)/(th.exp(r_hat[num_1])+th.exp(r_hat[num_2])+1e-6)
        if rank[num_1] < rank[num_2] :
            pref = 1
        elif rank[num_1] > rank[num_2] :
            pref = 0 
        else :
            pref = 0.5
        loss_1 = -1 * ((pref* th.log(p) + (1-pref)*th.log(1-p)))
        return p,loss_1
    
    def load_traj(self,num=1, save_directory='replay_buffer/') :
        data = ['actions','avail_actions','obs','probs','reward','state','terminated']
        replay_data = dict()
        for d in data :
            replay_data[d] = th.load(save_directory+'{}_{}.pt'.format(num,d)).to(self.device)
        return replay_data
    
    def setup(self, scheme, groups, preprocess, mac):
        self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
                                 preprocess=preprocess, device=self.args.device)
        self.mac = mac
        self.scheme = scheme
        self.groups = groups
        self.preprocess = preprocess

    def get_env_info(self):
        return self.env_info

    def save_replay(self):
        self.env.save_replay()
        
    def update_reward_model(self,
                            seq=0,
                            traj_path='replay_buffer/3m_LLM_1_0',
                            device='cuda',
                            model_path='save_reward_model',
                            training_steps=100,
                            update_threshold=0.5,
                            use_history_len = 3,
                            info_num = 0 
                           ) :
        
        ### Set training data ######################################
        
        ### 1. Trajectory Preference
        pref_traj_all = list()
        pref_step_all = list()
        
        start_num = 0
        if seq < use_history_len :
            start_num = 0
        else :
            start_num = seq-use_history_len

        if self.args.compare_team :
            for traj_num in range(start_num,seq) :
                save_traj_pref_directory = '{}/{}_{}_{}_{}_traj.pt'.format(self.args.preference_save_path,self.args.LLM_model,self.args.map_name,self.args.reward_model_info,traj_num)
                pref_traj_data = th.load(save_traj_pref_directory)
                pref_traj_all.append(pref_traj_data)
                                            
        ### Step Preference
        if self.args.compare_agents :
            for traj_num in range(start_num,seq) :
                save_step_pref_directory = '{}/{}_{}_{}_{}_step.pt'.format(self.args.preference_save_path,self.args.LLM_model,self.args.map_name,self.args.reward_model_info,traj_num)
                pref_step_data = th.load(save_step_pref_directory)
                pref_step_all.append(pref_step_data)


        
        ### Training Reward models #################################
        n_agents = self.args.n_agents
        traj_path = '{}/{}_{}_{}_{}/'.format(self.args.replay_buffer_save_path,self.args.LLM_model,self.args.map_name,self.args.reward_model_info,seq-1)
        for reward_num in range(self.args.n_reward_functions) :

            ### Set the model : Load data and set optimization
            load_model_path = model_path+'/{}_{}_{}_{}_{}_{}.pt'.format(self.args.LLM_model,self.args.map_name,self.args.reward_model_info,info_num,seq-1,reward_num)
            self.reward_models[reward_num].load_state_dict(th.load(load_model_path))
            #self.reward_model.load_state_dict(th.load(load_model_path))
            opt= th.optim.Adam(self.reward_models[reward_num].parameters(),lr=self.args.reward_model_lr)
            #opt= th.optim.Adam(self.reward_model.parameters(),lr=self.args.reward_model_lr)
           
            save_model_path = model_path+'/{}_{}_{}_{}_{}_{}.pt'.format(self.args.LLM_model,self.args.map_name,self.args.reward_model_info,info_num,seq,reward_num)

            if self.args.use_team_reward :
                load_model_path2 = model_path+'/{}_{}_{}_{}_{}_{}_team.pt'.format(self.args.LLM_model,self.args.map_name,self.args.reward_model_info,info_num,seq-1,reward_num)
                self.team_reward_models[reward_num].load_state_dict(th.load(load_model_path2))
                opt2 = th.optim.Adam(self.team_reward_models[reward_num].parameters(),lr=self.args.reward_model_lr)#*10)
                save_model_path2 = model_path+'/{}_{}_{}_{}_{}_{}_team.pt'.format(self.args.LLM_model,self.args.map_name,self.args.reward_model_info,info_num,seq,reward_num)

            mse_loss_fn = th.nn.MSELoss()
            
            reset_cnt = 0 
            training_step = 0
            while True :
                reset = False
                if self.args.compare_team :
                    loss_cnt = 0 
                    loss_team_cnt = 0
                    ## Trajectory
                    loss = th.zeros(1, requires_grad=True).to(device)
                    loss_team_vs_ind = th.zeros(1, requires_grad=True).to(device) 
                    loss_team = th.zeros(1, requires_grad=True).to(device)
                    for i,pref_traj_data in enumerate(pref_traj_all) :
                        traj_path = '{}/{}_{}_{}_{}/'.format(self.args.replay_buffer_save_path,self.args.LLM_model,self.args.map_name,self.args.reward_model_info,i+start_num)
                        for t_1, t_2, pref in pref_traj_data :
                            t_1 = int(t_1.item())
                            t_2 = int(t_2.item())
                            pref = float(pref.item())
                            #print(pref)
                            if t_1 != 0 and t_2 != 0 :
                                ### When the length is fixed
                                traj_1 = self.load_traj(num=t_1,save_directory=traj_path)
                                traj_2 = self.load_traj(num=t_2,save_directory=traj_path)
                                end_1 = th.where(traj_1['terminated']==1)[1].item()
                                end_2 = th.where(traj_2['terminated']==1)[1].item()
                                #end_state = [end_1,end_2]
                                if end_1 >= end_2 :
                                    end_state = [end_2,end_2]
                                else :
                                    end_state = [end_1,end_1]
                                    
                                if self.args.use_team_reward :
                                    r_3 = self.get_reward(traj_1, end_state[0],n_agents,reward_num,team=True)
                                    r_4 = self.get_reward(traj_2, end_state[1],n_agents,reward_num,team=True)
                                    p2 = th.exp(r_3) / (th.exp(r_3)+th.exp(r_4)+1e-6)
                                    if p2 != 0.0 and p2 != 1.0 :
                                        loss_team_1 = -1 * ((pref* th.log(p2) + (1-pref)*th.log(1-p2)))
                                        if th.isnan(loss_team_1) == False :
                                            loss_team = loss_team + loss_team_1
                                            loss_team_cnt += 1 
                                            #r_ind_1 = r_1.clone().detach()
                                            #r_ind_1.requires_grad_(False)
                                            #r_ind_2 = r_2.clone().detach()
                                            #r_ind_2.requires_grad_(False)
                                            #loss_team_vs_ind = loss_team_vs_ind + (mse_loss_fn(r_ind_1,r_3) + mse_loss_fn(r_ind_2,r_4)) / 2
                                else :
                                    r_1 = self.get_reward(traj_1, end_state[0],n_agents,reward_num,team=False)
                                    r_2 = self.get_reward(traj_2, end_state[1],n_agents,reward_num,team=False)

                                    '''
                                    r_1 = self.get_reward(traj_1, end_state,n_agents,reward_num,team=False)
                                    r_2 = self.get_reward(traj_2, end_state,n_agents,reward_num,team=False)
                                    '''
                                    p = th.exp(r_1) / (th.exp(r_1)+th.exp(r_2)+1e-6)

                                    if p !=0.0 and p != 1.0 :
                                        loss_1 = -1 * ((pref* th.log(p) + (1-pref)*th.log(1-p)))
                                        #print(loss_1)
                                        if th.isnan(loss_1) == False : 
                                            loss = loss+loss_1
                                            loss_cnt += 1
                                        else :
                                            reset = True
                    if loss_cnt != 0 and self.args.use_team_reward == False :
                        output_loss_1 = loss.item()/loss_cnt
                    else :
                        output_loss_1 = 0
                    if loss_cnt == 0 :
                        reset = True

                if self.args.compare_agents :
                    loss2_cnt = 0 
                    loss2 = th.zeros(1, requires_grad=True).to(device)
                    for i,pref_step_data in enumerate(pref_step_all) :
                        traj_path = '{}/{}_{}_{}_{}/'.format(self.args.replay_buffer_save_path,self.args.LLM_model,self.args.map_name,self.args.reward_model_info,i+start_num)
                        ## Step update
                        for pref_d in pref_step_data :
                            traj_num = pref_d[0]
                            traj_num = int(traj_num.item())
                            if traj_num != 0 :
                                step = pref_d[1]
                                step = int(step.item())
                                rank = pref_d[2:]
                                loss2_all = [0 for _ in range(int((len(rank)*(len(rank)-1)/2))) ]
                                
                                traj = self.load_traj(num=traj_num,save_directory=traj_path)
                                if self.args.scalability_test : 
                                    r_hat = self.get_one_step_reward(traj,step,reward_num,self.args.compare_agents_num)                                
                                else :
                                    r_hat = self.get_one_step_reward(traj,step,reward_num,n_agents)                                
                                pref_cnt = 0
                                for pair in combinations(range(len(rank)),2) :
                                    p,loss2_all[pref_cnt] = self.get_loss(pair[0],pair[1],r_hat,rank)
                                    if p!=0.0 and p!=1.0 and th.isnan(loss2_all[pref_cnt]) == False :
                                        loss2 = loss2 + loss2_all[pref_cnt]
                                        loss2_cnt+=1
                                        pref_cnt+=1
                                    else :
                                        reset = True
                    if loss2_cnt != 0 :
                        output_loss_2 = loss2.item()/loss2_cnt
                    else :
                        reset = True

                #### Back propagate
                if reset is False:
                    if self.args.compare_agents and self.args.compare_team :
                        if loss_cnt != 0 and loss2_cnt != 0 :
                            opt.zero_grad()
                            loss = loss / loss_cnt
                            loss2 = loss2 / loss2_cnt
                            final_loss = (1-self.args.step_weight)*loss + self.args.step_weight*loss2
                            final_loss.backward()
                            opt.step()

                        if self.args.use_team_reward :
                            if loss_team_cnt != 0 :
                                #loss_team = (update_threshold*loss_team + (1-update_threshold)*loss_team_vs_ind) / loss_team_cnt
                                loss_team = loss_team / loss_team_cnt
                            else : 
                                opt2.zero_grad()
                                #loss_team = (update_threshold*loss_team + (1-update_threshold)*loss_team_vs_ind)
                                #loss_team = loss_team + loss_team_vs_ind
                                loss_team.backward()
                                opt2.step()
                            #print('{} : {:.5f}'.format(training_step, loss_team)
                        else :
                            print('{} : {:.5f} / {:.5f}'.format(training_step,output_loss_1,output_loss_2)) 

                    elif self.args.compare_agents == False and self.args.compare_team :
                        if self.args.use_team_reward :
                            if loss_team_cnt != 0 :
                                opt2.zero_grad()
                                loss_team = (update_threshold*loss_team + (1-update_threshold)*loss_team_vs_ind) / loss_team_cnt
                                loss_team.backward()
                                opt2.step()
                            #loss_team = loss_team + loss_team_vs_ind
                        else :
                            if loss_cnt != 0 :
                                opt.zero_grad()
                                final_loss = loss / loss_cnt
                                final_loss.requires_grad_(True)
                                final_loss.backward()
                                opt.step()

                        if self.args.use_team_reward :     
                            if loss_cnt != 0 :
                                output_loss_3 = loss_team.item()/loss_cnt
                                print('{} : {:.5f} '.format(training_step,output_loss_3))       
                        else :
                            print('{} : {:.5f} '.format(training_step,output_loss_1))             

                    elif self.args.compare_agents and self.args.compare_team == False :
                        if loss2_cnt != 0 :
                            opt.zero_grad()
                            loss2 = loss2 / loss2_cnt
                            final_loss = loss2
                            final_loss.backward()
                            opt.step()
                        print('{} : {:.5f} '.format(training_step,output_loss_2))             
                else :
                    print('reset! : {}'.format(reset_cnt+1))
                    self.reward_models[reward_num] = reward_model(self.env_info["state_shape"],self.env_info["obs_shape"]+self.env_info["n_agents"]+self.env_info["n_actions"],self.env_info["n_actions"]).to(self.device)
                    opt= th.optim.Adam(self.reward_models[reward_num].parameters(),lr=self.args.reward_model_lr)

                    if self.args.use_team_reward :
                        self.team_reward_models[reward_num] = team_reward_model(self.env_info["state_shape"],self.env_info["obs_shape"]+self.env_info["n_agents"]+self.env_info["n_actions"],self.env_info["n_actions"],n_agents=self.env_info["n_agents"]).to(self.device)
                        opt2 = th.optim.Adam(self.team_reward_models[reward_num].parameters(),lr=self.args.reward_model_lr)#*10)
                    reset_cnt += 1
                    if reset_cnt >= 20 :
                        training_step = training_steps
                    else :
                        training_step = -1
                    reset = False
                    print(training_step)
                training_step += 1
                if training_step >= training_steps :
                    break
                    

            print(save_model_path)
            if self.args.use_team_reward :
                th.save(self.team_reward_models[reward_num].state_dict(), save_model_path2)
            th.save(self.reward_models[reward_num].state_dict(), save_model_path)
    

    def open_env(self) :
        env_fn = env_REGISTRY[self.args.env]
        self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.batch_size)])
        
        self.ps = list()
        for i, worker_conn in enumerate(self.worker_conns):
            ps = Process(target=env_worker, 
                    args=(worker_conn, CloudpickleWrapper(partial(env_fn, **self.args.env_args))))
            self.ps.append(ps)
        
        for p in self.ps:
            p.daemon = True
            p.start()
        
        #for parent_conn in self.parent_conns:
        #    parent_conn.send(("reset", None))

    def close_env(self):
        for parent_conn in self.parent_conns:
            parent_conn.send(("close", None))

    def reset(self):
        self.batch = self.new_batch()

        # Reset the envs
        for parent_conn in self.parent_conns:
            parent_conn.send(("reset", None))

        pre_transition_data = {
            "state": [],
            "avail_actions": [],
            "obs": []
        }
        # Get the obs, state and avail_actions back
        for parent_conn in self.parent_conns:
            data = parent_conn.recv()
            pre_transition_data["state"].append(data["state"])
            pre_transition_data["avail_actions"].append(data["avail_actions"])
            pre_transition_data["obs"].append(data["obs"])

        self.batch.update(pre_transition_data, ts=0)

        self.t = 0
        self.env_steps_this_run = 0
    
    def get_kendalltau(self,d_set) :
        assert len(d_set) >= 2, "need more than two reward sets"
        k = len(d_set)
        taus = []
        for pair in combinations(range(k),2) :
            tau,_ = kendalltau(d_set[pair[0]], d_set[pair[1]])
            taus.append(tau)
        mean_tau = sum(taus)/len(taus)
        return mean_tau

    def run(self, test_mode=False,kendall_need=True):
        self.reset()

        all_terminated = False
        episode_returns = [0 for _ in range(self.batch_size)]
        ori_episode_returns = [0 for _ in range(self.batch_size)]
        ori_ind_episode_returns = [0 for _ in range(self.batch_size)]
        ori_team_episode_returns = [0 for _ in range(self.batch_size)]
        episode_lengths = [0 for _ in range(self.batch_size)]
        self.mac.init_hidden(batch_size=self.batch_size)
        terminated = [False for _ in range(self.batch_size)]
        envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed]
        final_env_infos = []  # may store extra stats like battle won. this is filled in ORDER OF TERMINATION
        
        save_probs = getattr(self.args, "save_probs", False)
        cnt = 0 
        while True:

            # Pass the entire batch of experiences up till now to the agents
            # Receive the actions for each agent at this timestep in a batch for each un-terminated env
            if save_probs:
                actions, probs = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, bs=envs_not_terminated, test_mode=test_mode)
            else:
                actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, bs=envs_not_terminated, test_mode=test_mode)
                
            cpu_actions = actions.to("cpu").numpy()

            # Update the actions taken
            actions_chosen = {
                "actions": actions.unsqueeze(1).to("cpu"),
            }
            if save_probs:
                actions_chosen["probs"] = probs.unsqueeze(1).to("cpu")
            
            self.batch.update(actions_chosen, bs=envs_not_terminated, ts=self.t, mark_filled=False)
            
            # Send actions to each env
            action_idx = 0
            for idx, parent_conn in enumerate(self.parent_conns):
                if idx in envs_not_terminated: # We produced actions for this env
                    if not terminated[idx]: # Only send the actions to the env if it hasn't terminated
                        parent_conn.send(("step", cpu_actions[action_idx]))
                    action_idx += 1 # actions is not a list over every env

            # Update envs_not_terminated
            envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed]
            all_terminated = all(terminated)
            if all_terminated:
                break

            # Post step data we will insert for the current timestep
            post_transition_data = {
                "reward": [],
                "terminated": [],
                "ori_reward" : [],
                "int_ind_reward":[],
                "int_ind_reward_std":[],
                "int_team_reward":[],
                "int_team_reward_std":[],
                "kendalltau":[],
                "int_ind_reward_all":[],
                "ind_reward":[],
                "ind_terminated":[]                
            }
            # Data for the next step we will insert in order to select an action
            pre_transition_data = {
                "state": [],
                "avail_actions": [],
                "obs": []
            }
            actions_all = self.actions_all.clone()
            for agent,not_terminated in enumerate(envs_not_terminated) :
                actions_all[not_terminated] = actions[agent]
                
            # Receive data back for each unterminated env
            for idx, parent_conn in enumerate(self.parent_conns):
                if not terminated[idx]:
                    data = parent_conn.recv()
                    # Remaining data for this current timestep
                    post_transition_data["ori_reward"].append((data["reward"],))

                    ################################################################
                    ### Intrinsic Individual Reward from Reward Function ###########
                    ################################################################
                    
                    s = data["pre_state"]
                    s_next = data["state"]
                    o = data["pre_obs"]
                    intrinsic_reward = th.zeros(1)
                    intrinsic_ind_reward= th.zeros(1)
                    intrinsic_ind_reward_std = th.zeros(1)
                    intrinsic_team_reward =th.zeros(1) 
                    intrinsic_team_reward_std = th.zeros(1) 
                    ind_agent_terminated = list()
                    int_agent_reward = list()
                    

                    ind_reward_total = [ [] for _ in range(self.args.n_reward_functions) ]
                    for agent in range(self.env_info["n_agents"]) :
                        a_onehot = self.one_hot.transform(actions_all[idx][agent].clone().detach().requires_grad_(True)).to(self.device)
                        if cnt != 0 :
                            a_prev_onehot = self.one_hot.transform(prev_actions[idx][agent].clone().detach().requires_grad_(True)).to(self.device)
                        else :
                            a_prev_onehot = th.zeros((self.env_info["n_actions"])).to(self.device)
                        
                        if agent == 0 :
                            a_onehots = a_onehot.clone()
                        else :
                            a_onehots = th.cat((a_onehots,a_onehot)).to(self.device)

                        o_agent = self.agent_order.transform(th.tensor(agent)).to(self.device)
                        o_all = th.cat((th.tensor(o[agent]).to(self.device),a_prev_onehot,o_agent)).to(self.device)
                        ind_rewards = list()
                        for num in range(self.args.n_reward_functions) :
                            re = self.reward_models[num](th.tensor(s).reshape((1,-1)).to(self.device),th.tensor(s_next).reshape((1,-1)).to(self.device),o_all.reshape((1,-1)),a_onehot.reshape((1,-1)))
                            ind_rewards.append(re)
                            ind_reward_total[num].append(re)
                        ind_rewards = th.tensor(ind_rewards,dtype=th.float)
                        
                        ind_reward = th.mean(ind_rewards)
                        ind_reward_std = th.std(ind_rewards)
                        
                        int_agent_reward.append(ind_reward.item())
                        
                        ## if agent died, reward == 0 
                        #if actions_all[idx][agent].item() > 0.0 :
                        intrinsic_ind_reward += ind_reward
                        intrinsic_ind_reward_std += ind_reward_std         
                            
                            
                        if sum(data["obs"][agent]) == 0 :
                            ind_agent_terminated.append(1)
                        else :
                            ind_agent_terminated.append(0)
     
                    if kendall_need :
                        kendalltau = self.get_kendalltau(th.tensor(ind_reward_total))
                        post_transition_data["kendalltau"].append((kendalltau,))
                    else :
                        post_transition_data["kendalltau"].append((0,))                        
                    intrinsic_ind_reward_std = intrinsic_ind_reward_std / self.env_info["n_agents"]

                    post_transition_data["ind_reward"].append(int_agent_reward)
                    
                    if self.args.use_team_reward :
                        intrinsic_team_rewards = list()
                        for num in range(self.args.n_reward_functions) :
                            intrinsic_team_rewards.append(self.team_reward_models[num](th.tensor(s).reshape((1,-1)).to(self.device),th.tensor(s_next).reshape((1,-1)).to(self.device),a_onehots.reshape((1,-1))))
                        intrinsic_team_rewards = th.tensor(intrinsic_team_rewards,dtype=th.float)
                        
                        intrinsic_team_reward = th.mean(intrinsic_team_rewards)
                        intrinsic_team_reward_std = th.std(intrinsic_team_rewards)
                        
                        intrinsic_reward = intrinsic_team_reward
                        post_transition_data["int_team_reward"].append((intrinsic_team_reward,))
                        post_transition_data["int_team_reward_std"].append((intrinsic_team_reward_std,))
                        
                    else :
                        intrinsic_reward= intrinsic_ind_reward
                        post_transition_data["int_team_reward_std"].append((0,))
                        post_transition_data["int_team_reward"].append((0,))
                    if self.args.use_std_to_reward :
                        if self.args.use_team_reward : 
                            intrinsic_reward = intrinsic_reward + (intrinsic_team_reward_std+intrinsic_ind_reward_std)/2
                        else :
                            intrinsic_reward = intrinsic_reward + intrinsic_ind_reward_std
                    #if self.args.use_kendalltau_as_reward and kendalltau <= 0 :
                        #intrinsic_reward = intrinsic_reward - kendalltau
                        
                    if self.args.use_ori_reward :
                        if self.args.use_intrinsic_reward_as_contribution :
                            intrinsic_reward = th.tensor(data["reward"])
                        else :
                            intrinsic_reward = intrinsic_reward*0.2 + data["reward"]
                            #intrinsic_reward = intrinsic_reward * (1-self.args.ori_percent) + data["reward"] * self.args.ori_percent
                    if self.args.env == "sc2" :
                        if 'battle_won' in data["info"] and self.args.use_extrinsic_reward :
                            if data["info"]['battle_won'] :
                                intrinsic_reward = intrinsic_reward+100
                    elif self.args.env == "gfootball" :
                        if 'score' in data["info"] and self.args.use_extrinsic_reward :
                            if data["info"]['score'][0] == 1 or data["info"]['score'][0] == '1' :
                                intrinsic_reward = intrinsic_reward+100
                    post_transition_data["reward"].append((intrinsic_reward.item(),))
                    post_transition_data["int_ind_reward"].append((intrinsic_ind_reward.item(),))
                    post_transition_data["int_ind_reward_std"].append((intrinsic_ind_reward_std.item(),))
                    post_transition_data["int_ind_reward_all"].append(ind_reward_total)
                    episode_returns[idx] += intrinsic_reward
                    ori_episode_returns[idx] += data["reward"]
                    ori_ind_episode_returns[idx] += intrinsic_ind_reward
                    
                    '''
                    if self.args.use_team_reward :
                        ori_team_episode_returns[idx] += intrinsic_team_reward
                    else :
                        ori_team_episode_returns[idx] += 0
                    '''
                    
                    episode_lengths[idx] += 1
                    if not test_mode:
                        self.env_steps_this_run += 1

                    env_terminated = False
                    if data["terminated"]:
                        final_env_infos.append(data["info"])
                    if data["terminated"] and not data["info"].get("episode_limit", False):
                        env_terminated = True
                    terminated[idx] = data["terminated"]
                    post_transition_data["terminated"].append((env_terminated,))
                    post_transition_data["ind_terminated"].append(ind_agent_terminated)

                    # Data for the next timestep needed to select an action
                    pre_transition_data["state"].append(data["state"])
                    pre_transition_data["avail_actions"].append(data["avail_actions"])
                    pre_transition_data["obs"].append(data["obs"])

            # Add post_transiton data into the batch
            self.batch.update(post_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=False)

            # Move onto the next timestep
            prev_actions = actions_all.clone()
            self.t += 1
            cnt+=1

            # Add the pre-transition data
            self.batch.update(pre_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=True)

        if not test_mode:
            self.t_env += self.env_steps_this_run

        # Get stats back for each env
        for parent_conn in self.parent_conns:
            parent_conn.send(("get_stats",None))

        env_stats = []
        for parent_conn in self.parent_conns:
            env_stat = parent_conn.recv()
            env_stats.append(env_stat)

        cur_stats = self.test_stats if test_mode else self.train_stats

        cur_returns = self.test_returns if test_mode else self.train_returns
        ori_cur_returns = self.ori_test_returns if test_mode else self.ori_train_returns
        
        ori_cur_ind_returns = self.ori_test_ind_returns if test_mode else self.ori_ind_returns
        ori_cur_team_returns = self.ori_test_team_returns if test_mode else self.ori_team_returns
        
        log_prefix = "test_" if test_mode else ""
        infos = [cur_stats] + final_env_infos

        cur_stats.update({k: sum(d.get(k, 0) for d in infos) for k in set.union(*[set(d) for d in infos])})
        cur_stats["n_episodes"] = self.batch_size + cur_stats.get("n_episodes", 0)
        cur_stats["ep_length"] = sum(episode_lengths) + cur_stats.get("ep_length", 0)

        cur_returns.extend(episode_returns)
        ori_cur_returns.extend(ori_episode_returns)
        ori_cur_ind_returns.extend(ori_ind_episode_returns)
        ori_cur_team_returns.extend(ori_team_episode_returns)
        
        #print(cur_returns)
        #print(ori_cur_returns)

        n_test_runs = max(1, self.args.test_nepisode // self.batch_size) * self.batch_size
        if test_mode and (len(self.test_returns) == n_test_runs):
            self._log(cur_returns,ori_cur_returns,ori_cur_ind_returns,ori_cur_team_returns,cur_stats, log_prefix)
        elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval:
            self._log(cur_returns,ori_cur_returns,ori_cur_ind_returns,ori_cur_team_returns, cur_stats, log_prefix)
            if hasattr(self.mac.action_selector, "epsilon"):
                self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env)
            self.log_train_stats_t = self.t_env

        return self.batch

    def _log(self, returns, ori_returns,ori_cur_ind_returns,ori_cur_team_returns,stats, prefix):
        self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env)
        self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env)
        self.logger.log_stat(prefix + "ori_return_mean", np.mean(ori_returns), self.t_env)
        self.logger.log_stat(prefix + "ori_return_std", np.std(ori_returns), self.t_env)
        self.logger.log_stat(prefix + "ind_return_mean", np.mean(ori_cur_ind_returns), self.t_env)
        self.logger.log_stat(prefix + "ind_return_std", np.std(ori_cur_ind_returns), self.t_env)
        self.logger.log_stat(prefix + "team_return_mean", np.mean(ori_cur_team_returns), self.t_env)
        self.logger.log_stat(prefix + "team_return_std", np.std(ori_cur_team_returns), self.t_env)
        returns.clear()

        for k, v in stats.items():
            if k != "n_episodes":
                self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env)
        stats.clear()


def env_worker(remote, env_fn):
    # Make environment
    env = env_fn.x()
    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            actions = data
            # Take a step in the environment
            pre_state = env.get_state()
            pre_obs = env.get_obs()
            reward, terminated, env_info = env.step(actions)
            # Return the observations, avail_actions and state to make the next action
            state = env.get_state()
            avail_actions = env.get_avail_actions()
            obs = env.get_obs()
            remote.send({
                # Data for the next timestep needed to pick an action
                "pre_state":pre_state,
                "state": state,
                "avail_actions": avail_actions,
                "obs": obs,
                "pre_obs" :pre_obs,
                # Rest of the data for the current timestep
                "reward": reward,
                "terminated": terminated,
                "info": env_info
            })
        elif cmd == "reset":
            env.reset()
            remote.send({
                "state": env.get_state(),
                "avail_actions": env.get_avail_actions(),
                "obs": env.get_obs()
            })
        elif cmd == "close":
            env.close()
            remote.close()
            break
        elif cmd == "get_env_info":
            remote.send(env.get_env_info())
        elif cmd == "get_stats":
            remote.send(env.get_stats())
        else:
            raise NotImplementedError


class CloudpickleWrapper():
    """
    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
    """
    def __init__(self, x):
        self.x = x
    def __getstate__(self):
        import cloudpickle
        return cloudpickle.dumps(self.x)
    def __setstate__(self, ob):
        import pickle
        self.x = pickle.loads(ob)

