import torch as th
import torch.nn as nn
import torch.nn.functional as F


class IdComm(nn.Module):
	def __init__(self, input_shape, args):
		super(IdComm, self).__init__()
		self.args = args
		self.n_agents = args.n_agents

		self.inference_model = nn.Sequential(
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)

	def forward(self, inputs):
		mu = inputs[:,:self.args.comm_embed_dim].view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)
		sigma = th.zeros(mu.shape).cuda()

		return mu, sigma
