import torch as th
import torch.nn as nn
import torch.nn.functional as F

import random

class IBFCommS(nn.Module):
	def __init__(self, input_shape, args):
		super(IBFCommS, self).__init__()
		self.args = args
		self.n_agents = args.n_agents
		self.n_cluster=3
		self.comm_embed_dim=args.comm_embed_dim

		self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
		self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		#self.fc3 = nn.Linear(args.rnn_hidden_dim, 2*args.comm_embed_dim * self.n_agents)#sov7
		self.fc3 = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.fc1_new = nn.Linear(input_shape, args.rnn_hidden_dim)
		self.fc2_new = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		self.fc3_new = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.inference_model = nn.Sequential(
			#nn.Linear(input_shape + 2*args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),#sov7
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)

		# simi
		#self.simi_w1=[nn.Parameter(th.rand(2*args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]#sov7
		self.simi_w1=[nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b1=[nn.Parameter(th.rand(args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_w2=[nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b2=[nn.Parameter(th.rand(args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		
		#att
		self.q_layer=nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)

		#sov2
		self.switched=False

		self.sov4db1=None

	def switch_mode(self):
		self.switched=True

	def forward(self, inputs):
		x = F.relu(self.fc1(inputs))
		x = F.relu(self.fc2(x))
		gaussian_params = self.fc3(x)#+inputs[:,:3].view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)

		mu = gaussian_params
		#sigma = F.softplus(gaussian_params[:, self.args.comm_embed_dim * self.n_agents:])
		sigma = 0*th.ones(mu.shape).cuda()#1*th.ones(mu.shape).cuda()#th.ones(mu.shape).cuda()#th.zeros(mu.shape).cuda()

		#sov4
		mu_trans=[F.relu(mu.view(-1,self.comm_embed_dim)@self.simi_w1[i]+self.simi_b1[i]) for i in range(self.n_cluster)]
		mu_trans=[(mu_trans[i]@self.simi_w2[i]+self.simi_b2[i]).view(-1,self.n_agents*self.comm_embed_dim) for i in range(self.n_cluster)]

		#sov7
		#mu_trans=[F.relu(mu.view(-1,self.comm_embed_dim*2)@self.simi_w1[i]+self.simi_b1[i]) for i in range(self.n_cluster)]
		#mu_trans=[(mu_trans[i]@self.simi_w2[i]+self.simi_b2[i]).view(-1,self.n_agents*self.comm_embed_dim) for i in range(self.n_cluster)]

		#sov2
		if False and self.switched:
			x = F.relu(self.fc1_new(inputs))
			x = F.relu(self.fc2_new(x))
			gaussian_params = self.fc3_new(x)
			mu = gaussian_params
		
		#sov4db1
		'''if self.sov4db1 is not None:
			mu=mu_trans[self.sov4db1]

			#db2
			#mu=mu_trans[int(random.random()*self.n_cluster)]'''

		return mu, sigma,mu_trans
