import torch as th
import torch.nn as nn
import torch.nn.functional as F


class IBFComm(nn.Module):
	def __init__(self, input_shape, args):
		super(IBFComm, self).__init__()
		self.args = args
		self.n_agents = args.n_agents

		self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
		self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		self.fc3 = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.inference_model = nn.Sequential(
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)

	def forward(self, inputs):
		#oracle
		mu=th.zeros((inputs.shape[0],self.n_agents,3)).cuda()
		for i in range(inputs.shape[0]):
			if inputs[i,2]>=0:
				mu[i,int(inputs[i,2]),0:2]=inputs[i,0:2]
		sigma = 0*th.ones(mu.shape).cuda()
		return mu.detach(), sigma


		x = F.relu(self.fc1(inputs))
		x = F.relu(self.fc2(x))
		gaussian_params = self.fc3(x)#+inputs[:,:3].view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)

		#normalize
		#gaussian_params=gaussian_params.view(-1,self.n_agents,self.args.comm_embed_dim)
		#gaussian_params=5*gaussian_params/(gaussian_params**2).sum(dim=-1,keepdim=True)**0.5

		mu = gaussian_params
		#sigma = F.softplus(gaussian_params[:, self.args.comm_embed_dim * self.n_agents:])
		sigma = 0*th.ones(mu.shape).cuda()#th.ones(mu.shape).cuda()#th.zeros(mu.shape).cuda()

		return mu, sigma
