import torch as th
import torch.nn as nn
import torch.nn.functional as F


class OrcComm(nn.Module):
	def __init__(self, input_shape, args):
		super(OrcComm, self).__init__()
		self.args = args
		self.n_agents = args.n_agents

		self.inference_model = nn.Sequential(
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)

	def forward(self, inputs):
		mu = inputs[:,:self.args.comm_embed_dim]
		#assert ((mu[:,1]==0)+(mu[:,1]==1)).all()
		#mu[:,0]=-1*(1-mu[:,1])+mu[:,0]*mu[:,1]#-10000*(1-mu[:,1])+mu[:,0]*mu[:,1]
		#mu[:,1]=-1
		mu[:,:-1]=-1*(1-mu[:,-1:])+mu[:,:-1]*mu[:,-1:]#-10000*(1-mu[:,1])+mu[:,0]*mu[:,1]
		#mu[:,-1]=0#orcv4
		#mu[:,3:]=0#orcv1
		mu = mu.view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)
		sigma = th.zeros(mu.shape).cuda()

		return mu, sigma
