from helper import *
from message_passing import MessagePassing

class RHNNConv(MessagePassing):
	def __init__(self, in_channels, out_channels, num_rels, act=lambda x:x, params=None):
		super(self.__class__, self).__init__()

		self.p 			= params
		self.in_channels	= in_channels
		self.out_channels	= out_channels
		self.num_rels 		= num_rels
		self.act 		= act
		self.device		= None

		self.w_loop		= get_param((in_channels, out_channels))
		self.w_in		= get_param((in_channels, out_channels))
		self.w_out		= get_param((in_channels, out_channels))
		self.w_rel 		= get_param((in_channels, out_channels))
		self.loop_rel 		= get_param((1, in_channels))

		self.multi_loop_rel 		= get_param((6, in_channels)).view(1,6,self.p.init_dim)

		self.drop		= torch.nn.Dropout(self.p.dropout)
		self.bn			= torch.nn.BatchNorm1d(out_channels)

		if self.p.bias: self.register_parameter('bias', Parameter(torch.zeros(out_channels)))

		self.in_conv=torch.nn.Linear(self.p.init_dim*5,self.p.init_dim)
	def forward(self, x, edge_index,edge_order,edge_type, rel_embed,multi_rel_embed=None): 
		
		if self.device is None:
			self.device ='cuda'


		id2entity_instance=self.p.id2entity_instance
		id2entity_instance_loop=self.p.id2entity_instance_loop

		if self.p.opn=='norela':
			edge_index,edge_type=self.p.norepeat_edge_index, self.p.norepeat_edge_type
			id2entity_instance=self.p.id2norepeat_entity_instance

		input_add1=torch.cat((x,torch.zeros(1,x.size()[1]).to('cuda')),0)
		
		if self.p.entity_conv:
			if multi_rel_embed is not None:
				instance_dict_embedding=input_add1[id2entity_instance]
				instance_dict_embedding_loop=input_add1[id2entity_instance_loop]
			else:
				concat_embedding=input_add1[id2entity_instance].reshape(input_add1[id2entity_instance].size()[0],-1)
				instance_dict_embedding=self.in_conv(concat_embedding)
				concat_embedding_loop=input_add1[id2entity_instance_loop].reshape(input_add1[id2entity_instance_loop].size()[0],-1)
				instance_dict_embedding_loop=self.in_conv(concat_embedding_loop)
		else:
			if multi_rel_embed is not None:
				instance_dict_embedding=input_add1[id2entity_instance]
				instance_dict_embedding_loop=input_add1[id2entity_instance_loop]
			else:
				sign_a = torch.sign(id2entity_instance+1).int()
				non_zero_a = torch.count_nonzero(sign_a, dim=1).reshape(-1, 1)

				instance_dict_embedding=torch.sum(input_add1[id2entity_instance],1)/non_zero_a.to('cuda')

				sign_a_loop = torch.sign(id2entity_instance_loop+1).int()
				non_zero_a_loop = torch.count_nonzero(sign_a_loop, dim=1).reshape(-1, 1)
				instance_dict_embedding_loop=torch.sum(input_add1[id2entity_instance_loop],1)/non_zero_a_loop.to('cuda')

		rel_embed = torch.cat([rel_embed, self.loop_rel], dim=0)
		if multi_rel_embed is not None:
			multi_rel_embed=multi_rel_embed.cuda()
			self.multi_loop_rel=self.multi_loop_rel.cuda()
			multi_rel_embed=torch.cat([multi_rel_embed,self.multi_loop_rel],dim=0)
			multi_rel_embed_add1=torch.cat((multi_rel_embed,torch.zeros(multi_rel_embed.size()[0],1,multi_rel_embed.size()[2]).to('cuda')),1)
		
		else:
			multi_rel_embed_add1=None

		num_edges = edge_index.size(1)
		num_ent   = x.size(0)


		self.in_index= edge_index
		self.in_type= edge_type


		self.loop_index  = torch.stack([torch.arange(num_ent), torch.arange(num_ent)]).to('cuda')
		self.loop_type   = torch.full((num_ent,), rel_embed.size(0)-1, dtype=torch.long).to('cuda')
		self.loop_edge_order   = torch.full((num_ent,5), 0, dtype=torch.long).to('cuda')

		if self.p.norm==0:
			self.in_norm     = None
		elif self.p.norm==1:
			self.in_norm     = self.compute_norm_new(self.in_index, num_ent)
		elif self.p.norm==2:
			self.in_norm     = self.compute_norm_complex(self.in_index, self.in_type,num_ent)
		elif self.p.norm==3:
			self.in_norm     = self.compute_norm_simple(self.in_index, self.in_type,num_ent)

		in_res		= self.propagate('add',instance_dict_embedding, edge_index,   x=x, edge_type=self.in_type,   rel_embed=rel_embed, edge_norm=self.in_norm, 	mode='in',use_norm=self.p.norm,entity_conv=self.p.entity_conv,multi_rel_embed=multi_rel_embed_add1,edge_order=edge_order)
		loop_res	= self.propagate('add',instance_dict_embedding_loop,self.loop_index, x=x, edge_type=self.loop_type, rel_embed=rel_embed, edge_norm=None, 		mode='loop',use_norm=0,entity_conv=self.p.entity_conv,multi_rel_embed=multi_rel_embed_add1,edge_order=self.loop_edge_order)

		out		= self.drop(in_res)*(1/2) + self.drop(loop_res)*(1/2)
		if self.p.bias: out = out + self.bias

		out = self.bn(out)

		if multi_rel_embed is not None:
			return self.act(out), torch.matmul(rel_embed, self.w_rel)[:-1],torch.matmul(multi_rel_embed, self.w_rel)[:-1] 
		else:
			return self.act(out), torch.matmul(rel_embed, self.w_rel)[:-1],1
	def rel_transform(self, ent_embed, rel_embed):

		if   self.p.opn == 'corr': 	trans_embed  = ccorr(ent_embed, rel_embed)
		elif self.p.opn == 'sub': 	trans_embed  = ent_embed - rel_embed
		elif self.p.opn == 'mult': 	trans_embed  = ent_embed * rel_embed
		elif self.p.opn == 'norela': 	trans_embed  = ent_embed
		elif self.p.opn == 'oldnorela': 	trans_embed  = ent_embed
		else: raise NotImplementedError

		return trans_embed

	def message(self, x_j,edge_index,edge_type, rel_embed, edge_norm, mode,use_norm=0,entity_conv=0,multi_rel_embed=None,edge_order=None):
		weight 	= getattr(self, 'w_{}'.format(mode))
	
		edge_type=edge_type.cuda()
		rel_emb = torch.index_select(rel_embed, 0, edge_type)

		if multi_rel_embed is not None:
	
			multi_rel_emb=torch.index_select(multi_rel_embed,0,edge_type)
			edge_order=edge_order.cuda()
			multi_rel_emb=torch.gather(multi_rel_emb,1,edge_order.unsqueeze(2).repeat(1,1,multi_rel_emb.size()[2])).squeeze(1)

			x_j=x_j*multi_rel_emb

			if self.p.entity_conv:

				x_j_concat_embedding=x_j.reshape(x_j.size()[0],-1)
				x_j=self.in_conv(x_j_concat_embedding)
			else:
				sign_a = torch.sign(self.p.id2entity_instance[edge_index[1]]+1).int()
				non_zero_a = torch.count_nonzero(sign_a, dim=1).reshape(-1, 1)
				x_j=torch.sum(x_j,1)/non_zero_a.cuda()

			xj_rel  = self.rel_transform(x_j, rel_emb)

		else:
			rel_emb = torch.index_select(rel_embed, 0, edge_type)


			xj_rel  = self.rel_transform(x_j, rel_emb)
		out	= torch.mm(xj_rel, weight)

		if edge_norm is not None:
			edge_norm=edge_norm.cuda()
			if use_norm==1:
				edge_norm=edge_norm.view(-1, 1)
		
		return out if edge_norm is None else out * edge_norm

	def update(self, aggr_out):
		return aggr_out

	def compute_norm(self, edge_index, num_ent):
		row, col	= edge_index
		edge_weight 	= torch.ones_like(row).float()
		deg		= scatter_add( edge_weight, row, dim=0, dim_size=num_ent)	
		deg_inv		= deg.pow(-0.5)						

		deg_inv[deg_inv	== float('inf')] = 0
		norm		= deg_inv[row] * edge_weight * deg_inv[col]		

		return norm

	def compute_norm_new(self, edge_index, num_ent):

		row, col	= edge_index
		
		edge_weight 	= torch.ones_like(row).float()
		deg		= scatter_add( edge_weight, row, dim=0, dim_size=num_ent)	
		deg_inv		= deg.pow(-0.5)						

		deg_inv[deg_inv	== float('inf')] = 0
		norm		= deg_inv[row] * edge_weight * deg_inv[row]			

		return norm

	def compute_norm_complex(self, edge_index, edge_type,num_ent):
		row, col	= edge_index		
		edge_weight 	= torch.ones_like(row).float()
		deg		= scatter_add( edge_weight, row, dim=0, dim_size=num_ent)	

		edge_deg=scatter_add(torch.ones_like(edge_type).float(), edge_type, dim=0, dim_size=self.p.num_rel) 
		H=self.p.ent_edge_matrix
		deg_inv		=deg.unsqueeze(1).pow(-0.5)					
		deg_inv[deg_inv	== float('inf')] = 0

		edge_deg_inv		=torch.diag(edge_deg).pow(-0.5)							
		edge_deg_inv[edge_deg_inv	== float('inf')] = 0
		norm1=deg_inv*H
		norm2=torch.mm(edge_deg_inv,norm1.t())
		norm=torch.mm(norm1,norm2)
		norm=torch.sum(norm,1).unsqueeze(1)
		return norm[row]

	def compute_norm_simple(self, edge_index, edge_type,num_ent):
		row, col	= edge_index		
		edge_weight 	= torch.ones_like(row).float()
		deg		= scatter_add( edge_weight, row, dim=0, dim_size=num_ent)	
		edge_deg=scatter_add(torch.ones_like(edge_type).float(), edge_type, dim=0, dim_size=self.p.num_rel) 
		H=self.p.ent_edge_matrix
		deg_inv		=deg.unsqueeze(1).pow(-0.5)					
		deg_inv[deg_inv	== float('inf')] = 0

		edge_deg_inv		=edge_deg.unsqueeze(1).pow(-0.5)					
		edge_deg_inv[edge_deg_inv	== float('inf')] = 0
		norm1=deg_inv*H
		norm=torch.mm(norm1,edge_deg_inv)
		return norm[row]

	def __repr__(self):
		return '{}({}, {}, num_rels={})'.format(
			self.__class__.__name__, self.in_channels, self.out_channels, self.num_rels)

