import torch as th
import torch.nn as nn
import torch.nn.functional as F

class att_cls_model(nn.Module):
	def __init__(self,input_shape,comm_dim,n_cluster):
		super(att_cls_model,self).__init__()
		self.input_shape=input_shape
		self.comm_dim=comm_dim
		self.n_cluster=n_cluster
		self.hidden_dim=8

		self.q_layer=nn.Linear(comm_dim,self.hidden_dim)
		self.k_layer=nn.ModuleList([nn.Linear(input_shape,self.hidden_dim) for _ in range(n_cluster)])
		
	def forward(self,input,comm):
		q=self.q_layer(comm).unsqueeze(2)
		ks=[layer(input).unsqueeze(2) for layer in self.k_layer]
		ks=th.cat(ks,dim=2)

		return (q*ks).sum(dim=1)

class IBFCommM(nn.Module):
	def __init__(self, input_shape, args):
		super(IBFCommM, self).__init__()
		self.args = args
		self.n_agents = args.n_agents
		self.n_cluster=3
		self.comm_embed_dim=args.comm_embed_dim

		self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
		self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
		self.fc3 = nn.Linear(args.rnn_hidden_dim, args.comm_embed_dim * self.n_agents)

		self.inference_model = nn.Sequential(
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, args.n_actions)
		)

		# simi
		'''self.simi_w1=[nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b1=[nn.Parameter(th.rand(args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_w2=[nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]
		self.simi_b2=[nn.Parameter(th.rand(args.comm_embed_dim)*2-1).cuda() for _ in range(self.n_cluster)]'''
		#grad
		self.simi_w11=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w12=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w13=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		#self.simi_w14=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		#self.simi_w15=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_b11=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b12=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b13=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		#self.simi_b14=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		#self.simi_b15=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_w21=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w22=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w23=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		#self.simi_w24=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		#self.simi_w25=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_b21=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b22=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b23=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		#self.simi_b24=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		#self.simi_b25=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_w1=[self.simi_w11,self.simi_w12,self.simi_w13]#[self.simi_w11,self.simi_w12,self.simi_w13,self.simi_w14,self.simi_w15]#
		self.simi_b1=[self.simi_b11,self.simi_b12,self.simi_b13]#[self.simi_b11,self.simi_b12,self.simi_b13,self.simi_b14,self.simi_b15]#
		self.simi_w2=[self.simi_w21,self.simi_w22,self.simi_w23]#[self.simi_w21,self.simi_w22,self.simi_w23,self.simi_w24,self.simi_w25]#
		self.simi_b2=[self.simi_b21,self.simi_b22,self.simi_b23]#[self.simi_b21,self.simi_b22,self.simi_b23,self.simi_b24,self.simi_b25]#

		'''self.simi_w11=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w12=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w13=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w14=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w15=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w16=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w17=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_b11=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b12=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b13=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b14=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b15=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b16=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b17=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_w21=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w22=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w23=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w24=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w25=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w26=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_w27=nn.Parameter(th.rand(args.comm_embed_dim,args.comm_embed_dim)*2-1)
		self.simi_b21=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b22=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b23=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b24=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b25=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b26=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_b27=nn.Parameter(th.rand(args.comm_embed_dim)*2-1)
		self.simi_w1=[self.simi_w11,self.simi_w12,self.simi_w13,self.simi_w14,self.simi_w15,self.simi_w16,self.simi_w16]#
		self.simi_b1=[self.simi_b11,self.simi_b12,self.simi_b13,self.simi_b14,self.simi_b15,self.simi_b16,self.simi_b17]#
		self.simi_w2=[self.simi_w21,self.simi_w22,self.simi_w23,self.simi_w24,self.simi_w25,self.simi_w26,self.simi_w27]#
		self.simi_b2=[self.simi_b21,self.simi_b22,self.simi_b23,self.simi_b24,self.simi_b25,self.simi_b26,self.simi_b27]#'''

		'''self.cls_model = nn.Sequential(
			nn.Linear(input_shape + args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, 4 * args.comm_embed_dim * self.n_agents),
			nn.ReLU(True),
			nn.Linear(4 * args.comm_embed_dim * self.n_agents, self.n_cluster)
		)'''
		self.cls_model=att_cls_model(input_shape,args.comm_embed_dim * self.n_agents,self.n_cluster)

		#pcav4
		self.prob=None

		#simipca
		self.mat=th.rand((self.n_cluster*args.comm_embed_dim,args.comm_embed_dim)).cuda()*2-1

		#mlev18
		self.actv=th.ones((self.n_cluster)).cuda()

	def forward(self, inputs):
		x = F.relu(self.fc1(inputs))
		x = F.relu(self.fc2(x))
		gaussian_params = self.fc3(x)#+inputs[:,:3].view(-1,1,self.args.comm_embed_dim).repeat(1,self.n_agents,1).view(-1,self.args.comm_embed_dim*self.n_agents)

		#normalize
		#gaussian_params=gaussian_params.view(-1,self.n_agents,self.args.comm_embed_dim)
		#gaussian_params=gaussian_params/(gaussian_params**2).sum(dim=-1,keepdim=True)**0.5#5*gaussian_params/(gaussian_params**2).sum(dim=-1,keepdim=True)**0.5

		mu = gaussian_params
		#sigma = F.softplus(gaussian_params[:, self.args.comm_embed_dim * self.n_agents:])
		sigma = 0*th.ones(mu.shape).cuda()#th.ones(mu.shape).cuda()#th.zeros(mu.shape).cuda()

		mu_trans=[F.relu(mu.view(-1,self.comm_embed_dim)@self.simi_w1[i]+self.simi_b1[i]) for i in range(self.n_cluster)]
		mu_trans=[(mu_trans[i]@self.simi_w2[i]+self.simi_b2[i]).view(-1,self.n_agents*self.comm_embed_dim) for i in range(self.n_cluster)]

		#mlev18
		mat_mask=self.actv.float().unsqueeze(1).repeat(1,self.comm_embed_dim).view(-1).unsqueeze(1)

		#pcav4
		'''if False and self.prob is not None:
			mat_mask=(self.prob>0.1).float().unsqueeze(1).repeat(1,self.comm_embed_dim).view(-1).unsqueeze(1)
		else:
			mat_mask=1'''
		mu_trans_tp=[x.view(-1,self.n_agents,self.comm_embed_dim) for x in mu_trans]
		mu_trans_tp=th.cat(mu_trans_tp,dim=-1)
		#mu=mu_trans_tp.matmul(self.mat).view(-1,self.n_agents*self.comm_embed_dim)
		mu=mu_trans_tp.matmul(self.mat*mat_mask).view(-1,self.n_agents*self.comm_embed_dim)
		sigma = 0*th.ones(mu.shape).cuda()
		#print('0',(mu-mu_trans[0]))
		#assert (mu-mu_trans[0]).sum()==0

		return mu, sigma,mu_trans
