import torch as th
import torch.nn as nn
import torch.nn.functional as F


class IBFComm(nn.Module):
	def __init__(self, input_shape, args):
		super(IBFComm, self).__init__()
		self.args = args
		self.n_agents = args.n_agents

		self.fc_comm=nn.Linear(9,3)

		self.fc1 = nn.Linear(input_shape-3*self.n_agents+3, args.rnn_hidden_dim)
		self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		self.fc3 = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.inference_model = nn.Sequential(
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)

	def forward(self, inputs):
		comm=inputs[:,-3*self.n_agents:].view(-1,self.n_agents,3)
		comm1=th.cat([comm.mean(dim=1),(comm**2).mean(dim=1),(comm[:,:,0:1]*comm[:,:,1:2]).mean(dim=1),(comm[:,:,0:1]*comm[:,:,2:3]).mean(dim=1),(comm[:,:,1:2]*comm[:,:,2:3]).mean(dim=1)],dim=-1)
		comm2=self.fc_comm(comm1)
		x = F.relu(self.fc1(th.cat([inputs[:,:-3*self.n_agents],comm2],-1)))#F.relu(self.fc1(inputs))
		x = F.relu(self.fc2(x))
		gaussian_params = self.fc3(x)#+inputs[:,:3].view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)

		mu = gaussian_params
		#sigma = F.softplus(gaussian_params[:, self.args.comm_embed_dim * self.n_agents:])
		sigma = 1*th.ones(mu.shape).cuda()#th.ones(mu.shape).cuda()#th.zeros(mu.shape).cuda()

		return mu, sigma
