from __future__ import division

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import json
from model import A3C_Single
from utils import ensure_shared_grads

class HLoss(nn.Module):
    def __init__(self):
        super(HLoss, self).__init__()

    def forward(self, x, prior=None):
        if prior is None:
            b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
            b = -b.sum(1)
            b = b.mean()
        else:
            b = F.softmax(x, dim = -1)
            b = b * (F.log_softmax(x, dim = -1) - torch.log(prior).view(-1, x.size(-1)))
            b = -b.sum(-1)
            b = b.mean()
        return b

class Agent(object):
    def __init__(self, model, env, args, state, device):
        self.model = model
        self.env = env
        self.num_agents = env.n
        #self.num_targets = env.observation_space.shape[1]
        self.num_targets = env.num_target
        self.state_dim = env.observation_space.shape[2]
        self.model_name = args.model 
        self.trigger_threshold = 0.45
        self.prior = torch.FloatTensor(np.array([0.7, 0.3]))  # communication edge prior

        if 'continuous' in args.model:
            self.continuous = True
            self.action_high = [env.action_space[i].high for i in range(self.num_agents)]
            self.action_low = [env.action_space[i].low for i in range(self.num_agents)]
            self.dim_action = env.action_space[0].shape[0]
        else:
            self.dim_action = 1
            self.continuous = False

        self.model_name = args.model
        self.eps_len = 0
        self.eps_num = 0
        self.args = args
        self.values = []
        self.log_probs = []
        self.rewards = []
        self.entropies = []
        self.rewards_eps = []
        self.done = True
        self.info = None
        self.reward = 0
        self.device = device
        self.lstm_out = args.lstm_out
        self.reward_mean = None
        self.reward_std = 1
        self.num_steps = 0
        self.n_steps = 0
        self.vk = 0
        self.state = state
        self.rank = 0
        # evaluation for ToM & Comm
        self.comm_ToM_loss = torch.zeros(1)
        self.no_comm_ToM_loss = torch.zeros(1)
        self.ToM_loss = torch.zeros(1)

        self.hself = torch.zeros(self.num_agents, self.lstm_out).to(device)
        #self.hothers = torch.zeros(self.num_agents, self.num_agents-1, self.lstm_out).to(device)
        if 'ToM-v5' in self.model_name:
            self.hToM = torch.zeros(self.num_agents, self.num_agents, self.lstm_out).to(device)
        elif 'ToM' in self.model_name:
            self.hToM = torch.zeros(self.num_agents, self.num_agents-1, self.lstm_out).to(device)
        else:
            self.hToM = torch.zeros(1)
        self.ToM_msgs = torch.zeros(self.num_agents,self.num_agents-1, self.lstm_out).to(device)
        self.real_msgs = torch.zeros(self.num_agents,self.num_agents-1, self.num_targets, self.state_dim).to(device)
        self.cam_states = None # cam_dim=3 ndarray
        self.ToM_file = None
        #self.ToM_goals = []
        #self.real_goals = []
        self.ToM_history = []
        self.Policy_history = []
        self.step_history = []
        self.loss_history = []
        self.ToM_targets = []
        self.real_targets = []
        self.edge_logits = []
        #self.prev_ToM = torch.zeros(self.num_agents, self.num_agents-1, self.num_targets, 1)
        #self.prev_ToM = torch.cat((1 - self.prev_ToM, self.prev_ToM),-1)  # all ToM goals are initially set 0 

    def msg_transfer(self, msg, current_cam, target_cam):
        # current cam : k
        # target cam : i
        new_msg = msg
        new_agent_idx = target_cam / self.num_agents
        target_x = self.cam_states[current_cam][0] + msg[2] * torch.cos(msg[3] + self.cam_states[current_cam][2])
        target_y = self.cam_states[current_cam][1] + msg[2] * torch.sin(msg[3] + self.cam_states[current_cam][2])
        x_delt = target_x - self.cam_states[target_cam][0]
        y_delt = target_y - self.cam_states[target_cam][1]
        new_distance = torch.sqrt(y_delt * y_delt + x_delt * x_delt)
        new_angle = torch.atan2(y_delt, x_delt) / np.pi * 180 - self.cam_states[target_cam][2]
        if new_angle > 180:
            new_angle -= 360
        if new_angle < -180:
            new_angle += 360
        new_msg[0] = new_agent_idx
        new_msg[2] = new_distance
        new_msg[3] = new_angle
        return new_msg

    def obs_msg_fusion(self):
        # state [n,m,obs_dim]
        # msgs [n,n-1,m,obs_dim]
        for i in range(self.num_agents):
            for j in range(self.num_targets):
                if (self.state[i,j]==torch.zeros(self.state_dim)).all():
                    # target j not observed by agent i, try to search it in msgs
                    Found=False
                    for k in range(self.num_agents):
                        # msgs sent by other agents
                        if i==k: continue
                        if i>k and (self.real_msgs[k,i-1,j]!=torch.zeros(self.state_dim)).any():
                            msg=self.real_msgs[k,i-1,j]
                            Found=True
                            break
                        elif i<k and (self.real_msgs[k,i,j]!=torch.zeros(self.state_dim)).any():
                            msg=self.real_msgs[k,i,j]
                            Found=True
                            break
                    if Found:
                        self.state[i,j]=self.msg_transfer(msg,k,i)

    def get_cam_states(self):
        cam_states = self.env.get_cam_states()
        cam_states = torch.from_numpy(np.array(cam_states)).float().to(self.device)
        return cam_states
    
    def get_comm_domain(self):
        comm_domain = self.env.get_comm_domain()
        comm_domain = torch.from_numpy(comm_domain).unsqueeze(-1).bool()
        comm_domain = comm_domain.to(self.device)
        return comm_domain

    def get_available_actions(self):
        available_actions = self.env.get_available_actions()
        available_actions = torch.from_numpy(available_actions).to(self.device)
        return available_actions

    def punish(self, actions, num_agents, num_targets):
        # each target should be tracked by one and only one camera, so punish other actions
        punish = np.abs(np.sum(actions,0) - 1)
        punish = self.args.punish_rate * np.sum(punish)/(num_agents * num_targets)
        print("punish")
        return punish

    def action_train(self):
        self.n_steps += 1
        if self.args.mask_actions:
            available_actions = self.get_available_actions()
            available_actions_data = available_actions.cpu().numpy()
        else:
            available_actions = None
            available_actions_data = 0
        if 'ToM-v1' in self.model_name:
            self.cam_states=self.env.get_cam_states()
            self.obs_msg_fusion()
            value_multi, actions, entropy, log_prob, hn_self, hn_ToM, next_ToM_msgs, next_real_msgs, ToM_goals = \
                self.model(self.state, self.hself, self.hToM,self.cam_states,self.ToM_msgs)
        elif 'ToM-v2' in self.model_name:
            self.cam_states = self.get_cam_states()
            value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goals, ToM_target, real_cover, probs =\
                self.model(self.state, self.hself, self.hToM, self.cam_states)
        elif 'comm' in self.model_name:
            value_multi, actions, entropy, log_prob, hn_self, edge_logits, comm_edges, probs =\
                self.model(self.state, self.hself)
        elif 'ToM-v4' in self.model_name:
            self.cam_states = self.get_cam_states()
            value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goals, edge_logits, comm_edges, probs =\
                self.model(self.state, self.hself, self.hToM, self.cam_states, available_actions = available_actions)
        elif 'ToM-v5' in self.model_name:
            self.cam_states = self.get_cam_states()
            self.comm_domain = self.get_comm_domain()
            value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goals, edge_logits, comm_edges, probs, real_cover, ToM_target_cover =\
                self.model(self.state, self.hself, self.hToM, self.cam_states, self.comm_domain, available_actions = available_actions)
        elif 'center' in self.model_name:
            self.cam_states = self.get_cam_states()
            value_multi, actions, entropy, log_prob, probs = self.model(self.state, self.cam_states)
        else:
            value_multi, actions, entropy, log_prob, probs = self.model(self.state, available_actions = available_actions)
        
        actions_env = actions.cpu().numpy() # only ndarrays can be processed by the environment
        # if self.env.reset_type != 3:
        #     state_multi, reward_multi, self.done, self.info = self.env.step(actions)
        # else:
        state_multi, reward_multi, self.done, self.info = self.env.step(actions_env)#,obstacle=True)
        if "center" not in self.model_name:
            reward_multi = reward_multi.repeat(self.num_agents) # all agents share the same reward
        # save original reward for evaluation, but normalize it for training
        self.reward_org = reward_multi.copy()
        # punish actions for collision or neglection

        if self.args.punish_rate > 0:
            punish = self.punish(actions_env, self.num_agents, self.num_targets)
            reward_multi -= punish
            print(punish)
        if self.args.norm_reward:
            reward_multi = self.reward_normalizer(reward_multi)

        # save state for training
        if 'decentralized' in self.model_name or 'comm' in self.model_name:
            Policy_data = {"state":self.state.detach().cpu().numpy(), "cam_states": np.zeros((1,1)), "actions": actions_env, "reward": reward_multi,\
                "comm_domain": np.zeros((1,1,1)), "available_actions": available_actions_data}
        elif 'center' in self.model_name:
            Policy_data = {"state":self.state.detach().cpu().numpy(), "cam_states": self.cam_states.detach().cpu().numpy(), "actions": actions_env, "reward": reward_multi.repeat(self.num_agents),\
                "comm_domain": np.zeros((1,1,1)), "available_actions": available_actions_data}
        elif 'ToM-v4' in self.model_name or 'ToM-v2' in self.model_name:
            Policy_data = {"state":self.state.detach().cpu().numpy(), "cam_states": self.cam_states.detach().cpu().numpy(), "actions": actions_env, "reward": reward_multi,\
                "comm_domain": np.zeros((1,1,1)), "available_actions": available_actions_data}
            #test_data = {"actions": actions_env, "hself": self.hself.detach().numpy(), "hothers": self.hothers.detach().numpy(), "entropy": entropy.detach().numpy(), "comm":edge_logits.detach().numpy(), "log_probs": log_prob.detach().numpy()}
            #self.step_history.append(test_data)
        elif 'ToM-v5' in self.model_name:
            Policy_data = {"state":self.state.detach().cpu().numpy(), "cam_states": self.cam_states.detach().cpu().numpy(),"actions": actions_env, "reward": reward_multi,\
                "comm_domain":self.comm_domain.detach().cpu().numpy(),"available_actions": available_actions_data}
        self.Policy_history.append(Policy_data)
        
        if 'ToM' in self.model_name:
            real_goals = torch.cat((1-actions,actions),-1)
            if 'v4' in self.model_name or 'v2' in self.model_name:
                ToM_data = {"state":self.state.detach().cpu().numpy(), "cam_states":self.cam_states.detach().cpu().numpy(), "comm_domain": np.zeros((1,1,1)),\
                    "real":real_goals.detach().cpu().numpy(), "available_actions": available_actions_data}
            elif 'v5' in self.model_name:
                ToM_data = {"state":self.state.detach().cpu().numpy(), "cam_states":self.cam_states.detach().cpu().numpy(), "comm_domain":self.comm_domain.detach().cpu().numpy(),\
                    "real":real_goals.detach().cpu().numpy(), "available_actions": available_actions_data}
            self.ToM_history.append(ToM_data)

        if isinstance(self.done, list): self.done = np.sum(self.done)
        self.state = torch.from_numpy(np.array(state_multi)).float().to(self.device)
            
        self.reward = torch.tensor(reward_multi).float().to(self.device)
        self.eps_len += 1
        # self.values.append(value_multi)
        # self.entropies.append(entropy)
        # self.log_probs.append(log_prob)
        # self.rewards.append(self.reward.unsqueeze(1))

        if 'ToM' in self.model_name:
            # self.ToM_goals.append(ToM_goals)
            # #self.real_goals.append(probs.detach())
            # actions = actions.detach()
            # self.real_goals.append(torch.cat((1-actions,actions),-1))
            self.hself=hn_self
            self.hToM=hn_ToM
        if 'comm' in self.model_name:
            self.hself = hn_self
        if 'ToM-v1' in self.model_name:
            self.real_msgs=next_real_msgs
            self.ToM_msgs=next_ToM_msgs
        if 'ToM-v2' in self.model_name:
            pass
            #self.ToM_targets.append(ToM_target)
            #self.real_targets.append(real_target)
        if 'ToM-v4' in self.model_name or 'ToM-v5' in self.model_name:
            pass
            #self.edge_logits.append(edge_logits)
            #self.prev_ToM = ToM_goals

    def action_test(self):
        if self.args.mask_actions:
            available_actions = self.get_available_actions()
        else:
            available_actions = None
        if 'ToM-v1' in self.model_name:
            self.cam_states=self.env.get_cam_states()
            self.obs_msg_fusion()
            with torch.no_grad():
                value_multi, actions, entropy, log_prob, hn_self, hn_ToM, next_ToM_msgs, next_real_msgs, ToM_goals = \
                    self.model(self.state, self.hself, self.hToM,self.cam_states,self.ToM_msgs,True)
        elif 'ToM-v2' in self.model_name:
            with torch.no_grad():
                self.cam_states=self.get_cam_states()
                value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goals, ToM_target, real_cover, probs = \
                    self.model(self.state, self.hself, self.hToM, self.cam_states, True)
        elif 'comm' in self.model_name:
            with torch.no_grad():
                value_multi, actions, entropy, log_prob, hn_self, edge_logits, comm_edges, probs = self.model(self.state, self.hself, True)
                self.env.comm_edges = comm_edges
                self.comm_cnt = torch.sum(comm_edges)
                self.comm_bit = self.comm_cnt * 32
        elif 'ToM-v4' in self.model_name:
            with torch.no_grad():
                self.cam_states = self.get_cam_states()
                value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goals, edge_logits, comm_edges, probs=\
                    self.model(self.state, self.hself, self.hToM, self.cam_states, True, available_actions = available_actions)
                
                self.env.comm_edges = comm_edges
                # for evaluation, should be commented when real training
                '''
                KL_criterion = torch.nn.KLDivLoss(reduction='sum')

                real_goals_duplicate = probs.reshape(1, self.num_agents, self.num_targets, -1).repeat(self.num_agents, 1, 1, 1)
                idx= (torch.ones(self.num_agents, self.num_agents) - torch.diag(torch.ones(self.num_agents))).bool()
                real_goals_duplicate = real_goals_duplicate[idx].reshape(self.num_agents, self.num_agents-1, self.num_targets, -1)
                comm_edges = comm_edges[idx].reshape(self.num_agents, self.num_agents-1, 1, 1)
                
                comm_real_prob = (real_goals_duplicate * comm_edges).float()
                comm_ToM_prob = (ToM_goals * comm_edges).float()
                self.comm_ToM_loss = KL_criterion(comm_ToM_prob.log(),comm_real_prob)

                no_comm_real_prob = (real_goals_duplicate * (1-comm_edges)).float()
                no_comm_ToM_prob = (ToM_goals * (1-comm_edges)).float()
                self.no_comm_ToM_loss = KL_criterion(no_comm_ToM_prob.log(),no_comm_real_prob)

                self.ToM_loss = KL_criterion(ToM_goals.log(),real_goals_duplicate)
                '''
        elif 'ToM-v5' in self.model_name:
            with torch.no_grad():
                self.cam_states = self.get_cam_states()
                self.comm_domain = self.get_comm_domain()
                #prev_ToM = self.prev_ToM
                value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goals, edge_logits, comm_edges, probs, real_cover, ToM_target_cover=\
                    self.model(self.state, self.hself, self.hToM, self.cam_states, self.comm_domain, True, available_actions = available_actions)
                
                self.comm_cnt = torch.sum(comm_edges)
                self.comm_bit = self.comm_cnt * self.num_targets
                self.env.comm_edges = comm_edges

        elif 'center' in self.model_name:
            with torch.no_grad():
                self.cam_states = self.get_cam_states()
                value_multi, actions, entropy, log_prob, probs = self.model(self.state, self.cam_states)
        else:
            with torch.no_grad():
                value_multi, actions, entropy, log_prob, probs = self.model(self.state, True, available_actions = available_actions)

        # if self.env.reset_type != 3:
        #     state_multi, self.reward, self.done, self.info = self.env.step(actions)
        # else:
        state_multi, self.reward, self.done, self.info = self.env.step(actions)#, obstacle=True)

        if isinstance(self.done, list): self.done = np.sum(self.done)
        self.state = torch.from_numpy(np.array(state_multi)).float().to(self.device)
        self.rotation = self.info['cost']
        self.eps_len += 1

        if 'ToM' in self.model_name:
            # update rnn hidden states & msgs
            self.hself=hn_self
            self.hToM=hn_ToM
        if 'comm' in self.model_name:
            self.hself = hn_self
        if 'ToM-v1' in self.model_name:
            self.real_msgs=next_real_msgs
            self.ToM_msgs=next_ToM_msgs
        if 'ToM-v3' in self.model_name:
            pass
            #self.prev_ToM = ToM_goals

    def reset(self):
        obs = self.env.reset()
        self.state = torch.from_numpy(np.array(obs)).float().to(self.device)
        #self.prev_ToM = torch.zeros(self.num_agents, self.num_agents-1, self.num_targets, 1)
        #self.prev_ToM = torch.cat((1 - self.prev_ToM, self.prev_ToM),-1)  # all ToM goals are initially set 0 

        self.eps_len = 0
        self.eps_num += 1
        self.reset_rnn_hidden()
        
        self.model.sample_noise()

    def clean_buffer(self, done):
        # outputs
        self.values = []
        self.log_probs = []
        self.entropies = []
        # gt
        self.rewards = []
        self.obs_tracker = []
        # ToM & real goals should be cleaned after ToM training
        #self.ToM_goals=[]
        #self.real_goals=[]
        self.ToM_targets=[]
        self.real_targets=[]
        self.edge_logits = []
        if done:
            # clean
            self.rewards_eps = []

        return self

    def reward_normalizer(self, reward):
        reward = np.array(reward)
        self.num_steps += 1
        if self.num_steps == 1:
            self.reward_mean = reward
            self.vk = 0
            self.reward_std = 1
        else:
            delt = reward - self.reward_mean
            self.reward_mean = self.reward_mean + delt/self.num_steps
            self.vk = self.vk + delt * (reward-self.reward_mean)
            self.reward_std = np.sqrt(self.vk/(self.num_steps - 1))
        reward = (reward - self.reward_mean) / (self.reward_std + 1e-8)
        return reward

    def reset_rnn_hidden(self):
        self.hself = torch.zeros(self.num_agents, self.lstm_out).to(self.device)
        if 'ToM-v4' in self.model_name:
            self.hToM = torch.zeros(self.num_agents, self.num_agents-1, self.lstm_out).to(self.device)
        elif 'ToM-v5' in self.model_name:
            self.hToM = torch.zeros(self.num_agents, self.num_agents, self.lstm_out).to(self.device)
        self.ToM_msgs = torch.zeros(self.num_agents,self.num_agents-1, self.lstm_out).to(self.device)
        self.real_msgs = torch.zeros(self.num_agents,self.num_agents-1, self.num_targets, self.state_dim).to(self.device)

    def update_rnn_hidden(self):
        self.hself = Variable(self.hself.data)
        self.hToM = Variable(self.hToM.data)
        self.ToM_msgs = Variable(self.ToM_msgs)
        self.real_msgs = Variable(self.real_msgs)

    '''
    def test_model(self):
        hself = torch.zeros(self.num_agents, self.lstm_out).to(self.device)
        hothers = torch.zeros(self.num_agents, self.num_agents-1, self.lstm_out).to(self.device)
        for i in range(len(self.rewards)):
            #print(i)
            state = self.Policy_history[i]["state"]
            cam_states = self.Policy_history[i]["cam_states"]
            value_multi, actions, entropy, log_prob, hn_self, hn_others, ToM_goals, edge_logits, comm_edges, probs =\
                self.model(state, hself, hothers, cam_states)
            print(torch.sum(torch.abs(hself - self.Policy_history[i]["hself"])))
            print(torch.sum(torch.abs(hothers - self.Policy_history[i]["hothers"])))
            hself = hn_self
            hothers = hn_others
        self.Policy_history = []
    '''

    def optimize(self, params, optimizer, shared_model, training_mode, device_share):
        R = torch.zeros(len(self.rewards[0]), 1).to(self.device)
        if not self.done:
            # predict value
            if 'ToM-v1' in self.model_name:
                self.cam_states=self.env.get_cam_states()
                self.obs_msg_fusion()
                value_multi, *others = self.model(self.state, self.hself, self.hothers, self.cam_states, self.ToM_msgs)
            elif 'ToM-v2' in self.model_name:
                self.cam_states=self.env.get_cam_states()
                value_multi, *others = self.model(self.state, self.hself, self.hothers, self.cam_states)
            elif 'ToM-v3' in self.model_name or 'ToM-v4' in self.model_name:
                self.cam_states = self.get_cam_states()
                value_multi, *others = self.model(self.state, self.hself, self.hothers, self.cam_states)
            elif 'center' in self.model_name:
                self.cam_states = self.get_cam_states()
                value_multi, *others = self.model(self.state, self.cam_states)
            else:
                value_multi, *others = self.model(self.state)
            #print(value_multi)
            for i in range(len(self.rewards[0])):  # num_agent
                R[i][0] = value_multi[i].data

        self.values.append(Variable(R).to(self.device))

        #batch_size = len(self.entropies[0][0])
        policy_loss = torch.zeros(1).to(self.device)
        value_loss = torch.zeros(1).to(self.device)
        entropies = torch.zeros(1, self.dim_action).to(self.device)
        w_entropies = float(self.args.entropy)

        ToM_loss = torch.zeros(1).to(self.device)
        #ToM_target_loss = torch.zeros(1).to(self.device)
        #Trigger_loss = torch.zeros(1).to(self.device)
        Sparsity_loss = torch.zeros(1).to(self.device)

        KL_criterion = torch.nn.KLDivLoss(reduction='sum')
        KL_single = torch.nn.KLDivLoss(reduction='none')
        BCE_criterion = torch.nn.BCELoss(reduction='sum')
        criterionH = HLoss()

        R = Variable(R, requires_grad=True).to(self.device)
        gae = torch.zeros(1, 1).to(self.device)

        for i in reversed(range(len(self.rewards))):
            R = self.args.gamma * R + self.rewards[i]
            advantage = R - self.values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)
            # Generalized Advantage Estimataion
            delta_t = self.rewards[i] + self.args.gamma * self.values[i + 1].data - self.values[i].data
            gae = gae * self.args.gamma * self.args.tau + delta_t

            #log_probs = self.log_probs[i]
                 
            #for j in range(self.num_agents):
            #    log_probs[j*self.num_targets : (j+1)*self.num_targets] = log_probs[ j*self.num_targets : (j+1)*self.num_targets] * torch.Tensor(gae[j])
            policy_loss = policy_loss - (w_entropies * self.entropies[i]) - (self.log_probs[i] * gae[0])  # only if gae are all same
            entropies += self.entropies[i].sum()

            if 'ToM' in self.model_name:
                '''
                # compute ToM Loss, including agent & target
                ToM_goals = self.ToM_goals[i]
                real_goals = self.real_goals[i]
                
                real_goals_duplicate = real_goals.reshape(1, self.num_agents, self.num_targets, -1).repeat(self.num_agents, 1, 1, 1)
                idx= (torch.ones(self.num_agents, self.num_agents) - torch.diag(torch.ones(self.num_agents))).bool()
                real_goals_duplicate = real_goals_duplicate[idx].reshape(self.num_agents, self.num_agents-1, self.num_targets, -1)
                
                real_goals_duplicate = real_goals[1:].unsqueeze(0)                   
                # ToM_targets = self.ToM_targets[i]
                # real_targets = self.real_targets[i]
                # real_targets_duplicate = real_targets[1:].unsqueeze(0)
                for j in range(1,self.num_agents):
                    goals = torch.cat((real_goals[:j],real_goals[j+1:]),0).unsqueeze(0)
                    real_goals_duplicate = torch.cat((real_goals_duplicate,goals),0)
                    # targets = torch.cat((real_targets[:i],real_targets[i+1:]),0).unsqueeze(0)
                    # real_targets_duplicate = torch.cat((real_targets_duplicate,targets),0)
                
                real_prob = real_goals_duplicate.float()
                ToM_prob = ToM_goals.float()
                ToM_loss = ToM_loss + KL_criterion(ToM_prob.log(),real_prob)
                #ToM_target_loss = ToM_target_loss + BCE_criterion(ToM_targets.float(), real_targets_duplicate.float())
                '''
                if 'ToM-v3' in self.model_name or 'ToM-v4' in self.model_name:
                    # Sparsity loss
                    edge_logits = self.edge_logits[i].reshape(self.num_agents * self.num_agents, -1)  # (n * n) * 2
                    Sparsity_loss += -criterionH(edge_logits, self.prior)

                '''
                # Trigger Loss
                if 'v3-part' in self.model_name and i < len(self.rewards) - 1:
                    trigger = self.comm_trigger[i+1]
                    real_trigger = KL_single(ToM_prob.log(), real_prob)
                    real_trigger = ((torch.sum(real_trigger, [-2,-1]) / self.num_targets) > self.trigger_threshold).unsqueeze(-1)
                    Trigger_loss += BCE_criterion(trigger.float(), real_trigger.float())
                '''
        test_data = {"policy": policy_loss.sum().data, "value": value_loss.sum().data, "entropy": entropies.sum().data, "Sparsity": Sparsity_loss.sum().data}
        self.loss_history.append(test_data)
        return 
        self.model.zero_grad()
        loss = policy_loss.sum() + 0.5 * value_loss.sum() + 0.5 * Sparsity_loss.sum() # + 2 * ToM_loss.sum()
        loss.backward(retain_graph=True)
        
        # for name,param in self.model.named_parameters():
        #     if not ('ToM' in name or 'other' in name):
        #         print(name,param.grad)

        torch.nn.utils.clip_grad_norm_(params, 50)
        ensure_shared_grads(self.model, shared_model, self.device, device_share)
        optimizer.step()
        self.clean_buffer(self.done)
        return policy_loss, value_loss, Sparsity_loss, entropies

    '''
    def optimize_ToM(self, params, optimizer_ToM, shared_model, training_mode, device_share):
        ToM_loss = torch.zeros(1).to(self.device)
        ToM_loss.requires_grad = True
        KL_criterion = torch.nn.KLDivLoss(reduction='sum')
        for i in range(len(self.ToM_goals)):
            ToM_goals = self.ToM_goals[i]
            real_goals = self.real_goals[i]

            real_goals_duplicate = real_goals.reshape(1, self.num_agents, self.num_targets, -1).repeat(self.num_agents, 1, 1, 1)
            idx= (torch.ones(self.num_agents, self.num_agents) - torch.diag(torch.ones(self.num_agents))).bool()
            real_goals_duplicate = real_goals_duplicate[idx].reshape(self.num_agents, self.num_agents-1, self.num_targets, -1)

            real_prob = real_goals_duplicate.float()
            ToM_prob = ToM_goals.float()
            ToM_loss = ToM_loss + KL_criterion(ToM_prob.log(),real_prob)

        self.model.zero_grad()
        # for name,param in self.model.named_parameters():
        #     if not('ToM' in name or 'other' in name):
        #         if (param.grad is not None):
        #             print(param.grad[0])
        loss = 3 * ToM_loss
        loss.backward(retain_graph=True)

        torch.nn.utils.clip_grad_norm_(params, 50)
        ensure_shared_grads(self.model, shared_model, self.device, device_share)
        optimizer_ToM.step()
        # clean ToM & real goal buffer here
        self.ToM_goals = []
        self.real_goals = []
        return ToM_loss    
    '''