import torch as th
import torch.nn as nn
import torch.nn.functional as F

import random

class IBFCommSV(nn.Module):
	def __init__(self, input_shape, args):
		super(IBFCommSV, self).__init__()
		self.args = args
		self.n_agents = args.n_agents
		self.n_cluster=3
		self.comm_embed_dim=args.comm_embed_dim

		self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
		self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		#self.fc3 = nn.Linear(args.rnn_hidden_dim, 2*args.comm_embed_dim * self.n_agents)#sov7
		self.fc3 = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.fc_sigma=nn.Linear(args.rnn_hidden_dim, self.n_cluster)

		self.inference_model = nn.Sequential(
			#nn.Linear(input_shape + 2*args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),#sov7
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)

		# simi
		#self.simi_w1=[nn.Parameter(th.rand(2*args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]#sov7
		self.simi_w1=[nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b1=[nn.Parameter(th.rand(args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_w2=[nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b2=[nn.Parameter(th.rand(args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]

	def forward(self, inputs):
		x = F.relu(self.fc1(inputs))
		x = F.relu(self.fc2(x))
		gaussian_params = self.fc3(x)#+inputs[:,:3].view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)
		sigma_trans=self.fc_sigma(x)

		mu = gaussian_params
		#sigma = F.softplus(gaussian_params[:, self.args.comm_embed_dim * self.n_agents:])
		sigma = 0*th.ones(mu.shape).cuda()#1*th.ones(mu.shape).cuda()#th.ones(mu.shape).cuda()#th.zeros(mu.shape).cuda()

		#sov4
		mu_trans=[F.relu(mu.view(-1,self.comm_embed_dim)@self.simi_w1[i]+self.simi_b1[i]) for i in range(self.n_cluster)]
		mu_trans=[(mu_trans[i]@self.simi_w2[i]+self.simi_b2[i]).view(-1,self.n_agents*self.comm_embed_dim) for i in range(self.n_cluster)]

		return mu, sigma,mu_trans,sigma_trans
