from __future__ import division
import torch
import numpy as np
import torch.nn as nn
from gym import spaces
import torch.nn.functional as F
from torch.autograd import Variable

from utils import norm_col_init, weights_init
from perception import NoisyLinear, RNN, AttentionLayer

def build_model(env, args, device):
    name = args.model

    if 'single' in name:
        model = A3C_Single(env.observation_space, env.action_space, args, device)
    elif 'ToM-v1' in name:
        # partial obs, communication needed
        model = A3C_Multi(env.observation_space, env.action_space, args, device)
        model.num_agents = env.n
        model.num_targets = env.num_target
    elif 'ToM-v2' in name:
        # no communication, only ToM
        model = A3C_Multi_v2(env.observation_space,env.action_space,args,device)
        model.num_agents = env.n
        model.num_targets = env.num_target
    elif 'comm' in name:
        # partial obs with communication without ToM
        model = A3C_Multi_v3(env.observation_space,env.action_space,args,device)
        model.num_agents = env.n
        model.num_targets = env.num_target
    elif 'ToM-v4' in name:
        # partial obs with obstacles and centralized communication and ToM
        model = A3C_Multi_v4(env.observation_space,env.action_space,args,device)
        model.num_agents = env.n
        model.num_targets = env.num_target
    elif 'ToM-v5' in name:
        # partial obs with obstacles and decentralized communication, target cover is learnt
        model = A3C_Multi_v5(env.observation_space,env.action_space,args,device)
        model.num_agents = env.n
        model.num_targets = env.num_target
    elif 'baseline' in name and 'decentralized' in name:
        model = A3C_baseline_decentralized(env.observation_space, env.action_space, args, device)
        model.num_agents = env.n
        model.num_targets = env.num_target

    elif 'baseline' in name and 'center' in name:
        model = A3C_baseline_centralized(env.observation_space, env.action_space, args, device)
        model.num_agents = env.n
        model.num_targets = env.num_target

    model.train()
    return model


def wrap_action(self, action):
    action = np.squeeze(action)
    out = action * (self.action_high - self.action_low) / 2 + (self.action_high + self.action_low) / 2.0
    return out


def sample_action(mu_multi, sigma_multi, device, test=False):
    # discrete
    logit = mu_multi
    prob = F.softmax(logit, dim=-1)
    log_prob = F.log_softmax(logit, dim=-1)
    entropy = -(log_prob * prob).sum(-1, keepdim=True)
    if test: # test
        # not always sample action instead of choosing the best action, where this branch should never be visited
        #print("test")
        action = prob.max(-1)[1].data
        action_env = action.cpu().numpy()  # np.squeeze(action.cpu().numpy(), axis=0)
    else:
        action = prob.multinomial(1).data
        log_prob = log_prob.gather(-1, Variable(action))  # [num_agent, 1] # comment for sl slave
        action_env = action.squeeze(0)
    
    return action_env, entropy, log_prob, prob


class ValueNet(nn.Module):
    def __init__(self, input_dim, head_name, num=1):
        super(ValueNet, self).__init__()
        if 'ns' in head_name:
            self.noise = True
            self.critic_linear = NoisyLinear(input_dim, num, sigma_init=0.017)
        else:
            self.noise = False
            self.critic_linear = nn.Linear(input_dim, num)
            self.critic_linear.weight.data = norm_col_init(self.critic_linear.weight.data, 0.1)
            self.critic_linear.bias.data.fill_(0)

    def forward(self, x):
        value = self.critic_linear(x)
        return value

    def sample_noise(self):
        if self.noise:
            self.critic_linear.sample_noise()

    def remove_noise(self):
        if self.noise:
            self.critic_linear.sample_noise()


class AMCValueNet(nn.Module):
    def __init__(self, input_dim, head_name, num=1, device=torch.device('cpu')):
        super(AMCValueNet, self).__init__()
        self.head_name = head_name
        self.device = device

        if 'ns' in head_name:
            self.noise = True
            self.critic_linear = NoisyLinear(input_dim, num, sigma_init=0.017)
        if 'onlyJ' in head_name:
            self.noise = False
            self.critic_linear = nn.Linear(input_dim, num)
            self.critic_linear.weight.data = norm_col_init(self.critic_linear.weight.data, 0.1)
            self.critic_linear.bias.data.fill_(0)
        else:
            self.noise = False
            self.critic_linear = nn.Linear(2 * input_dim, num)
            self.critic_linear.weight.data = norm_col_init(self.critic_linear.weight.data, 0.1)
            self.critic_linear.bias.data.fill_(0)

            self.attention = AttentionLayer(input_dim, input_dim, device)
        self.feature_dim = input_dim

    def forward(self, x, goal):
        _, feature_dim = x.shape
        value = []

        coalition = x.view(-1, feature_dim)
        n = coalition.shape[0]

        feature = torch.zeros([self.feature_dim]).to(self.device)
        value.append(self.critic_linear(torch.cat([feature, coalition[0]])))
        for j in range(1, n):
            _, feature = self.attention(coalition[:j].unsqueeze(0))
            value.append(self.critic_linear(torch.cat([feature.squeeze(), coalition[j]])))  # delta f = f[:j]-f[:j-1]

        # mean and sum
        value = torch.cat(value).sum()

        return value.unsqueeze(0)

    def sample_noise(self):
        if self.noise:
            self.critic_linear.sample_noise()

    def remove_noise(self):
        if self.noise:
            self.critic_linear.sample_noise()


class PolicyNet(nn.Module):
    def __init__(self, input_dim, action_space, head_name, device):
        super(PolicyNet, self).__init__()
        self.head_name = head_name
        self.device = device
        num_outputs = action_space.n

        if 'ns' in head_name:
            self.noise = True
            self.actor_linear = NoisyLinear(input_dim, num_outputs, sigma_init=0.017)
        else:
            self.noise = False
            self.actor_linear = nn.Linear(input_dim, num_outputs)

            # init layers
            self.actor_linear.weight.data = norm_col_init(self.actor_linear.weight.data, 0.1)
            self.actor_linear.bias.data.fill_(0)

    def forward(self, x, test=False, available_actions = None):
        mu = F.relu(self.actor_linear(x))
        if available_actions is not None:
            # mask unavailable actions
            # size of mu&available actions: b*n*m, 2
            idx = (available_actions == 0)
            mu[idx] = 1e-10
        sigma = torch.ones_like(mu)
        action, entropy, log_prob, prob = sample_action(mu, sigma, self.device, test)
        return action, entropy, log_prob, prob

    def sample_noise(self):
        if self.noise:
            self.actor_linear.sample_noise()
            self.actor_linear2.sample_noise()

    def remove_noise(self):
        if self.noise:
            self.actor_linear.sample_noise()
            self.actor_linear2.sample_noise()

'''
class EncodeBiRNN(torch.nn.Module):
    def __init__(self, dim_in, lstm_out=128, head_name='birnn_lstm', device=None):
        super(EncodeBiRNN, self).__init__()
        self.head_name = head_name

        self.encoder = BiRNN(dim_in, int(lstm_out / 2), 1, device, 'gru')

        self.feature_dim = self.encoder.feature_dim
        self.global_feature_dim = self.encoder.feature_dim
        self.apply(weights_init)
        self.train()

    def forward(self, inputs):
        x = inputs
        cn, hn = self.encoder(x)

        feature = cn  # shape: [bs, num_camera, lstm_dim]

        global_feature = hn.permute(1, 0, 2).reshape(-1)

        return feature, global_feature
'''

class EncodeLinear(torch.nn.Module):
    def __init__(self, dim_in, dim_out=32, head_name='lstm', device=None):
        super(EncodeLinear, self).__init__()

        self.features = nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.ReLU(),
            nn.Linear(dim_out, dim_out),
            nn.ReLU()
        )

        self.head_name = head_name
        self.feature_dim = dim_out
        self.train()

    def forward(self, inputs):
        x = inputs
        feature = self.features(x)
        return feature


class A3C_Single(torch.nn.Module):  # single vision Tracking
    def __init__(self, obs_space, action_spaces, args, device=torch.device('cpu')):
        super(A3C_Single, self).__init__()
        self.n = len(obs_space)
        obs_dim = obs_space[0].shape[1]

        lstm_out = args.lstm_out
        head_name = args.model

        self.head_name = head_name

        self.encoder = AttentionLayer(obs_dim, lstm_out, device)
        self.critic = ValueNet(lstm_out, head_name, 1)
        self.actor = PolicyNet(lstm_out, action_spaces[0], head_name, device)

        self.train()
        self.device = device

    def forward(self, inputs, test=False):
        data = Variable(inputs, requires_grad=True)
        _, feature = self.encoder(data)

        actions, entropies, log_probs = self.actor(feature, test)
        values = self.critic(feature)

        return values, actions, entropies, log_probs

    def sample_noise(self):
        self.actor.sample_noise()
        self.actor.sample_noise()

    def remove_noise(self):
        self.actor.remove_noise()
        self.actor.remove_noise()


class GoalLayer(nn.Module):
    def __init__(self,in_dim, hidden_dim=32, device=torch.device('cpu')):
        super(GoalLayer,self).__init__()
        self.feature_dim=in_dim
        self.net=nn.Sequential(
            nn.Linear(in_dim,hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,2),
            nn.Softmax(dim=-1)
        )
        self.train()
        self.device=device
    
    def forward(self,inputs):
        x=inputs
        assign_prob=self.net(x) # The probability of assigning this goal to the agent
        return assign_prob


class ToM(nn.Module):
    def __init__(self,args,attention_feature_dim,cam_state_dim=3,device=torch.device('cpu')):
        super(ToM,self).__init__()
        self.cam_state_size=cam_state_dim  # location [0,0] + rotation[0]
        self.msg_size = attention_feature_dim # previous msg transformed by self Attention
        self.GRU_feature_dim=args.lstm_out

        self.RNN_feature_dim=args.lstm_out
        self.RNN=RNN(self.cam_state_size+self.msg_size ,self.RNN_feature_dim,1,device,'GRU')
        
        self.other_goal=GoalLayer(self.GRU_feature_dim + attention_feature_dim + self.RNN.feature_dim,device=device)
        
        self.train()
        self.device = device

    def forward(self,ToM_inputs,GRU_output,attention_features,h_others):
        num_agents=ToM_inputs.size()[0] #[n, n-1, 3+msg_dim]
        num_targets=attention_features.size()[1] #[n,m,att_dim]
        #h_others=Variable(h_others,requires_grad=True)
        h_other=h_others.reshape(1,-1,self.RNN_feature_dim) #[1, n*(n-1), hidden_size]
        ToM_input=ToM_inputs.reshape(num_agents*(num_agents-1),1,-1) #[n*(n-1),1,3+msg_dim]
        
        RNN_output,hn_other=self.RNN(ToM_input,h_other)
        hn_others=hn_other.reshape(num_agents,num_agents-1,-1)
        RNN_outputs=RNN_output.reshape(num_agents,num_agents-1,1,-1)
        
        GRU_dim=GRU_output.size()[-1]
        RNN_dim=self.RNN_feature_dim
        att_dim=attention_features.size()[2]
        
        GRU_output_tmp = GRU_output.reshape(num_agents,1,1,-1)
        GRU_output=GRU_output_tmp.expand(num_agents,num_agents-1,num_targets,GRU_dim)
        RNN_output=RNN_outputs.expand(num_agents,num_agents-1,num_targets,RNN_dim)
        attention_feature=attention_features.unsqueeze(1).expand(num_agents,num_agents-1,num_targets,att_dim)
        stack_feature=torch.cat((GRU_output,RNN_output,attention_feature),3)
        
        other_goals=self.other_goal(stack_feature) #[num_agents, num_agents-1, num_targets, 1]
        
        return other_goals,hn_others


