import torch as th
import torch.nn as nn
import torch.nn.functional as F

import random

class IBFCommSP(nn.Module):
	def __init__(self, input_shape, args):
		super(IBFCommSP, self).__init__()
		self.args = args
		self.n_agents = args.n_agents
		self.n_cluster=3
		self.comm_embed_dim=args.comm_embed_dim

		#pcav1
		'''self.fc1 = nn.ModuleList([nn.Linear(input_shape, args.rnn_hidden_dim) for _ in range(self.n_cluster)])
		self.fc2 = nn.ModuleList([nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim) for _ in range(self.n_cluster)])
		self.fc3 = nn.ModuleList([nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents) for _ in range(self.n_cluster)])'''

		#sov4
		self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
		self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		self.fc3 = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.inference_model = nn.Sequential(
			#nn.Linear(input_shape + 2*args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),#sov7
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)

		# simi
		#self.simi_w1=[nn.Parameter(th.rand(2*args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]#sov7
		self.simi_w1=[nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b1=[nn.Parameter(th.rand(args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_w2=[nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b2=[nn.Parameter(th.rand(args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]

		#sov2
		self.switched=False

		self.sov4db1=None
		
		#pcav4
		self.prob=None

		#simipca
		self.mat=th.rand((self.n_cluster*args.comm_embed_dim,args.comm_embed_dim)).cuda()*2-1
		#self.mat=th.eye(self.n_cluster*args.comm_embed_dim)[:,:args.comm_embed_dim].cuda()
		#self.mat=th.cat([th.eye(args.comm_embed_dim),th.eye(args.comm_embed_dim),th.zeros((args.comm_embed_dim,args.comm_embed_dim))],dim=0).cuda()

	def switch_mode(self):
		raise
		self.switched=True

	def forward(self, inputs):
		#pcav1
		'''xs = [F.relu(self.fc1[i](inputs)) for i in range(self.n_cluster)]
		xs = [F.relu(self.fc2[i](xs[i])) for i in range(self.n_cluster)]
		gaussian_params = [self.fc3[i](xs[i]) for i in range(self.n_cluster)]
		mu = gaussian_params
		mu_trans=mu'''

		#sov4
		x = F.relu(self.fc1(inputs))
		x = F.relu(self.fc2(x))
		gaussian_params = self.fc3(x)#+inputs[:,:3].view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)
		mu = gaussian_params
		#sigma = F.softplus(gaussian_params[:, self.args.comm_embed_dim * self.n_agents:])
		sigma = 0*th.ones(mu.shape).cuda()#1*th.ones(mu.shape).cuda()#th.ones(mu.shape).cuda()#th.zeros(mu.shape).cuda()
		#sov4
		mu_trans=[F.relu(mu.view(-1,self.comm_embed_dim)@self.simi_w1[i]+self.simi_b1[i]) for i in range(self.n_cluster)]
		mu_trans=[(mu_trans[i]@self.simi_w2[i]+self.simi_b2[i]).view(-1,self.n_agents*self.comm_embed_dim) for i in range(self.n_cluster)]

		#pcav4
		if self.prob is not None:
			mat_mask=(self.prob>0.1).float().unsqueeze(1).repeat(1,self.comm_embed_dim).view(-1).unsqueeze(1)
		else:
			mat_mask=1

		mu_trans_tp=[x.view(-1,self.n_agents,self.comm_embed_dim) for x in mu_trans]
		mu_trans_tp=th.cat(mu_trans_tp,dim=-1)
		#mu=mu_trans_tp.matmul(self.mat).view(-1,self.n_agents*self.comm_embed_dim)
		mu=mu_trans_tp.matmul(self.mat*mat_mask).view(-1,self.n_agents*self.comm_embed_dim)
		sigma = 0*th.ones(mu.shape).cuda()
		#print('0',(mu-mu_trans[0]))
		#assert (mu-mu_trans[0]).sum()==0
		
		#sov4db1
		'''if self.sov4db1 is not None:
			mu=mu_trans[0]#[self.sov4db1]

			#db2
			#mu=mu_trans[int(random.random()*self.n_cluster)]'''

		return mu, sigma,mu_trans
