import torch as th
import torch.nn as nn
import torch.nn.functional as F


class IBFCommS(nn.Module):
	def __init__(self, input_shape, args):
		super(IBFCommS, self).__init__()
		self.args = args
		self.n_agents = args.n_agents
		self.n_cluster=3

		self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
		self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		self.fc3 = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.fc1_new = nn.Linear(input_shape, args.rnn_hidden_dim)
		self.fc2_new = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		self.fc3_new = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.inference_model = nn.Sequential(
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)

		# simi
		self.simi_w1=[nn.Parameter(th.rand(self.n_agents*args.comm_embed_dim,self.n_agents*args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b1=[nn.Parameter(th.rand(self.n_agents*args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_w2=[nn.Parameter(th.rand(self.n_agents*args.comm_embed_dim,self.n_agents*args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b2=[nn.Parameter(th.rand(self.n_agents*args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_mask=th.zeros(self.n_agents*args.comm_embed_dim,self.n_agents*args.comm_embed_dim).cuda()
		for i in range(self.n_agents):
			self.simi_mask[args.comm_embed_dim*i:args.comm_embed_dim*(i+1),args.comm_embed_dim*i:args.comm_embed_dim*(i+1)]=1

		# simi_new for sov3
		self.simi_w1_new=[nn.Parameter(th.rand(self.n_agents*args.comm_embed_dim,self.n_agents*args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b1_new=[nn.Parameter(th.rand(self.n_agents*args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_w2_new=[nn.Parameter(th.rand(self.n_agents*args.comm_embed_dim,self.n_agents*args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b2_new=[nn.Parameter(th.rand(self.n_agents*args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_mask=th.zeros(self.n_agents*args.comm_embed_dim,self.n_agents*args.comm_embed_dim).cuda()
		for i in range(self.n_agents):
			self.simi_mask[args.comm_embed_dim*i:args.comm_embed_dim*(i+1),args.comm_embed_dim*i:args.comm_embed_dim*(i+1)]=1

		#sov2
		self.switched=False

	def switch_mode(self):
		self.switched=True

	def forward(self, inputs):
		x = F.relu(self.fc1(inputs))
		x = F.relu(self.fc2(x))
		gaussian_params = self.fc3(x)#+inputs[:,:3].view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)

		mu = gaussian_params
		#sigma = F.softplus(gaussian_params[:, self.args.comm_embed_dim * self.n_agents:])
		sigma = 0*th.ones(mu.shape).cuda()#1*th.ones(mu.shape).cuda()#th.ones(mu.shape).cuda()#th.zeros(mu.shape).cuda()

		mu_trans=[F.relu(mu@(self.simi_w1[i]*self.simi_mask)+self.simi_b1[i]) for i in range(self.n_cluster)]
		mu_trans=[mu_trans[i]@(self.simi_w2[i]*self.simi_mask)+self.simi_b2[i] for i in range(self.n_cluster)]

		#sov2
		if False and self.switched:
			x = F.relu(self.fc1_new(inputs))
			x = F.relu(self.fc2_new(x))
			gaussian_params = self.fc3_new(x)
			mu = gaussian_params
		#sov3
		if self.switched:
			x = F.relu(self.fc1_new(inputs))
			x = F.relu(self.fc2_new(x))
			gaussian_params = self.fc3_new(x)
			mu = gaussian_params
			mu_trans_new=[F.relu(mu@(self.simi_w1_new[i]*self.simi_mask)+self.simi_b1_new[i]) for i in range(self.n_cluster)]
			mu_trans_new=[mu_trans_new[i]@(self.simi_w2_new[i]*self.simi_mask)+self.simi_b2_new[i] for i in range(self.n_cluster)]
			mu_trans=mu_trans+mu_trans_new

		return mu, sigma,mu_trans