class PropNet(nn.Module):
    def __init__(self, node_dim_in, edge_dim_in, hidden_dim, node_dim_out, edge_dim_out, batch_norm=0, pstep=2):

        super(PropNet, self).__init__()

        self.node_dim_in = node_dim_in
        self.edge_dim_in = edge_dim_in
        self.hidden_dim = hidden_dim

        self.node_dim_out = node_dim_out
        self.edge_dim_out = edge_dim_out

        self.pstep = pstep

        # node encoder
        modules = [
            nn.Linear(node_dim_in, hidden_dim),
            nn.ReLU()]
        if batch_norm == 1:
            modules.append(nn.BatchNorm1d(hidden_dim))
        self.node_encoder = nn.Sequential(*modules)

        # edge encoder
        modules = [
            nn.Linear(node_dim_in * 2 + edge_dim_in, hidden_dim),
            nn.ReLU()]
        if batch_norm == 1:
            modules.append(nn.BatchNorm1d(hidden_dim))
        self.edge_encoder = nn.Sequential(*modules)

        # node propagator
        modules = [
            # input: node_enc, node_rep, edge_agg
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()]
        if batch_norm == 1:
            modules.append(nn.BatchNorm1d(hidden_dim))
        self.node_propagator = nn.Sequential(*modules)

        # edge propagator
        self.edge_propagators = nn.ModuleList()
        for i in range(pstep):
            modules = [
                # input: node_rep * 2,  edge_rep
                nn.Linear(hidden_dim * 3, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU()]
            if batch_norm == 1:
                modules.append(nn.BatchNorm1d(hidden_dim))
            edge_propagator = nn.Sequential(*modules)
            self.edge_propagators.append(edge_propagator)

        # commented due to None grad problem
        # node predictor
        # modules = [
        #     nn.Linear(hidden_dim * 2, hidden_dim),
        #     nn.ReLU()]
        # if batch_norm == 1:
        #     modules.append(nn.BatchNorm1d(hidden_dim))
        # modules.append(nn.Linear(hidden_dim, node_dim_out))
        # self.node_predictor = nn.Sequential(*modules)

        # edge predictor
        modules = [
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU()]
        if batch_norm == 1:
            modules.append(nn.BatchNorm1d(hidden_dim))
        modules.append(nn.Linear(hidden_dim, edge_dim_out))
        self.edge_predictor = nn.Sequential(*modules)

    def forward(self, node_rep, edge_rep=None, edge_type=None,
                ignore_node=False, ignore_edge=False):
        # node_rep: batch x N x node_dim_in
        # edge_rep: N x N x edge_dim_in
        # edge_type: N x N x edge_type_num
        B, N, _ = node_rep.size()

        # node_enc
        node_enc = self.node_encoder(node_rep.view(-1, self.node_dim_in)).view(B, N, self.hidden_dim)
        # edge_enc
        node_rep_r = node_rep[:, :, None, :].repeat(1, 1, N, 1)
        node_rep_s = node_rep[:, None, :, :].repeat(1, N, 1, 1)
        if edge_rep is not None:
            tmp = torch.cat([node_rep_r, node_rep_s, edge_rep], -1)
        else:
            tmp = torch.cat([node_rep_r, node_rep_s], -1)

        edge_enc = self.edge_encoder(tmp.view(B * N * N, -1)).view(B, N, N, self.hidden_dim)

        if edge_type is not None:
            edge_enc = edge_enc * edge_type.view(B, N, N, self.edge_type_num, 1)[:, :, :, -1, :]

        for i in range(self.pstep):
            if i == 0:
                node_effect = node_enc
                edge_effect = edge_enc

            # calculate edge_effect
            node_effect_r = node_effect[:, :, None, :].repeat(1, 1, N, 1)
            node_effect_s = node_effect[:, None, :, :].repeat(1, N, 1, 1)
            tmp = torch.cat([node_effect_r, node_effect_s, edge_effect], -1)

            edge_effect = self.edge_propagators[i](tmp.view(B * N * N, -1)).view(B, N, N, -1)
            # edge effect: N * N * hidden_dim

            if edge_type is not None:
                edge_effect = edge_effect * edge_type.view(B, N, N, self.edge_type_num, 1)[:, :, :, -1, :]

            # calculate node_effect
            edge_effect_agg = edge_effect.sum(2)
            tmp = torch.cat([node_enc, node_effect, edge_effect_agg], -1)
            node_effect = self.node_propagator(tmp.view(B * N, -1)).view(B, N, self.hidden_dim)

        node_effect = torch.cat([node_effect, node_enc], -1).view(B * N, -1)
        edge_effect = torch.cat([edge_effect, edge_enc], -1).view(B * N * N, -1)

        # node_pred: B x N x node_dim_out
        # edge_pred: B x N x N x edge_dim_out
        if ignore_node:
            edge_pred = self.edge_predictor(edge_effect)
            return edge_pred.view(B, N, N, -1)
        if ignore_edge:
            node_pred = self.node_predictor(node_effect)
            return node_pred.view(B, N, -1)

        node_pred = self.node_predictor(node_effect).view(B, N, -1)
        edge_pred = self.edge_predictor(edge_effect).view(B, N, N, -1)
        return node_pred, edge_pred


class Graph_Infer(nn.Module):
    def __init__(self, dim_in, device=torch.device('cpu')):
        super(Graph_Infer, self).__init__()
        #self.N = num_agents
        self.propnet_selfloop = False
        #self.mask_remove_self_loop = torch.FloatTensor(
        #    np.ones((num_agents, num_agents)) - np.eye(num_agents)).view(1, num_agents, num_agents, 1)
        #self.mask_remove_self_loop = self.mask_remove_self_loop.to(device)
        self.device = device
        self.edge_type_num = 2
        # edge type: 2
        # 0 stands for null edge, and 1 stands for real edge

        self.graph_infer = PropNet(
            node_dim_in= dim_in, 
            edge_dim_in=0,
            hidden_dim= dim_in * 3,
            node_dim_out=0,
            edge_dim_out=self.edge_type_num,
            pstep=2,
            batch_norm=0)

    def forward(self, node_features, hard=True):
        # node_features: batch * N * global_feature_dim
        B, N, _ = node_features.size()

        edge_type_logits = self.graph_infer(node_features, ignore_node=True)
        
        edge_type = F.gumbel_softmax(edge_type_logits.view(B * N * N, self.edge_type_num), tau=0.5, hard=hard)
        edge_type = edge_type.view(B, N, N, self.edge_type_num)[:,:,:,-1].view(B, N, N, 1)
        
        if self.propnet_selfloop == False:
            mask_remove_self_loop = torch.FloatTensor(np.ones((N, N)) - np.eye(N)).view(1, N, N, 1)
            mask_remove_self_loop = mask_remove_self_loop.to(self.device)
            edge_type = edge_type * mask_remove_self_loop

        return edge_type_logits, edge_type


class A3C_Multi(torch.nn.Module):
    def __init__(self, obs_space, action_spaces, args, device=torch.device('cpu')):
        super(A3C_Multi, self).__init__()
        #_, _, self.pose_dim = obs_space.shape
        #self.num_agents = None
        #self.num_targets = None
        self.num_agents, self.num_targets, self.pose_dim = obs_space.shape
        
        lstm_out = args.lstm_out
        head_name = args.model
        self.head_name = head_name

        self.encoder = EncodeLinear(self.pose_dim, lstm_out, head_name, device)
        feature_dim = self.encoder.feature_dim

        self.attention = AttentionLayer(feature_dim, lstm_out, device)
        feature_dim = self.attention.feature_dim

        #self.obstacle_attention = AttentionLayer(self.pose_dim, lstm_out, device)
        self.GRU=RNN(feature_dim,lstm_out,1,device,'GRU')
        feature_dim=self.GRU.feature_dim + self.attention.feature_dim + 1 # GRU_output + attention_feature + max_prob

        # create actor & critic
        self.actor = PolicyNet(feature_dim, spaces.Discrete(2), head_name, device)

        if 'shap' in head_name:
            self.ShapleyVcritic = AMCValueNet(feature_dim, head_name, 1, device)
        else:
            self.critic = ValueNet(feature_dim, head_name, 1)

        # create ToM
        self.ToM=ToM(args,self.attention.feature_dim,device=device)

        self.train()
        self.device = device

    def forward(self, multi_obs, self_hiddens, other_hiddens, cameras, ToM_msgs, test=False):
        num_agents=self.num_agents
        num_targets=self.num_targets  
        
        pos_obs = multi_obs[:, :self.num_targets].clone()
        obs_dim = pos_obs.size()[2] # [n,m,obs_dim]
        #obstacle_obs = inputs[:, self.num_targets:]

        feature_target = Variable(multi_obs, requires_grad=True) 
        feature_target = self.encoder(feature_target)  # [num_targets, num_target, feature_dim]
        att_features, global_feature = self.attention(feature_target)  
        #features: [num_agents, num_targets, feature_dim]

        h_self=self_hiddens.unsqueeze(0) # [1, num_agents, hidden_size]
        global_features=global_feature.unsqueeze(1) # [num_agents, 1, feature_dim]
        GRU_outputs,hn_self_tmp=self.GRU(global_features,h_self)
        #GRU_outputs=GRU_outputs.squeeze(1) 
        hn_self=hn_self_tmp.squeeze(0)

        # ToM_feature concat
        cam_states = torch.from_numpy(cameras).float()

        cam_dim=cam_states.size()[1] # cam_dim=3

        cam_states=cam_states.unsqueeze(0).expand(num_agents,num_agents,cam_dim)
        idx=(torch.ones(num_agents,num_agents)-torch.diag(torch.ones(num_agents))).bool()
        camera_states=cam_states[idx].reshape(num_agents,-1,cam_dim) # [n,n-1,3]
        ToM_feature=torch.cat((camera_states,ToM_msgs),2) #[n,n-1,3+msg_dim]

        # ToM prediction
        other_goals,hn_others=self.ToM(ToM_feature,GRU_outputs,att_features,other_hiddens)
        
        masked_idx=(other_goals<0.5).squeeze(-1) # not assigned [n, n-1, m]
        
        # compute ToM_msg & real_msg
        real_msgs=(pos_obs.unsqueeze(1)).expand(num_agents,num_agents-1,num_targets,obs_dim)
        real_msgs[masked_idx]=torch.zeros(obs_dim)
        next_ToM_msgs=real_msgs.reshape(-1,num_targets,obs_dim)
        next_ToM_msgs = self.encoder(next_ToM_msgs)
        _,next_ToM_msgs_tmp=self.attention(next_ToM_msgs)
        next_ToM_msgs=next_ToM_msgs_tmp.reshape(num_agents,num_agents-1,-1)

        # decide self goals
        max_prob,_=torch.max(other_goals,1) #[n,m-1,1]
        GRU_outputs=GRU_outputs.expand(num_agents,num_targets,self.GRU.feature_dim)
        actor_feature=torch.cat((att_features,GRU_outputs,max_prob),2) #[n,m,dim]
        critic_dim=actor_feature.size()[-1]
        critic_feature=torch.sum(actor_feature,1) #[n,dim]
        actor_features=actor_feature.reshape(-1,critic_dim) #[n*m,dim]

        actions, entropies, log_probs = self.actor(actor_features, test)
        actions = actions.reshape(num_agents,num_targets,-1)
        if 'shap' not in self.head_name:
            #print(global_feature.size(),feature.size())
            values = self.critic(critic_feature)
        else:
            values = self.ShapleyVcritic(critic_feature, actions)  # shape [1,1]

        return values, actions, entropies, log_probs, hn_self, hn_others, next_ToM_msgs, real_msgs, other_goals

    def sample_noise(self):
        self.actor.sample_noise()
        self.actor.sample_noise()

    def remove_noise(self):
        self.actor.remove_noise()
        self.actor.remove_noise()


class A3C_Multi_v2(torch.nn.Module):
    # partial obs + ToM + no communication
    def __init__(self, obs_space, action_spaces, args, device=torch.device('cpu')):
        super(A3C_Multi_v2, self).__init__()

        self.num_agents, num_entity, self.pose_dim = obs_space.shape  # num_targets = target + obstacles
        #self.num_obstacles = num_entity - self.num_targets

        lstm_out = args.lstm_out
        head_name = args.model
        self.head_name = head_name
        
        self.encoder = EncodeLinear(self.pose_dim, lstm_out, head_name, device)
        feature_dim = self.encoder.feature_dim

        self.attention = AttentionLayer(feature_dim, lstm_out * 2, device)
        feature_dim = self.attention.feature_dim

        self.GRU=RNN(feature_dim,lstm_out,1,device,'GRU')
        
        # create ToM, including camera ToM & target ToM
        cam_state_size = 4  # [x, y, cos, sin]

        self.ToM_target = nn.Sequential(
            nn.Linear(lstm_out + self.attention.feature_dim*2, lstm_out),
            nn.ReLU(),
            nn.Linear(lstm_out,1),
            nn.Sigmoid()
        )

        self.ToM=RNN(cam_state_size, lstm_out, 1, device, 'GRU')
        self.other_goal=GoalLayer(lstm_out + self.attention.feature_dim + self.GRU.feature_dim, device=device)

        feature_dim=self.GRU.feature_dim + self.attention.feature_dim # GRU_output + attention_feature + max_prob 

        # ToM goal & msgs have too low dimension, so raise them to match GRU&attention
        # self.raise_dim = nn.Sequential(
        #     nn.Linear(1, lstm_out),
        #     nn.ReLU(),
        #     nn.Linear(lstm_out,feature_dim),
        #     nn.ReLU()
        # )

        self.reduce_dim = nn.Sequential(
            nn.Linear(feature_dim, lstm_out*2),
            nn.ReLU(),
        )
        feature_dim = lstm_out*2 + 1
        # create actor & critic
        self.actor = PolicyNet(feature_dim, spaces.Discrete(2), head_name, device)

        if 'shap' in head_name:
            self.ShapleyVcritic = AMCValueNet(feature_dim, head_name, 1, device)
        else:
            # centralized training
            self.critic_encoder = AttentionLayer(feature_dim, 6 * feature_dim, device)
            self.critic = ValueNet(6 * feature_dim, head_name, 1)
        self.train()
        self.device = device

    def forward(self, multi_obs, self_hiddens, ToM_hiddens, cameras, test=False):
        num_agents=self.num_agents
        num_targets=self.num_targets  
        
        if len(multi_obs.size()) != 4:
            # not batch data, so turn them into a size-1 batch
            batch = False
            multi_obs = multi_obs.unsqueeze(0)
            self_hiddens = self_hiddens.unsqueeze(0)
            ToM_hiddens = ToM_hiddens.unsqueeze(0)
            cameras = cameras.unsqueeze(0)
        else:
            batch = True
        
        batch_size = multi_obs.size()[0]
        obs_dim = multi_obs.size()[-1] # [batch,n,m,obs_dim]
        num_both = multi_obs.size()[2]

        idx = (torch.ones(num_agents, num_agents) - torch.diag(torch.ones(num_agents))).bool()
        real_target_cover = (multi_obs[:,:,:self.num_targets,-1] != 0).reshape(batch_size, num_agents, num_targets, 1).detach()
        real_others_cover = real_target_cover.unsqueeze(1).repeat(1, num_agents, 1, 1, 1)
        real_others_cover = real_others_cover[:,idx].reshape(batch_size, num_agents, num_agents-1, num_targets, 1)
        
        feature_target = Variable(multi_obs, requires_grad=True)

        feature_target = self.encoder(feature_target)  # [num_agents, num_targets, feature_dim]
        feature_target = feature_target.reshape(batch_size * num_agents, num_both, -1)

        att_features, global_feature = self.attention(feature_target)  
        #features: [num_agents, num_targets, feature_dim]
        att_features = att_features.reshape(batch_size, num_agents, num_both, -1)
        global_feature = global_feature.reshape(batch_size, num_agents, -1)

        # split features into targets & obstacles
        ############ begin ###########
        obstacle_features = torch.mean(att_features[:,:,self.num_targets:], 2) if num_both > num_targets else torch.zeros(batch_size, num_agents, att_features.size()[-1]).to(device_share)
        att_features = att_features[:,:,:self.num_targets]
        ############# end ############

        h_self = self_hiddens.reshape(1, num_agents * batch_size, -1) # [1, num_agents, hidden_size]
        global_features = global_feature.reshape(num_agents * batch_size, 1, -1)
        GRU_outputs,hn_self=self.GRU(global_features,h_self)
        #GRU_outputs=GRU_outputs.squeeze(1) 
        hn_self=hn_self.reshape(batch_size, num_agents, -1)

        # ToM_input
        cam_states = cameras  #[batch, num_agents, cam_dim]
        # Because there is no communication in this version, cameras do not have access to other rotations
        cam_states[:,:,-1] = 0
        cam_dim = cam_states.size()[-1] # cam_dim=3

        # transform to relative cam_states 
        cam_states_duplicate = cam_states.unsqueeze(1).expand(batch_size, num_agents, num_agents, cam_dim)
        idx=(torch.ones(num_agents,num_agents)-torch.diag(torch.ones(num_agents))).bool()
        cam_states_duplicate = cam_states_duplicate[:,idx].reshape(batch_size,num_agents,num_agents-1,cam_dim) # [n,n-1,3]
        cam_states_relative = cam_states_duplicate - cam_states.unsqueeze(2).expand(batch_size,num_agents, num_agents-1, cam_dim)
        cam_state_theta = ((cam_states_relative[:,:,:,-1]/180) * np.pi).reshape(batch_size, num_agents, num_agents-1, 1)
        camera_states = torch.cat((cam_states_relative[:,:,:,:2], torch.cos(cam_state_theta), torch.sin(cam_state_theta)),-1)
        camera_states = camera_states.reshape(batch_size*num_agents*(num_agents-1),1,-1) # [n*n-1,1,4]
        
        h_ToM = ToM_hiddens.reshape(1,-1,self.ToM.feature_dim) #[1,n*(n-1),dim]

        # ToM_camera prediction
        ToM_output,hn_ToM=self.ToM(camera_states,h_ToM)

        # GoalLayer input concat
        hn_ToM = hn_ToM.reshape(batch_size,num_agents,num_agents-1,-1)
        ToM_output = ToM_output.reshape(batch_size,num_agents, num_agents-1, 1, -1)
        
        GRU_dim=GRU_outputs.size()[-1]
        ToM_dim=ToM_output.size()[-1]
        att_dim=att_features.size()[-1]
        
        GRU_output_expand = GRU_outputs.reshape(batch_size,num_agents,1,1,-1)
        GRU_output_expand = GRU_output_expand.expand(batch_size,num_agents,num_agents-1,num_targets,GRU_dim)
        ToM_output_expand = ToM_output.expand(batch_size,num_agents,num_agents-1,num_targets,GRU_dim)
        att_feature_expand = att_features.unsqueeze(2).expand(batch_size,num_agents,num_agents-1,num_targets,att_dim)

        # ToM target
        obstacle_features = obstacle_features.reshape(batch_size, num_agents, 1, 1, -1).repeat(1, 1, num_agents - 1, num_targets, 1)
        ToM_target_feature = torch.cat((att_feature_expand.detach(), obstacle_features.detach(), ToM_output_expand),-1)

        ToM_target_cover = self.ToM_target(ToM_target_feature)

        # other goals
        goal_feature = torch.cat((GRU_output_expand, ToM_output_expand, att_feature_expand),-1) # detach ToM_target here
        other_goals=self.other_goal(goal_feature) # [n, n-1, m, 2]
        
        # decide self goals
        max_prob,_=torch.max(other_goals[:,:,:,:,1],2)
        max_prob = max_prob.reshape(batch_size,num_agents, num_targets, 1).detach() #[n,m,1]
        GRU_outputs=GRU_outputs.reshape(batch_size, num_agents, 1, -1).expand(batch_size, num_agents, num_targets, self.GRU.feature_dim)
        
        actor_feature=torch.cat((att_features,GRU_outputs),-1) #[n,m,dim]
        actor_feature = self.reduce_dim(actor_feature)
        #max_prob = self.raise_dim(max_prob)
        #actor_feature = actor_feature + max_prob
        actor_feature = torch.cat((actor_feature,max_prob),-1)
        actor_dim = actor_feature.size()[-1]
        critic_feature = torch.sum(actor_feature,2)

        actor_feature = actor_feature.reshape(-1,actor_dim) #[n*m,dim]

        actions, entropies, log_probs, probs = self.actor(actor_feature, test)

        probs = probs.reshape(batch_size, num_agents, num_targets, -1)
        actions = actions.reshape(batch_size, num_agents, num_targets, -1)
        entropies = entropies.reshape(batch_size, num_agents, num_targets, -1)
        log_probs = log_probs.reshape(batch_size,num_agents, num_targets, -1)

        if 'shap' not in self.head_name:
            #print(global_feature.size(),feature.size())
            _, global_critic_feature = self.critic_encoder(critic_feature)
            values = self.critic(global_critic_feature).unsqueeze(1).repeat(1, num_agents, 1)
        else:
            values = self.ShapleyVcritic(critic_feature, actions)  # shape [1,1]

        if not batch:
            # squeeze all the tensor for env
            values = values.squeeze(0)
            actions = actions.squeeze(0)
            entropies = entropies.squeeze(0)
            log_probs = log_probs.squeeze(0)
            hn_self = hn_self.squeeze(0)
            hn_ToM = hn_ToM.squeeze(0)
            other_goals = other_goals.squeeze(0)
            probs = probs.squeeze(0)
            real_cover =real_others_cover.squeeze(0)
            ToM_target_cover = ToM_target_cover.squeeze(0)

        return values, actions, entropies, log_probs, hn_self, hn_ToM, other_goals, ToM_target_cover, real_others_cover, probs

    def sample_noise(self):
        self.actor.sample_noise()
        self.actor.sample_noise()

    def remove_noise(self):
        self.actor.remove_noise()
        self.actor.remove_noise()


class A3C_Multi_v3(torch.nn.Module):
    # partial obs + communication + No ToM
    def __init__(self, obs_space, action_spaces, args, device=torch.device('cpu')):
        super(A3C_Multi_v3, self).__init__()
        self.num_agents, self.num_targets, self.pose_dim = obs_space.shape
        
        lstm_out = args.lstm_out
        head_name = args.model
        self.head_name = head_name

        self.encoder = EncodeLinear(self.pose_dim, lstm_out, head_name, device)
        feature_dim = self.encoder.feature_dim

        self.attention = AttentionLayer(feature_dim, lstm_out * 2, device)
        feature_dim = self.attention.feature_dim

        self.GRU=RNN(feature_dim,lstm_out,1,device,'GRU')
        
        # create ToM, including camera ToM & target ToM
        cam_state_size = 4  # [x, y, cos, sin]

        # GNN based communication inference
        self.graph_infer = Graph_Infer(self.attention.feature_dim, device)

        feature_dim=self.GRU.feature_dim + self.attention.feature_dim*2  # GRU_output + attention_feature + max_prob + msgs
        self.reduce_dim = nn.Sequential(
            nn.Linear(feature_dim, lstm_out * 1),
            nn.ReLU(),
        )
        feature_dim = lstm_out * 1
        # create actor & critic
        self.actor = PolicyNet(feature_dim, spaces.Discrete(2), head_name, device)

        if 'shap' in head_name:
            self.ShapleyVcritic = AMCValueNet(feature_dim, head_name, 1, device)
        else:
            # centralized training
            self.critic_encoder = AttentionLayer(feature_dim, 6 * feature_dim, device)
            self.critic = ValueNet(6 * feature_dim, head_name, 1)
        self.train()
        self.device = device

    def forward(self, multi_obs, self_hiddens, test=False):
        num_agents=self.num_agents
        num_targets=self.num_targets  
        
        if len(multi_obs.size()) != 4:
            # not batch data, so turn them into a size-1 batch
            batch = False
            multi_obs = multi_obs.unsqueeze(0)
            self_hiddens = self_hiddens.unsqueeze(0)
        else:
            batch = True

        batch_size = multi_obs.size()[0]
        obs_dim = multi_obs.size()[-1] # [batch,n,m,obs_dim]
        num_both = multi_obs.size()[2]
        
        feature_target = Variable(multi_obs, requires_grad=True)

        feature_target = self.encoder(feature_target)  # [num_agents, num_targets, feature_dim]
        feature_target = feature_target.reshape(batch_size * num_agents, num_both, -1)
        att_features, global_feature = self.attention(feature_target)  
        #features: [num_agents, num_targets, feature_dim]
        att_features = att_features.reshape(batch_size, num_agents, num_both, -1)
        global_feature = global_feature.reshape(batch_size, num_agents, -1)
        att_features = att_features[:,:,:self.num_targets]

        h_self = self_hiddens.reshape(1, num_agents * batch_size, -1) # [1, num_agents, hidden_size]
        global_features = global_feature.reshape(num_agents * batch_size, 1, -1) # [num_agents, 1, feature_dim]
        
        GRU_outputs,hn_self=self.GRU(global_features,h_self)
        #GRU_outputs=GRU_outputs.squeeze(1) 
        hn_self=hn_self.reshape(batch_size, num_agents, -1)

        ############ GNN based communication ################
        edge_logits, comm_edges = self.graph_infer(global_feature.reshape(batch_size,num_agents,-1))
        edge_logits = edge_logits.reshape(-1,2)
        edge_logits = F.softmax(edge_logits, dim=-1)
        comm_edges = comm_edges.reshape(batch_size, num_agents, num_agents, 1)

        # selected msg
        comm_msg = global_features.reshape(batch_size, 1, num_agents, -1).repeat(1,num_agents,1,1)
        comm_msg = comm_msg * comm_edges
        msgs = torch.mean(comm_msg, 1).reshape(batch_size, num_agents, 1, -1).repeat(1,1, num_targets, 1)
        #msgs = self.reduce_dim(msgs)
        ######### end of GNN based communication ###############

        # decide self goals
        GRU_outputs = GRU_outputs.reshape(batch_size, num_agents, 1, -1).expand(batch_size, num_agents, num_targets, self.GRU.feature_dim)
        actor_feature=torch.cat((GRU_outputs,att_features,msgs), -1) #[n,m,dim]
        actor_feature = self.reduce_dim(actor_feature)

        actor_dim = actor_feature.size()[-1]

        critic_feature = torch.sum(actor_feature,2)
        actor_feature = actor_feature.reshape(-1,actor_dim) #[n*m,dim]

        actions, entropies, log_probs, probs = self.actor(actor_feature, test)

        probs = probs.reshape(batch_size, num_agents, num_targets, -1)
        actions = actions.reshape(batch_size, num_agents, num_targets, -1)
        entropies = entropies.reshape(batch_size, num_agents, num_targets, -1)
        log_probs = log_probs.reshape(batch_size,num_agents, num_targets, -1)

        if 'shap' not in self.head_name:
            #print(global_feature.size(),feature.size())
            _, global_critic_feature = self.critic_encoder(critic_feature)
            values = self.critic(global_critic_feature).unsqueeze(1).repeat(1, num_agents, 1)
        else:
            values = self.ShapleyVcritic(critic_feature, actions)  # shape [1,1]
        
        if not batch:
            # squeeze all the tensor for env
            values = values.squeeze(0)
            actions = actions.squeeze(0)
            entropies = entropies.squeeze(0)
            log_probs = log_probs.squeeze(0)
            hn_self = hn_self.squeeze(0)
            comm_edges = comm_edges.squeeze(0)
            probs = probs.squeeze(0)

        return values, actions, entropies, log_probs, hn_self, edge_logits, comm_edges, probs

    def sample_noise(self):
        self.actor.sample_noise()
        self.actor.sample_noise()

    def remove_noise(self):
        self.actor.remove_noise()
        self.actor.remove_noise()

class A3C_Multi_v4(torch.nn.Module):
    # partial obs + communication + obstacle
    def __init__(self, obs_space, action_spaces, args, device=torch.device('cpu')):
        super(A3C_Multi_v4, self).__init__()
        self.num_agents, num_entity, self.pose_dim = obs_space.shape  # num_targets = target + obstacles
        #self.num_obstacles = num_entity - self.num_targets

        lstm_out = args.lstm_out
        head_name = args.model
        self.head_name = head_name

        self.encoder = EncodeLinear(self.pose_dim, lstm_out, head_name, device)
        feature_dim = self.encoder.feature_dim

        self.attention = AttentionLayer(feature_dim, lstm_out * 4, device)
        feature_dim = self.attention.feature_dim

        self.GRU=RNN(feature_dim,lstm_out,1,device,'GRU')
        
        # create ToM, including camera ToM & target ToM
        cam_state_size = 4  # [x, y, cos, sin]

        # comment this model because it's useless, and if exists, ensure_share_grad will collapse due to none grad cited
        # self.ToM_target = nn.Sequential(
        #     nn.Linear(cam_state_size + self.attention.feature_dim, 10),
        #     nn.ReLU(),
        #     nn.Linear(10,1),
        #     nn.Sigmoid()
        # )
        self.ToM=RNN(cam_state_size, lstm_out, 1, device, 'GRU')
        self.other_goal=GoalLayer(lstm_out + self.attention.feature_dim + self.GRU.feature_dim + 1, device=device)
        
        # GNN based communication inference
        self.graph_infer = Graph_Infer(self.attention.feature_dim, device=device)

        feature_dim=self.GRU.feature_dim + self.attention.feature_dim + 1 + 2# GRU_output + attention_feature + max_prob + msgs

        # create actor & critic
        self.actor = PolicyNet(feature_dim, spaces.Discrete(2), head_name, device)

        if 'shap' in head_name:
            self.ShapleyVcritic = AMCValueNet(feature_dim, head_name, 1, device)
        else:
            # centralized training
            self.critic_encoder = AttentionLayer(feature_dim, 6 * feature_dim, device)
            self.critic = ValueNet(6 * feature_dim, head_name, 1)
        self.train()
        self.device = device

    def forward(self, multi_obs, self_hiddens, other_hiddens, cameras, test=False, available_actions = None):
        num_agents = self.num_agents
        num_targets = self.num_targets 
        if len(multi_obs.size()) != 4:
            # not batch data, so turn them into a size-1 batch
            batch = False
            multi_obs = multi_obs.unsqueeze(0)
            self_hiddens = self_hiddens.unsqueeze(0)
            other_hiddens = other_hiddens.unsqueeze(0)
            cameras = cameras.unsqueeze(0)
        else:
            batch = True
        batch_size = multi_obs.size()[0]
        obs_dim = multi_obs.size()[-1] # [batch,n,m,obs_dim]
        num_both = multi_obs.size()[2]

        # real target coverage, 0.4 = 800(radius)/2000(area size)
        #real_target_cover = (multi_obs[:,:,-1] <= 0.4).reshape(num_agents, num_targets, 1).detach()
        real_target_cover = (multi_obs[:,:,:self.num_targets,-1] != 0).reshape(batch_size, num_agents, num_targets, 1).detach()

        feature_target = Variable(multi_obs, requires_grad=True)

        feature_target = self.encoder(feature_target)  # [batch_size, num_agents, num_both, feature_dim]

        feature_target = feature_target.reshape(batch_size * num_agents, num_both, -1)
        att_features, global_feature = self.attention(feature_target) 
        att_features = att_features.reshape(batch_size, num_agents, num_both, -1)
        global_feature = global_feature.reshape(batch_size, num_agents, -1)
        #features: [batch, num_agents, num_targets, feature_dim]
        
        # The only difference between v3 and v4
        ############ begin ###########
        att_features = att_features[:,:,:self.num_targets]
        ############# end ############

        h_self = self_hiddens.reshape(1, num_agents * batch_size, -1) # [1, num_agents * batch, hidden_size]
        global_features = global_feature.reshape(num_agents * batch_size, 1, -1) # [num_agents * batch, 1, feature_dim]
        GRU_outputs, hn_self = self.GRU(global_features, h_self)
        #GRU_outputs=GRU_outputs.squeeze(1) 
        hn_self=hn_self.reshape(batch_size, num_agents, -1)

        # ToM_input
        cam_states = cameras  #[batch, num_agents, cam_dim]
        cam_dim = cam_states.size()[-1] # cam_dim=3

        # transform to relative cam_states 
        cam_states_duplicate = cam_states.unsqueeze(1).expand(batch_size, num_agents, num_agents, cam_dim)
        idx = (torch.ones(num_agents,num_agents) - torch.diag(torch.ones(num_agents))).bool()
        cam_states_duplicate = cam_states_duplicate[:,idx].reshape(batch_size, num_agents, num_agents-1, cam_dim) # [batch,n,n-1,3]
        cam_states_relative = cam_states_duplicate - cam_states.unsqueeze(2).expand(batch_size, num_agents, num_agents-1, cam_dim)
        cam_state_theta = ((cam_states_relative[:,:,:,-1]/180) * np.pi).reshape(batch_size, num_agents, num_agents-1, 1)
        camera_states = torch.cat((cam_states_relative[:,:,:,:2], torch.cos(cam_state_theta), torch.sin(cam_state_theta)),-1)
        camera_states = camera_states.reshape(batch_size*num_agents*(num_agents-1),1,-1) # [batch*n*n-1,1,4]
        h_others = other_hiddens.reshape(1,-1,self.ToM.feature_dim) #[1,batch*n*(n-1),dim]

        # ToM_camera prediction
        ToM_output,hn_others = self.ToM(camera_states,h_others)

        # GoalLayer input concat
        hn_others = hn_others.reshape(batch_size, num_agents, num_agents-1, -1)
        ToM_output = ToM_output.reshape(batch_size, num_agents, num_agents-1, 1, -1)
        GRU_dim=GRU_outputs.size()[-1]
        ToM_dim = ToM_output.size()[-1]
        att_dim=att_features.size()[-1]
        
        GRU_output_expand = GRU_outputs.reshape(batch_size, num_agents, 1, 1, -1)
        GRU_output_expand = GRU_output_expand.expand(batch_size, num_agents, num_agents-1, num_targets, GRU_dim)
        ToM_output = ToM_output.expand(batch_size, num_agents, num_agents-1, num_targets, ToM_dim)
        att_feature_expand = att_features.unsqueeze(2).expand(batch_size,num_agents,num_agents-1,num_targets,att_dim)

        # ToM_target: predicted version
        # camera_states = (camera_states.reshape(num_agents,num_agents-1,1,-1)).expand(num_agents, num_agents-1, num_targets, 4)
        # ToM_target_feature = torch.cat((att_feature_expand,camera_states),-1)
        # ToM_target_cover = self.ToM_target(ToM_target_feature) #[n, n-1, m, 1]
        
        # ToM_target: Groud Truth Version
        ToM_target_cover = real_target_cover.unsqueeze(1).expand(batch_size, num_agents, num_agents, num_targets, 1)
        idx = (torch.ones(num_agents, num_agents) - torch.diag(torch.ones(num_agents))).bool()
        ToM_target_cover = ToM_target_cover[:,idx].reshape(batch_size, num_agents, num_agents-1, num_targets, 1)

        # other goals
        goal_feature = torch.cat((GRU_output_expand.detach(), ToM_output, att_feature_expand.detach(), ToM_target_cover),-1) # detach ToM_target here
        other_goals = self.other_goal(goal_feature) # [batch, n, n-1, m, 2]

        ############ GNN based communication ################
        edge_logits, comm_edges = self.graph_infer(global_feature.reshape(batch_size,num_agents,-1))

        # must do softmax to logits for training
        edge_logits = F.softmax(edge_logits, dim=-1)

        comm_target_edges = comm_edges.reshape(batch_size, num_agents, num_agents, 1).repeat(1, 1, 1, num_targets)
        comm_target_edges = comm_target_edges * (real_target_cover.reshape(batch_size, num_agents, 1, num_targets).repeat(1, 1, num_agents, 1))
        comm_cnt = torch.sum(comm_target_edges,1).reshape(batch_size, num_agents, num_targets, 1) 

        # selected msg
        comm_msg = other_goals[:,:,:,:,1].detach()     # [batch, n, n-1, m]  comm_edge:[n, n, 1]
        # transfer [n,n-1,m] to [n,n,m], where the added part is the diagnoal
        # e.g. [[1,2],[3,4],[5,6]] -> [[0,1,2],[3,0,4],[5,6,0]] 
        mask = torch.ones([num_agents, num_agents-1]).to(self.device)
        mask_u = torch.triu(mask,0).reshape(1, num_agents, num_agents-1, 1).repeat(batch_size, 1, 1, num_targets)
        mask_l = torch.tril(mask,-1).reshape(1, num_agents, num_agents-1, 1).repeat(batch_size, 1, 1, num_targets)
        zeros = torch.zeros([batch_size, num_agents, 1, num_targets]).to(self.device)
        tri_u = torch.cat((zeros, comm_msg * mask_u), 2)
        tri_l = torch.cat((comm_msg * mask_l, zeros), 2)
        comm_msg = tri_u + tri_l    # [batch, n, n, m] 

        # evaluation
        #comm_edges = 1 - comm_edges
        #comm_cnt = self.num_agents - 1 - comm_cnt
        #comm_edges = torch.zeros(self.num_agents,self.num_agents,1)
        #comm_cnt  = torch.zeros(self.num_agents,self.num_targets,1)
        #end of evaluation
        comm_msg = comm_msg * comm_target_edges
        msgs = torch.sum(comm_msg, 1).reshape(batch_size, num_agents, num_targets, 1)
        msgs = torch.cat((msgs,comm_cnt),-1)

        ######### end of GNN based communication ###############

        # decide self goals
        max_prob , _ = torch.max(other_goals[:,:,:,:,1],2) 
        max_prob = max_prob.reshape(batch_size, num_agents, num_targets, 1).detach() #[batch,n,m,1]
        GRU_outputs = GRU_outputs.reshape(batch_size, num_agents, 1, -1).expand(batch_size, num_agents, num_targets, self.GRU.feature_dim)
        actor_feature=torch.cat((GRU_outputs,att_features, max_prob, msgs), -1) #[batch,n,m,dim]
        actor_dim = actor_feature.size()[-1]
        critic_feature = torch.sum(actor_feature,2)#.reshape(batch_size, 1, -1).repeat(1, num_agents, 1) #expand(num_agents, num_agents*actor_dim) #[b,n,dim]
        actor_feature = actor_feature.reshape(-1,actor_dim) #[batch*n*m,dim]

        if available_actions is not None:
            available_actions = available_actions.reshape(batch_size * num_agents * num_targets, -1)
        actions, entropies, log_probs, probs = self.actor(actor_feature, test, available_actions)

        probs = probs.reshape(batch_size, num_agents, num_targets, -1)
        actions = actions.reshape(batch_size, num_agents, num_targets, -1)
        entropies = entropies.reshape(batch_size, num_agents, num_targets, -1)
        log_probs = log_probs.reshape(batch_size,num_agents, num_targets, -1)

        if 'shap' not in self.head_name:
            #print(global_feature.size(),feature.size())
            _, global_critic_feature = self.critic_encoder(critic_feature)
            values = self.critic(global_critic_feature).unsqueeze(1).repeat(1, num_agents, 1)
        else:
            values = self.ShapleyVcritic(critic_feature, actions)  # shape [1,1]

        if not batch:
            # squeeze all the tensor for env
            values = values.squeeze(0)
            actions = actions.squeeze(0)
            entropies = entropies.squeeze(0)
            log_probs = log_probs.squeeze(0)
            hn_self = hn_self.squeeze(0)
            hn_others = hn_others.squeeze(0)
            other_goals = other_goals.squeeze(0)
            edge_logits = edge_logits.squeeze(0)
            comm_edges = comm_edges.squeeze(0)
            probs = probs.squeeze(0)
        return values, actions, entropies, log_probs, hn_self, hn_others, other_goals, edge_logits, comm_edges, probs

    def sample_noise(self):
        self.actor.sample_noise()
        self.actor.sample_noise()

    def remove_noise(self):
        self.actor.remove_noise()
        self.actor.remove_noise()

class A3C_Multi_v5(torch.nn.Module):
    # partial obs + communication + obstacle
    def __init__(self, obs_space, action_spaces, args, device=torch.device('cpu')):
        super(A3C_Multi_v5, self).__init__()
        self.num_agents, num_entity, self.pose_dim = obs_space.shape  # num_targets = target + obstacles
        #self.num_obstacles = num_entity - self.num_targets
        self.reduce_comm = args.reduce_comm
        lstm_out = args.lstm_out
        head_name = args.model
        self.head_name = head_name

        self.encoder = EncodeLinear(self.pose_dim, lstm_out, head_name, device)
        feature_dim = self.encoder.feature_dim

        self.attention = AttentionLayer(feature_dim, lstm_out * 2, device)
        feature_dim = self.attention.feature_dim

        self.GRU=RNN(feature_dim ,lstm_out ,1,device,'GRU')
        
        # create ToM, including camera ToM & target ToM
        cam_state_size = 4  # [x, y, cos, sin]

        # comment this model because it's useless, and if exists, ensure_share_grad will collapse due to none grad cited
        self.ToM_target = nn.Sequential(
            nn.Linear(lstm_out + self.attention.feature_dim*2, lstm_out),
            nn.ReLU(),
            nn.Linear(lstm_out,1),
            nn.Sigmoid()
        )
        self.ToM=RNN(cam_state_size, lstm_out, 1, device, 'GRU')
        self.other_goal=GoalLayer(lstm_out + self.attention.feature_dim + self.GRU.feature_dim, device=device)
        
        # GNN based communication inference
        self.graph_infer = Graph_Infer(self.attention.feature_dim + self.ToM.feature_dim, device=device)

        feature_dim=self.GRU.feature_dim + self.attention.feature_dim#+ 1 + 2# GRU_output + attention_feature + max_prob + msgs

        # ToM goal & msgs have too low dimension, so raise them to match GRU&attention
        # self.raise_dim = nn.Sequential(
        #     nn.Linear(3, 16),
        #     nn.ReLU(),
        #     nn.Linear(16,16),
        #     nn.ReLU()
        # )

        self.reduce_dim = nn.Sequential(
            #nn.Linear(feature_dim, lstm_out * 2),
            #nn.ReLU(),
            nn.Linear(feature_dim , lstm_out),
            nn.ReLU(),
        )
        feature_dim = lstm_out + 3
        # create actor & critic
        self.actor = PolicyNet(feature_dim, spaces.Discrete(2), head_name, device)

        if 'shap' in head_name:
            self.ShapleyVcritic = AMCValueNet(feature_dim, head_name, 1, device)
        else:
            # centralized training
            self.critic_encoder = AttentionLayer(feature_dim, 6 * feature_dim, device)
            self.critic = ValueNet(6 * feature_dim, head_name, 1)
        self.train()
        self.device = device

    def forward(self, multi_obs, self_hiddens, ToM_hiddens, cameras, comm_domain, test=False, available_actions = None):
        num_agents = self.num_agents
        num_targets = self.num_targets

        if len(multi_obs.size()) != 4:
            # not batch data, so turn them into a size-1 batch
            batch = False
            multi_obs = multi_obs.unsqueeze(0)
            self_hiddens = self_hiddens.unsqueeze(0)
            ToM_hiddens = ToM_hiddens.unsqueeze(0)
            cameras = cameras.unsqueeze(0)
            comm_domain = comm_domain.unsqueeze(0)
        else:
            batch = True

        batch_size = multi_obs.size()[0]
        obs_dim = multi_obs.size()[-1] # [batch,n,m,obs_dim]
        num_both = multi_obs.size()[2]

        # compute real cover: whether target covered by an agent is coverd by another agent.
        # real target coverage, 0.4 = 800(radius)/2000(area size)
        #real_target_cover = (multi_obs[:,:,-1] <= 0.4).reshape(num_agents, num_targets, 1).detach()
        idx = (torch.ones(num_agents, num_agents) - torch.diag(torch.ones(num_agents))).bool()
        real_target_cover = (multi_obs[:,:,:self.num_targets,-1] != 0).reshape(batch_size, num_agents, num_targets, 1).detach()

        real_self_cover = real_target_cover.unsqueeze(2).repeat(1, 1, num_agents - 1, 1, 1)
        real_others_cover = real_target_cover.unsqueeze(1).repeat(1, num_agents, 1, 1, 1)
        real_others_cover = real_others_cover[:,idx].reshape(batch_size, num_agents, num_agents-1, num_targets, 1)
        real_cover = real_others_cover # for ToM target training

        feature_target = Variable(multi_obs, requires_grad=True)

        feature_target = self.encoder(feature_target)  # [batch_size, num_agents, num_both, feature_dim]

        feature_target = feature_target.reshape(batch_size * num_agents, num_both, -1)
        att_features, global_feature = self.attention(feature_target) 
        att_features = att_features.reshape(batch_size, num_agents, num_both, -1)
        global_feature = global_feature.reshape(batch_size, num_agents, -1)
        #features: [batch, num_agents, num_targets, feature_dim]
        
        # split features into targets & obstacles
        ############ begin ###########
        obstacle_features = torch.mean(att_features[:,:,self.num_targets:], 2) if num_both > num_targets else torch.zeros(batch_size, num_agents, att_features.size()[-1]).to(self.device)
        att_features = att_features[:,:,:self.num_targets]
        ############# end ############

        h_self = self_hiddens.reshape(1, num_agents * batch_size, -1) # [1, num_agents * batch, hidden_size]
        global_features = global_feature.reshape(num_agents * batch_size, 1, -1) # [num_agents * batch, 1, feature_dim]
        GRU_outputs, hn_self = self.GRU(global_features, h_self)
        #GRU_outputs=GRU_outputs.squeeze(1) 
        hn_self=hn_self.reshape(batch_size, num_agents, -1)

        # ToM_input
        cam_states = cameras  #[batch, num_agents, cam_dim]
        cam_dim = cam_states.size()[-1] # cam_dim=3

        # transform to relative cam_states 
        cam_states_duplicate = cam_states.unsqueeze(1).expand(batch_size, num_agents, num_agents, cam_dim)
        #idx = (torch.ones(num_agents,num_agents) - torch.diag(torch.ones(num_agents))).bool()
        #cam_states_duplicate = cam_states_duplicate[:,idx].reshape(batch_size, num_agents, num_agents-1, cam_dim) # [batch,n,n-1,3]
        cam_states_relative = cam_states_duplicate - cam_states.unsqueeze(2).expand(batch_size, num_agents, num_agents, cam_dim)
        cam_state_theta = ((cam_states_relative[:,:,:,-1]/180) * np.pi).reshape(batch_size, num_agents, num_agents, 1)
        camera_states = torch.cat((cam_states_relative[:,:,:,:2], torch.cos(cam_state_theta), torch.sin(cam_state_theta)),-1)
        camera_states = camera_states.reshape(batch_size*num_agents*(num_agents),1,-1) # [batch*n*n,1,4]
        h_ToM = ToM_hiddens.reshape(1,-1,self.ToM.feature_dim) #[1,batch*n*(n),dim]

        # ToM_camera prediction
        ToM_output,hn_ToM = self.ToM(camera_states,h_ToM)

        # GoalLayer input concat
        hn_ToM = hn_ToM.reshape(batch_size, num_agents, num_agents, -1)
        ToM_output = ToM_output.reshape(batch_size, num_agents, num_agents, -1)
        ToM_output_other = ToM_output[:,idx].reshape(batch_size, num_agents, num_agents-1, 1, -1)

        GRU_dim=GRU_outputs.size()[-1]
        ToM_dim=ToM_output.size()[-1]
        att_dim=att_features.size()[-1]
        GRU_output_expand = GRU_outputs.reshape(batch_size, num_agents, 1, 1, -1)
        GRU_output_expand = GRU_output_expand.expand(batch_size, num_agents, num_agents-1, num_targets, GRU_dim)
        ToM_output_expand = ToM_output_other.expand(batch_size, num_agents, num_agents-1, num_targets, ToM_dim)
        att_feature_expand = att_features.unsqueeze(2).expand(batch_size,num_agents,num_agents-1,num_targets,att_dim)

        # ToM_target: predicted version
        #camera_states_others = camera_states.reshape(batch_size, num_agents, num_agents, -1)[:,idx]
        #camera_states = (camera_states_others.reshape(batch_size,num_agents,num_agents-1,1,-1)).repeat(1, 1, 1, num_targets, 1)
        obstacle_features = obstacle_features.reshape(batch_size, num_agents, 1, 1, -1).repeat(1, 1, num_agents - 1, num_targets, 1)
        ToM_target_feature = torch.cat((att_feature_expand.detach(), obstacle_features.detach(), ToM_output_expand),-1)
        ToM_target_cover = self.ToM_target(ToM_target_feature) #[b, n, n-1, m, 1]
        #print(ToM_target_cover)
        # ToM_target: Groud Truth Version
        #ToM_target_cover = real_target_cover.unsqueeze(1).expand(batch_size, num_agents, num_agents, num_targets, 1)
        #idx = (torch.ones(num_agents, num_agents) - torch.diag(torch.ones(num_agents))).bool()
        #ToM_target_cover = ToM_target_cover[:,idx].reshape(batch_size, num_agents, num_agents-1, num_targets, 1)
        

        # other goals
        goal_feature = torch.cat((att_feature_expand.detach(), GRU_output_expand.detach(), ToM_output_expand),-1) # detach ToM_target here
        other_goals = self.other_goal(goal_feature) # [batch, n, n-1, m, 2]

        # prepare masks
        mask = torch.ones([num_agents, num_agents-1]).to(self.device)
        mask_u = torch.triu(mask,0).reshape(1, num_agents, num_agents-1, 1).repeat(batch_size, 1, 1, num_targets)
        mask_l = torch.tril(mask,-1).reshape(1, num_agents, num_agents-1, 1).repeat(batch_size, 1, 1, num_targets)
        zeros = torch.zeros([batch_size, num_agents, 1, num_targets]).to(self.device)

        ############ GNN based communication ################
        ToM_goals = (other_goals[:,:,:,:,1] > 0.5).detach()
        tri_u_ToM = torch.cat((zeros, ToM_goals * mask_u), 2)
        tri_l_ToM = torch.cat((ToM_goals * mask_l, zeros), 2)
        ToM_goals = tri_u_ToM + tri_l_ToM
        diag_idx = torch.diag(torch.ones(num_agents)).bool()
        ToM_goals[:,diag_idx] += 1
        ToM_goals = ToM_goals.unsqueeze(-1) #[b, n, n, m, 1]
        node_features = att_features.unsqueeze(2).repeat(1,1,num_agents,1,1).detach()
        node_features = torch.sum(node_features * ToM_goals, 3) # sum all targets
        node_features = torch.cat((node_features, ToM_output.detach()),-1).reshape(batch_size * num_agents, num_agents, -1)

        #edge_logits, comm_edges = self.graph_infer(global_feature.reshape(batch_size,num_agents,-1))
        edge_logits, comm_edges = self.graph_infer(node_features)
        edge_logits = edge_logits.reshape(batch_size, num_agents, num_agents, num_agents, -1)[:,diag_idx] #[b, n, n, 2]
        comm_edges = comm_edges.reshape(batch_size, num_agents, num_agents, num_agents, -1)[:,diag_idx]  # [b, n, n, 1]
        comm_edges = comm_edges * comm_domain # only communicate with agents in self domain
        comm_domain_reshape = comm_domain.reshape(-1,1).repeat(1,2) #[b*n*n, 2]

        # only for ablation test
        #comm_edges = torch.zeros(batch_size, num_agents, num_agents, 1).to(self.device)

        edge_logits = edge_logits.reshape(-1,2)[comm_domain_reshape]
        edge_logits = edge_logits.reshape(-1,2) #[k,2] only logits of those edges in comm domain can be saved for training
        edge_logits = F.softmax(edge_logits, dim=-1)

        comm_target_edges = comm_edges.reshape(batch_size, num_agents, num_agents, 1).repeat(1, 1, 1, num_targets)
        # although we use real target cover here, but it is only self real cover duplicate, so it still keeps the decentralized execution mode
        comm_target_edges = comm_target_edges * (real_target_cover.reshape(batch_size, num_agents, 1, num_targets).repeat(1, 1, num_agents, 1)) 
        comm_cnt = torch.sum(comm_target_edges,1).reshape(batch_size, num_agents, num_targets, 1) 

        # selected msg
        comm_msg = other_goals[:,:,:,:,1].detach()     # [batch, n, n-1, m]  comm_edge:[n, n, 1]
        # transfer [n,n-1,m] to [n,n,m], where the added part is the diagnoal
        # e.g. [[1,2],[3,4],[5,6]] -> [[0,1,2],[3,0,4],[5,6,0]] 
        tri_u = torch.cat((zeros, comm_msg * mask_u), 2)
        tri_l = torch.cat((comm_msg * mask_l, zeros), 2)
        comm_msg = tri_u + tri_l    # [batch, n, n, m] 
        # evaluation
        #comm_edges = 1 - comm_edges
        #comm_cnt = self.num_agents - 1 - comm_cnt
        #comm_edges = torch.zeros(self.num_agents,self.num_agents,1)
        #comm_cnt  = torch.zeros(self.num_agents,self.num_targets,1)
        #end of evaluation
        msgs = comm_msg * comm_target_edges
        msgs = torch.sum(msgs, 1).reshape(batch_size, num_agents, num_targets, 1)
        msgs = torch.cat((msgs,comm_cnt),-1)

        ######### end of GNN based communication ###############

        # decide self goals
        max_prob , _ = torch.max(other_goals[:,:,:,:,1],2) 
        max_prob = max_prob.reshape(batch_size, num_agents, num_targets, 1).detach() #[batch,n,m,1]
        GRU_outputs = GRU_outputs.reshape(batch_size, num_agents, 1, -1).expand(batch_size, num_agents, num_targets, self.GRU.feature_dim)
        # new version of actor feature, reduce self feature dim
        self_feature = torch.cat((att_features, GRU_outputs), -1) 
        self_feature = self.reduce_dim(self_feature)
        ToM_msgs = torch.cat((max_prob, msgs),-1)
        #ToM_msgs = self.raise_dim(ToM_msgs)
        # actor_feature = actor_feature + ToM_msgs
        actor_feature = torch.cat((self_feature,ToM_msgs),-1)

        #actor_feature=torch.cat((GRU_outputs, att_features, max_prob, msgs), -1) #[batch,n,m,dim]
        actor_dim = actor_feature.size()[-1]
        critic_feature = torch.sum(actor_feature,2)#.reshape(batch_size, 1, -1).repeat(1, num_agents, 1) #expand(num_agents, num_agents*actor_dim) #[b,n,dim]
        actor_feature = actor_feature.reshape(-1,actor_dim) #[batch*n*m,dim]

        if available_actions is not None:
            available_actions = available_actions.reshape(batch_size * num_agents * num_targets, -1)

        actions, entropies, log_probs, probs = self.actor(actor_feature, test, available_actions)

        
        if self.reduce_comm:
            zero_msgs = torch.zeros(batch_size, num_agents, num_targets, 3).to(self.device)
            zero_actor_feature = torch.cat((self_feature, zero_msgs),-1).reshape(-1,actor_dim)
            _,_,_,zero_probs = self.actor(zero_actor_feature, test, available_actions)
            ori_probs = probs.reshape(batch_size, num_agents, num_targets, 2)
            zero_probs = zero_probs.reshape(batch_size, num_agents, num_targets,2)
            a = ori_probs.max(-1)[1]
            b = zero_probs.max(-1)[1]
            real_edges = torch.sum(a ^ b,-1)

            real_edges = 1 - (real_edges == 0).float()
            real_edges = real_edges.unsqueeze(1).repeat(1, num_agents, 1)
            new_edges = real_edges.reshape(batch_size, num_agents, num_agents, 1) * comm_domain
            
            # compute new msgs
            new_target_edges = new_edges.repeat(1, 1, 1, num_targets)
            new_target_edges = new_target_edges * (real_target_cover.reshape(batch_size, num_agents, 1, num_targets).repeat(1, 1, num_agents, 1)) 
            new_cnt = torch.sum(new_target_edges,1).reshape(batch_size, num_agents, num_targets, 1)
            new_msgs = comm_msg * new_target_edges
            new_msgs = torch.sum(new_msgs, 1).reshape(batch_size, num_agents, num_targets, 1)
            new_msgs = torch.cat((new_msgs,new_cnt),-1)
            new_ToM_msgs = torch.cat((max_prob, new_msgs),-1)
            new_actor_feature = torch.cat((self_feature,new_ToM_msgs),-1)
            new_actor_feature = new_actor_feature.reshape(-1,actor_dim)

            actions, entropies, log_probs, probs = self.actor(new_actor_feature, test, available_actions)
            #print(actions)
            #print(np.abs(new_actions-actions).sum())
            #print("-------------")
        
        probs = probs.reshape(batch_size, num_agents, num_targets, -1)
        actions = actions.reshape(batch_size, num_agents, num_targets, -1)
        entropies = entropies.reshape(batch_size, num_agents, num_targets, -1)
        log_probs = log_probs.reshape(batch_size,num_agents, num_targets, -1)

        if 'shap' not in self.head_name:
            #print(global_feature.size(),feature.size())
            _, global_critic_feature = self.critic_encoder(critic_feature)
            values = self.critic(global_critic_feature).unsqueeze(1).repeat(1, num_agents, 1)
        else:
            values = self.ShapleyVcritic(critic_feature, actions)  # shape [1,1]

        if not batch:
            # squeeze all the tensor for env
            values = values.squeeze(0)
            actions = actions.squeeze(0)
            entropies = entropies.squeeze(0)
            log_probs = log_probs.squeeze(0)
            hn_self = hn_self.squeeze(0)
            hn_ToM = hn_ToM.squeeze(0)
            other_goals = other_goals.squeeze(0)
            #edge_logits = edge_logits.squeeze(0)
            comm_edges = comm_edges.squeeze(0)
            probs = probs.squeeze(0)
            real_cover =real_cover.squeeze(0)
            ToM_target_cover = ToM_target_cover.squeeze(0)

        return values, actions, entropies, log_probs, hn_self, hn_ToM, other_goals, edge_logits, comm_edges, probs, real_cover, ToM_target_cover

    def sample_noise(self):
        self.actor.sample_noise()
        self.actor.sample_noise()

    def remove_noise(self):
        self.actor.remove_noise()
        self.actor.remove_noise()
'''
class A3C_Multi_v5(torch.nn.Module):
    # partial obs + communication + obstacle
    def __init__(self, obs_space, action_spaces, args, device=torch.device('cpu')):
        super(A3C_Multi_v5, self).__init__()
        self.num_agents, num_entity, self.pose_dim = obs_space.shape  # num_targets = target + obstacles
        #self.num_obstacles = num_entity - self.num_targets
        self.reduce_comm = args.reduce_comm
        lstm_out = args.lstm_out
        head_name = args.model
        self.head_name = head_name

        self.encoder = EncodeLinear(self.pose_dim, lstm_out, head_name, device)
        feature_dim = self.encoder.feature_dim

        self.attention = AttentionLayer(feature_dim, lstm_out * 2, device)
        feature_dim = self.attention.feature_dim

        #self.GRU=RNN(feature_dim ,lstm_out ,1,device,'GRU')
        
        # create ToM, including camera ToM & target ToM
        cam_state_size = 4  # [x, y, cos, sin]

        # comment this model because it's useless, and if exists, ensure_share_grad will collapse due to none grad cited
        self.ToM_target = nn.Sequential(
            nn.Linear(lstm_out + self.attention.feature_dim*2, lstm_out),
            nn.ReLU(),
            nn.Linear(lstm_out,1),
            nn.Sigmoid()
        )
        self.ToM=RNN(cam_state_size, lstm_out, 1, device, 'GRU')
        self.other_goal=GoalLayer(lstm_out + self.attention.feature_dim, device=device)
        
        # GNN based communication inference
        self.graph_infer = Graph_Infer(self.attention.feature_dim + self.ToM.feature_dim, device=device)

        feature_dim=self.ToM.feature_dim + self.attention.feature_dim#+ 1 + 2# GRU_output + attention_feature + max_prob + msgs

        # ToM goal & msgs have too low dimension, so raise them to match GRU&attention
        # self.raise_dim = nn.Sequential(
        #     nn.Linear(3, 16),
        #     nn.ReLU(),
        #     nn.Linear(16,16),
        #     nn.ReLU()
        # )

        self.reduce_dim = nn.Sequential(
            #nn.Linear(feature_dim, lstm_out * 2),
            #nn.ReLU(),
            nn.Linear(feature_dim , lstm_out),
            nn.ReLU(),
        )
        feature_dim = lstm_out + 3
        # create actor & critic
        self.actor = PolicyNet(feature_dim, spaces.Discrete(2), head_name, device)

        if 'shap' in head_name:
            self.ShapleyVcritic = AMCValueNet(feature_dim, head_name, 1, device)
        else:
            # centralized training
            self.critic_encoder = AttentionLayer(feature_dim, 6 * feature_dim, device)
            self.critic = ValueNet(6 * feature_dim, head_name, 1)
        self.train()
        self.device = device

    def forward(self, multi_obs, self_hiddens, ToM_hiddens, cameras, comm_domain, test=False, available_actions = None):
        num_agents = self.num_agents
        num_targets = self.num_targets

        if len(multi_obs.size()) != 4:
            # not batch data, so turn them into a size-1 batch
            batch = False
            multi_obs = multi_obs.unsqueeze(0)
            self_hiddens = self_hiddens.unsqueeze(0)
            ToM_hiddens = ToM_hiddens.unsqueeze(0)
            cameras = cameras.unsqueeze(0)
            comm_domain = comm_domain.unsqueeze(0)
        else:
            batch = True

        batch_size = multi_obs.size()[0]
        obs_dim = multi_obs.size()[-1] # [batch,n,m,obs_dim]
        num_both = multi_obs.size()[2]

        # compute real cover: whether target covered by an agent is coverd by another agent.
        # real target coverage, 0.4 = 800(radius)/2000(area size)
        #real_target_cover = (multi_obs[:,:,-1] <= 0.4).reshape(num_agents, num_targets, 1).detach()
        idx = (torch.ones(num_agents, num_agents) - torch.diag(torch.ones(num_agents))).bool()
        real_target_cover = (multi_obs[:,:,:self.num_targets,-1] != 0).reshape(batch_size, num_agents, num_targets, 1).detach()

        real_self_cover = real_target_cover.unsqueeze(2).repeat(1, 1, num_agents - 1, 1, 1)
        real_others_cover = real_target_cover.unsqueeze(1).repeat(1, num_agents, 1, 1, 1)
        real_others_cover = real_others_cover[:,idx].reshape(batch_size, num_agents, num_agents-1, num_targets, 1)
        real_cover = real_others_cover # for ToM target training

        feature_target = Variable(multi_obs, requires_grad=True)

        feature_target = self.encoder(feature_target)  # [batch_size, num_agents, num_both, feature_dim]

        feature_target = feature_target.reshape(batch_size * num_agents, num_both, -1)
        att_features, global_feature = self.attention(feature_target) 
        att_features = att_features.reshape(batch_size, num_agents, num_both, -1)
        global_feature = global_feature.reshape(batch_size, num_agents, -1)
        #features: [batch, num_agents, num_targets, feature_dim]
        
        # split features into targets & obstacles
        ############ begin ###########
        obstacle_features = torch.mean(att_features[:,:,self.num_targets:], 2) if num_both > num_targets else torch.zeros(batch_size, num_agents, att_features.size()[-1]).to(self.device)
        att_features = att_features[:,:,:self.num_targets]
        ############# end ############

        h_self = self_hiddens.reshape(1, num_agents * batch_size, -1) # [1, num_agents * batch, hidden_size]
        hn_self = h_self
        global_features = global_feature.reshape(num_agents * batch_size, 1, -1) # [num_agents * batch, 1, feature_dim]
        #GRU_outputs, hn_self = self.GRU(global_features, h_self)
        #GRU_outputs=GRU_outputs.squeeze(1) 
        #hn_self=hn_self.reshape(batch_size, num_agents, -1)
        

        # ToM_input
        cam_states = cameras  #[batch, num_agents, cam_dim]
        cam_dim = cam_states.size()[-1] # cam_dim=3

        # transform to relative cam_states 
        cam_states_duplicate = cam_states.unsqueeze(1).expand(batch_size, num_agents, num_agents, cam_dim)
        #idx = (torch.ones(num_agents,num_agents) - torch.diag(torch.ones(num_agents))).bool()
        #cam_states_duplicate = cam_states_duplicate[:,idx].reshape(batch_size, num_agents, num_agents-1, cam_dim) # [batch,n,n-1,3]
        cam_states_relative = cam_states_duplicate - cam_states.unsqueeze(2).expand(batch_size, num_agents, num_agents, cam_dim)
        cam_state_theta = ((cam_states_relative[:,:,:,-1]/180) * np.pi).reshape(batch_size, num_agents, num_agents, 1)
        camera_states = torch.cat((cam_states_relative[:,:,:,:2], torch.cos(cam_state_theta), torch.sin(cam_state_theta)),-1)
        camera_states = camera_states.reshape(batch_size*num_agents*(num_agents),1,-1) # [batch*n*n,1,4]
        h_ToM = ToM_hiddens.reshape(1,-1,self.ToM.feature_dim) #[1,batch*n*(n),dim]

        # ToM_camera prediction
        ToM_output,hn_ToM = self.ToM(camera_states,h_ToM)

        # GoalLayer input concat
        hn_ToM = hn_ToM.reshape(batch_size, num_agents, num_agents, -1)
        ToM_output = ToM_output.reshape(batch_size, num_agents, num_agents, -1)
        ToM_output_other = ToM_output[:,idx].reshape(batch_size, num_agents, num_agents-1, 1, -1)
        ToM_output_self = ToM_output[:, torch.diag(torch.ones(num_agents)).bool()].reshape(batch_size, num_agents, 1, -1)

        #GRU_dim=GRU_outputs.size()[-1]
        ToM_dim=ToM_output.size()[-1]
        att_dim=att_features.size()[-1]
        #GRU_output_expand = GRU_outputs.reshape(batch_size, num_agents, 1, 1, -1)
        #GRU_output_expand = GRU_output_expand.expand(batch_size, num_agents, num_agents-1, num_targets, GRU_dim)
        ToM_output_expand = ToM_output_other.expand(batch_size, num_agents, num_agents-1, num_targets, ToM_dim)
        att_feature_expand = att_features.unsqueeze(2).expand(batch_size,num_agents,num_agents-1,num_targets,att_dim)

        # ToM_target: predicted version
        #camera_states_others = camera_states.reshape(batch_size, num_agents, num_agents, -1)[:,idx]
        #camera_states = (camera_states_others.reshape(batch_size,num_agents,num_agents-1,1,-1)).repeat(1, 1, 1, num_targets, 1)
        obstacle_features = obstacle_features.reshape(batch_size, num_agents, 1, 1, -1).repeat(1, 1, num_agents - 1, num_targets, 1)
        ToM_target_feature = torch.cat((att_feature_expand.detach(), obstacle_features.detach(), ToM_output_expand),-1)
        ToM_target_cover = self.ToM_target(ToM_target_feature) #[b, n, n-1, m, 1]
        #print(ToM_target_cover)
        # ToM_target: Groud Truth Version
        #ToM_target_cover = real_target_cover.unsqueeze(1).expand(batch_size, num_agents, num_agents, num_targets, 1)
        #idx = (torch.ones(num_agents, num_agents) - torch.diag(torch.ones(num_agents))).bool()
        #ToM_target_cover = ToM_target_cover[:,idx].reshape(batch_size, num_agents, num_agents-1, num_targets, 1)
        

        # other goals
        goal_feature = torch.cat((att_feature_expand.detach(), ToM_output_expand),-1) # detach ToM_target here
        other_goals = self.other_goal(goal_feature) # [batch, n, n-1, m, 2]

        # prepare masks
        mask = torch.ones([num_agents, num_agents-1]).to(self.device)
        mask_u = torch.triu(mask,0).reshape(1, num_agents, num_agents-1, 1).repeat(batch_size, 1, 1, num_targets)
        mask_l = torch.tril(mask,-1).reshape(1, num_agents, num_agents-1, 1).repeat(batch_size, 1, 1, num_targets)
        zeros = torch.zeros([batch_size, num_agents, 1, num_targets]).to(self.device)

        ############ GNN based communication ################
        ToM_goals = (other_goals[:,:,:,:,1] > 0.5).detach()
        tri_u_ToM = torch.cat((zeros, ToM_goals * mask_u), 2)
        tri_l_ToM = torch.cat((ToM_goals * mask_l, zeros), 2)
        ToM_goals = tri_u_ToM + tri_l_ToM
        diag_idx = torch.diag(torch.ones(num_agents)).bool()
        ToM_goals[:,diag_idx] += 1
        ToM_goals = ToM_goals.unsqueeze(-1) #[b, n, n, m, 1]
        node_features = att_features.unsqueeze(2).repeat(1,1,num_agents,1,1).detach()
        node_features = torch.sum(node_features * ToM_goals, 3) # sum all targets
        node_features = torch.cat((node_features, ToM_output.detach()),-1).reshape(batch_size * num_agents, num_agents, -1)

        #edge_logits, comm_edges = self.graph_infer(global_feature.reshape(batch_size,num_agents,-1))
        edge_logits, comm_edges = self.graph_infer(node_features)
        edge_logits = edge_logits.reshape(batch_size, num_agents, num_agents, num_agents, -1)[:,diag_idx] #[b, n, n, 2]
        comm_edges = comm_edges.reshape(batch_size, num_agents, num_agents, num_agents, -1)[:,diag_idx]  # [b, n, n, 1]
        comm_edges = comm_edges * comm_domain # only communicate with agents in self domain
        comm_domain_reshape = comm_domain.reshape(-1,1).repeat(1,2) #[b*n*n, 2]

        # only for ablation test
        #comm_edges = torch.zeros(batch_size, num_agents, num_agents, 1).to(self.device)

        edge_logits = edge_logits.reshape(-1,2)[comm_domain_reshape]
        edge_logits = edge_logits.reshape(-1,2) #[k,2] only logits of those edges in comm domain can be saved for training
        edge_logits = F.softmax(edge_logits, dim=-1)

        comm_target_edges = comm_edges.reshape(batch_size, num_agents, num_agents, 1).repeat(1, 1, 1, num_targets)
        # although we use real target cover here, but it is only self real cover duplicate, so it still keeps the decentralized execution mode
        comm_target_edges = comm_target_edges * (real_target_cover.reshape(batch_size, num_agents, 1, num_targets).repeat(1, 1, num_agents, 1)) 
        comm_cnt = torch.sum(comm_target_edges,1).reshape(batch_size, num_agents, num_targets, 1) 

        # selected msg
        comm_msg = other_goals[:,:,:,:,1].detach()     # [batch, n, n-1, m]  comm_edge:[n, n, 1]
        # transfer [n,n-1,m] to [n,n,m], where the added part is the diagnoal
        # e.g. [[1,2],[3,4],[5,6]] -> [[0,1,2],[3,0,4],[5,6,0]] 
        tri_u = torch.cat((zeros, comm_msg * mask_u), 2)
        tri_l = torch.cat((comm_msg * mask_l, zeros), 2)
        comm_msg = tri_u + tri_l    # [batch, n, n, m] 
        # evaluation
        #comm_edges = 1 - comm_edges
        #comm_cnt = self.num_agents - 1 - comm_cnt
        #comm_edges = torch.zeros(self.num_agents,self.num_agents,1)
        #comm_cnt  = torch.zeros(self.num_agents,self.num_targets,1)
        #end of evaluation
        msgs = comm_msg * comm_target_edges
        msgs = torch.sum(msgs, 1).reshape(batch_size, num_agents, num_targets, 1)
        msgs = torch.cat((msgs,comm_cnt),-1)

        ######### end of GNN based communication ###############

        # decide self goals
        max_prob , _ = torch.max(other_goals[:,:,:,:,1],2) 
        max_prob = max_prob.reshape(batch_size, num_agents, num_targets, 1).detach() #[batch,n,m,1]
        #GRU_outputs = GRU_outputs.reshape(batch_size, num_agents, 1, -1).expand(batch_size, num_agents, num_targets, self.GRU.feature_dim)
        ToM_output_self = ToM_output_self.reshape(batch_size, num_agents, 1, -1).repeat(1, 1, num_targets, 1).detach()
        # new version of actor feature, reduce self feature dim
        self_feature = torch.cat((att_features, ToM_output_self), -1) 
        self_feature = self.reduce_dim(self_feature)
        ToM_msgs = torch.cat((max_prob, msgs),-1)
        #ToM_msgs = self.raise_dim(ToM_msgs)
        # actor_feature = actor_feature + ToM_msgs
        actor_feature = torch.cat((self_feature,ToM_msgs),-1)

        #actor_feature=torch.cat((GRU_outputs, att_features, max_prob, msgs), -1) #[batch,n,m,dim]
        actor_dim = actor_feature.size()[-1]
        critic_feature = torch.sum(actor_feature,2)#.reshape(batch_size, 1, -1).repeat(1, num_agents, 1) #expand(num_agents, num_agents*actor_dim) #[b,n,dim]
        actor_feature = actor_feature.reshape(-1,actor_dim) #[batch*n*m,dim]

        if available_actions is not None:
            available_actions = available_actions.reshape(batch_size * num_agents * num_targets, -1)

        actions, entropies, log_probs, probs = self.actor(actor_feature, test, available_actions)

        
        if self.reduce_comm:
            zero_msgs = torch.zeros(batch_size, num_agents, num_targets, 3).to(self.device)
            zero_actor_feature = torch.cat((self_feature, zero_msgs),-1).reshape(-1,actor_dim)
            _,_,_,zero_probs = self.actor(zero_actor_feature, test, available_actions)
            ori_probs = probs.reshape(batch_size, num_agents, num_targets, 2)
            zero_probs = zero_probs.reshape(batch_size, num_agents, num_targets,2)
            a = ori_probs.max(-1)[1]
            b = zero_probs.max(-1)[1]
            real_edges = torch.sum(a ^ b,-1)

            real_edges = 1 - (real_edges == 0).float()
            real_edges = real_edges.unsqueeze(1).repeat(1, num_agents, 1)
            new_edges = real_edges.reshape(batch_size, num_agents, num_agents, 1) * comm_domain
            
            # compute new msgs
            new_target_edges = new_edges.repeat(1, 1, 1, num_targets)
            new_target_edges = new_target_edges * (real_target_cover.reshape(batch_size, num_agents, 1, num_targets).repeat(1, 1, num_agents, 1)) 
            new_cnt = torch.sum(new_target_edges,1).reshape(batch_size, num_agents, num_targets, 1)
            new_msgs = comm_msg * new_target_edges
            new_msgs = torch.sum(new_msgs, 1).reshape(batch_size, num_agents, num_targets, 1)
            new_msgs = torch.cat((new_msgs,new_cnt),-1)
            new_ToM_msgs = torch.cat((max_prob, new_msgs),-1)
            new_actor_feature = torch.cat((self_feature,new_ToM_msgs),-1)
            new_actor_feature = new_actor_feature.reshape(-1,actor_dim)

            actions, entropies, log_probs, probs = self.actor(new_actor_feature, test, available_actions)
            #print(actions)
            #print(np.abs(new_actions-actions).sum())
            #print("-------------")
        
        probs = probs.reshape(batch_size, num_agents, num_targets, -1)
        actions = actions.reshape(batch_size, num_agents, num_targets, -1)
        entropies = entropies.reshape(batch_size, num_agents, num_targets, -1)
        log_probs = log_probs.reshape(batch_size,num_agents, num_targets, -1)

        if 'shap' not in self.head_name:
            #print(global_feature.size(),feature.size())
            _, global_critic_feature = self.critic_encoder(critic_feature)
            values = self.critic(global_critic_feature).unsqueeze(1).repeat(1, num_agents, 1)
        else:
            values = self.ShapleyVcritic(critic_feature, actions)  # shape [1,1]

        if not batch:
            # squeeze all the tensor for env
            values = values.squeeze(0)
            actions = actions.squeeze(0)
            entropies = entropies.squeeze(0)
            log_probs = log_probs.squeeze(0)
            hn_self = hn_self.squeeze(0)
            hn_ToM = hn_ToM.squeeze(0)
            other_goals = other_goals.squeeze(0)
            #edge_logits = edge_logits.squeeze(0)
            comm_edges = comm_edges.squeeze(0)
            probs = probs.squeeze(0)
            real_cover =real_cover.squeeze(0)
            ToM_target_cover = ToM_target_cover.squeeze(0)

        return values, actions, entropies, log_probs, hn_self, hn_ToM, other_goals, edge_logits, comm_edges, probs, real_cover, ToM_target_cover

    def sample_noise(self):
        self.actor.sample_noise()
        self.actor.sample_noise()

    def remove_noise(self):
        self.actor.remove_noise()
        self.actor.remove_noise()
'''
class A3C_baseline_decentralized(torch.nn.Module):
    def __init__(self, obs_space, action_spaces, args, device=torch.device('cpu')):
        super(A3C_baseline_decentralized, self).__init__()
        self.num_agents, self.num_targets, self.pose_dim = obs_space.shape

        lstm_out = args.lstm_out
        head_name = args.model
        self.head_name = head_name

        self.encoder = EncodeLinear(self.pose_dim, lstm_out, head_name, device)
        feature_dim = self.encoder.feature_dim

        self.attention = AttentionLayer(feature_dim, lstm_out * 2, device)
        feature_dim = self.attention.feature_dim

        # create actor & critic
        self.actor = PolicyNet(feature_dim, spaces.Discrete(2), head_name, device)

        if 'shap' in head_name:
            self.ShapleyVcritic = AMCValueNet(feature_dim, head_name, 1, device)
        else:
            self.critic_encoder = AttentionLayer(feature_dim, 6 * feature_dim, device)
            self.critic = ValueNet(6 * feature_dim, head_name, 1)
            #self.critic = ValueNet(feature_dim*self.num_agents, head_name, 1)

        self.train()
        self.device = device

    def forward(self, inputs, test=False, available_actions = None):
        #pos_obs = inputs[:, :self.num_targets]
        #obstacle_obs = inputs[:, self.num_targets:]

        if len(inputs.size()) != 4:
            batch = False
            inputs = inputs.unsqueeze(0)
        else:
            batch = True
            
        batch_size = inputs.size()[0]
        num_agents = inputs.size()[1]
        num_both = inputs.size()[2]
        inputs_dim = inputs.size()[3]

        feature_target = Variable(inputs, requires_grad=True)
        feature_target = self.encoder(feature_target)  # batch, num_agent, num_target, feature_dim

        #feature_target = feature_target.reshape(-1, self.encoder.feature_dim).unsqueeze(0)  # [1, agent*target, feature_dim]
        feature_target = feature_target.reshape(batch_size * self.num_agents, num_both, -1)
        feature, global_feature = self.attention(feature_target)  # batch * num_agents, num_both, feature_dim

        # for obstacle env
        feature = feature[:,:self.num_targets]

        #feature = feature.squeeze()
        global_feature = global_feature.reshape(batch_size, num_agents, -1) # batch, num_agents, feature_dim

        actor_feature = feature.reshape(batch_size*self.num_agents*self.num_targets, -1)

        if available_actions is not None:
            available_actions = available_actions.reshape(batch_size * self.num_agents * self.num_targets, -1)
        actions, entropies, log_probs, probs = self.actor(actor_feature, test, available_actions)

        actions = actions.reshape(batch_size, self.num_agents, self.num_targets, -1)
        entropies = entropies.reshape(batch_size, self.num_agents, self.num_targets, -1)
        log_probs = log_probs.reshape(batch_size, self.num_agents, self.num_targets, -1)
        probs = probs.reshape(batch_size, self.num_agents, self.num_targets, -1)

        if 'shap' not in self.head_name:
            _, global_critic_feature = self.critic_encoder(global_feature)
            values = self.critic(global_critic_feature)
            values = values.reshape(batch_size, 1, -1).repeat(1, self.num_agents, 1)
        else:
            values = self.ShapleyVcritic(feature, actions)  # shape [1,1]

        if not batch:
            # squeeze all the tensor for env
            values = values.squeeze(0)
            actions = actions.squeeze(0)
            entropies = entropies.squeeze(0)
            log_probs = log_probs.squeeze(0)
            probs = probs.squeeze(0)

        return values, actions, entropies, log_probs, probs

    def sample_noise(self):
        self.actor.sample_noise()
        self.actor.sample_noise()

    def remove_noise(self):
        self.actor.remove_noise()
        self.actor.remove_noise()


class A3C_baseline_centralized(torch.nn.Module):
    def __init__(self, obs_space, action_spaces, args, device=torch.device('cpu')):
        super(A3C_baseline_centralized, self).__init__()
        self.num_agents, self.num_targets, self.obs_dim = obs_space.shape
        self.cam_state_size = 3
        self.pose_dim = self.obs_dim + self.cam_state_size   # feature include both obs and camera states
        
        lstm_out = args.lstm_out
        head_name = args.model
        self.head_name = head_name

        self.encoder = EncodeLinear(self.pose_dim, lstm_out, head_name, device)
        feature_dim = self.encoder.feature_dim

        self.attention = AttentionLayer(feature_dim, lstm_out, device)
        feature_dim = self.attention.feature_dim

        # create actor & critic
        self.actor = PolicyNet(feature_dim, spaces.Discrete(2), head_name, device)

        if 'shap' in head_name:
            self.ShapleyVcritic = AMCValueNet(feature_dim, head_name, 1, device)
        else:
            self.critic = ValueNet(feature_dim, head_name, 1)

        self.train()
        self.device = device

    def forward(self, inputs, cameras, available_actions = None, test=False):
        if len(inputs.size()) != 4:
            batch = False
            inputs = inputs.unsqueeze(0)
        else:
            batch = True

        batch_size = inputs.size()[0]
        num_agents = inputs.size()[1]
        num_both = inputs.size()[2]
        inputs_dim = inputs.size()[3]

        pos_obs = inputs
        
        feature_target = Variable(pos_obs, requires_grad=True)

        cam_states = cameras.reshape(batch_size, self.num_agents, 1, -1).repeat(1, 1, num_both, 1)
        feature_target = torch.cat((feature_target,cam_states),-1)

        feature_target = self.encoder(feature_target)  # num_agent, num_both, feature_dim

        feature_target = feature_target.reshape(-1, self.encoder.feature_dim).unsqueeze(0)  # [1, batch*agent*target, feature_dim]
        feature, global_feature = self.attention(feature_target)  # num_agents, feature_dim
        # for obstacles
        feature = feature.reshape(batch_size, self.num_agents, num_both, -1)
        feature = feature[:,:,:self.num_targets]
        feature = feature.reshape(batch_size*self.num_agents * self.num_targets, -1)

        actions, entropies, log_probs, probs = self.actor(feature, test)

        actions = actions.reshape(batch_size, self.num_agents, self.num_targets, -1)
        entropies = entropies.reshape(batch_size, self.num_agents, self.num_targets, -1)
        log_probs = log_probs.reshape(batch_size, self.num_agents, self.num_targets, -1)
        probs = probs.reshape(batch_size, self.num_agents, self.num_targets, -1)
        
        if 'shap' not in self.head_name:
            values = self.critic(global_feature)
        else:
            values = self.ShapleyVcritic(feature, actions)  # shape [1,1]

        if not batch:
            # squeeze all the tensor for env
            values = values.squeeze(0)
            actions = actions.squeeze(0)
            entropies = entropies.squeeze(0)
            log_probs = log_probs.squeeze(0)
            probs = probs.squeeze(0)

        return values, actions, entropies, log_probs, probs

    def sample_noise(self):
        self.actor.sample_noise()
        self.actor.sample_noise()

    def remove_noise(self):
        self.actor.remove_noise()
        self.actor.remove_noise()