import torch as th
import torch.nn as nn
import torch.nn.functional as F


class IBFComm(nn.Module):
	def __init__(self, input_shape, args):
		super(IBFComm, self).__init__()
		self.args = args
		self.n_agents = args.n_agents

		self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
		self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		self.fc3 = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.inference_model = nn.Sequential(
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)
		#rnn
		self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)

	def forward(self, inputs,hidden_state):
		if hidden_state is None:
			hidden_state=th.zeros((inputs.shape[0],self.args.rnn_hidden_dim)).cuda()
		x = F.relu(self.fc1(inputs))
		h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
		x = self.rnn(x, h_in)
		x = F.relu(self.fc2(x))
		gaussian_params = self.fc3(x)#+inputs[:,:3].view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)

		#normalize
		#gaussian_params=gaussian_params.view(-1,self.n_agents,self.args.comm_embed_dim)
		#gaussian_params=5*gaussian_params/(gaussian_params**2).sum(dim=-1,keepdim=True)**0.5

		mu = gaussian_params
		#sigma = F.softplus(gaussian_params[:, self.args.comm_embed_dim * self.n_agents:])
		sigma = 0*th.ones(mu.shape).cuda()#th.ones(mu.shape).cuda()#th.zeros(mu.shape).cuda()

		return mu, sigma,hidden_state
