from helper import *
from RHNN_conv import RHNNConv
from RHNN_basis import RHNNBasis
import time

class BaseModel(torch.nn.Module):
	"""
	Base class for all other models, including the implementation of the loss function

	"""
	def __init__(self, params):
		super(BaseModel, self).__init__()

		self.p		= params
		self.act	= torch.tanh
		self.bceloss	= torch.nn.BCELoss()
		self.margin=self.p.margin
		self.neg=self.p.neg

		if self.margin:
			self.loss_func = torch.nn.MarginRankingLoss(margin=self.margin)
	def loss(self, pred, true_label):
		if not self.margin:
			return self.bceloss(pred, true_label)
		else:
			if self.neg!=-1:
				neg_list=[]
				pos_list=[]

				for i,data in enumerate(true_label):
					pos=pred[i][data>0.1]
					neg=pred[i][data<0.1]
					pos_=pos.repeat_interleave(self.neg)
					neg_index=torch.randint_like(pos_,len(neg)).long()
					neg_=neg[neg_index]
					neg_list.append(neg_)
					pos_list.append(pos_)
				neg_list=torch.cat(neg_list,0)
				pos_list=torch.cat(pos_list,0)
				y = -torch.ones(len(pos_list)).cuda()
				loss = self.loss_func(neg_list, pos_list, y)
				return loss

class  RHKHBase(BaseModel):
	"""
	The base class of our RHKH model, from which all other RHKH model implementations inherit

	"""
	def __init__(self, edge_index, edge_type, num_rel, params=None):
		super(RHKHBase, self).__init__(params)

		self.edge_index,self.edge_order		= edge_index
		self.edge_index=self.edge_index[[1,0]]

		self.edge_type		= edge_type
		self.p.gcn_dim		= self.p.embed_dim if self.p.gcn_layer == 1 else self.p.gcn_dim
		self.init_embed		= get_param((self.p.num_ent,   self.p.init_dim))
		self.device		= 'cuda'

		if self.p.num_bases > 0:
			self.init_rel  = get_param((self.p.num_bases,   self.p.init_dim))
		else:
			if self.p.score_func == 'transe': 	self.init_rel = get_param((num_rel,   self.p.init_dim))
			else: 					self.init_rel = get_param((num_rel, self.p.init_dim))

		if self.p.multi_r:
			multi_rel=get_param((num_rel*6,self.p.init_dim))
			self.multi_rel=(multi_rel).view(num_rel,6,self.p.init_dim)
		else:
			multi_rel=get_param((num_rel*6,self.p.init_dim*2))
			self.multi_rel=(multi_rel).view(num_rel,6,self.p.init_dim*2)

		if self.p.num_bases > 0:
			self.conv1 = RHNNBasis(self.p.init_dim, self.p.gcn_dim, num_rel, self.p.num_bases, act=self.act, params=self.p)
			self.conv2 = RHNNConv(self.p.gcn_dim,    self.p.embed_dim,    num_rel, act=self.act, params=self.p) if self.p.gcn_layer == 2 else None
		else:
			self.conv1 = RHNNConv(self.p.init_dim, self.p.gcn_dim,      num_rel, act=self.act, params=self.p)
			self.conv2 = RHNNConv(self.p.gcn_dim,    self.p.embed_dim,    num_rel, act=self.act, params=self.p) if self.p.gcn_layer == 2 else None

		self.register_parameter('bias', Parameter(torch.zeros(self.p.num_ent)))
		self.arity=5
		self.shift_network=torch.nn.Linear(self.p.embed_dim+self.arity,self.p.embed_dim)

	def forward_base(self, sub, rel, drop1, drop2):

		r	= self.init_rel if self.p.score_func != 'transe' else torch.cat([self.init_rel, -self.init_rel], dim=0)
		x, r,_	= self.conv1(self.init_embed, self.edge_index,  self.edge_order,self.edge_type, rel_embed=r)
		x	= drop1(x)
		x, r,_	= self.conv2(x, self.edge_index, self.edge_order, self.edge_type, rel_embed=r) 	if self.p.gcn_layer == 2 else (x, r,_)
		x	= drop2(x) 							if self.p.gcn_layer == 2 else x

		rel_emb	= torch.index_select(r, 0, rel)

		return rel_emb, x

	def forward_base_multi(self, sub, rel, drop1, drop2):

		r	= self.init_rel 
		multi_r=self.multi_rel.cuda()
		x, r,multi_r	= self.conv1(self.init_embed,self.edge_index, self.edge_order, self.edge_type, rel_embed=r,multi_rel_embed=multi_r)
		x	= drop1(x)
		x, r,multi_r	= self.conv2(x, self.edge_index, self.edge_order, self.edge_type, rel_embed=r,multi_rel_embed=multi_r) 	if self.p.gcn_layer == 2 else (x, r,multi_r)
		x	= drop2(x) 							if self.p.gcn_layer == 2 else x

		rel_emb	= torch.index_select(r, 0, rel)
		multi_r_emb=torch.index_select(multi_r,0,rel)
		return rel_emb, x,multi_r_emb

	def shift_onehot(self,entity_embed):
		e_onehot=torch.eye(5).unsqueeze(0).repeat(entity_embed.size()[0],1,1).cuda()
		e=self.shift_network(torch.cat((entity_embed,e_onehot),2))
		return e

	def shift(self,v, dim,sh):
		y = torch.cat((v[:,dim:dim+1,sh:], v[:,dim:dim+1,:sh]),2)
		return y

	def shift_rotate(self,entity_embed):
		emb_dim=entity_embed.size()[2]
		e1 = self.shift(entity_embed, 0,int(1 * emb_dim/6))
		e2 = self.shift(entity_embed, 1,int(2 * emb_dim/6))
		e3 = self.shift(entity_embed, 2,int(3 * emb_dim/6))
		e4 = self.shift(entity_embed, 3,int(4 * emb_dim/6))
		e5 = self.shift(entity_embed, 4,int(5 * emb_dim/6))
		return torch.cat((e1,e2,e3,e4,e5),1)

class RHKH_TransE(RHKHBase):
	"""
	Implementation of RHKH model with N-TransE as score function

	"""
	def __init__(self, edge_index, edge_type, params=None):
		super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params)
		self.drop = torch.nn.Dropout(self.p.id_drop)
		self.device='cuda'

	def forward(self, sub, rel):
		if self.p.position:
			sub_list=[]
			pos_list=[]
			for s in sub:
				sub0=self.p.id2sub_pos[s]

				sub1,pos1=sub0[0],sub0[1]
				sub_list.append(sub1)
				pos_list.append(list(pos1))

			sub=torch.LongTensor(sub_list)
			pos=torch.from_numpy(np.array(pos_list)).to(self.device)

			pos_sub_obj=torch.from_numpy(np.array(pos_list)).to(self.device)
			pos,obj_pos=pos_sub_obj[:,:-1],pos_sub_obj[:,-1:]
	
		if self.p.multi_r:
			rel_emb, all_ent,multi_rel_embed	= self.forward_base_multi(sub, rel, self.drop, self.drop)
		else:
			multi_rel_embed=self.multi_rel.cuda()
			rel_emb, all_ent	= self.forward_base(sub, rel, self.drop, self.drop)

		entity_embed_add1=torch.cat((all_ent,torch.zeros(1,all_ent.size()[1]).to(self.device)),0)
		id2entity_instance=self.p.id2entity_instance


		if self.p.position:
			if not self.p.multi_r:
				multi_rel_embed=torch.index_select(multi_rel_embed,0,rel)
			multi_rel_embed_add1=torch.cat((multi_rel_embed,torch.zeros(multi_rel_embed.size()[0],1,multi_rel_embed.size()[2]).to(self.device)),1)
			multi_rel_embed=torch.gather(multi_rel_embed_add1,1,pos.unsqueeze(2).repeat(1,1,200)).squeeze(1)
			entity_embed_all=entity_embed_add1[id2entity_instance[sub]]*multi_rel_embed
		else:
			entity_embed_all=entity_embed_add1[id2entity_instance[sub]]

		if self.p.shift==1:
			entity_embed_all=self.shift_onehot(entity_embed_all)
		elif self.p.shift==2:
			entity_embed_all=self.shift_rotate(entity_embed_all)

		sub_emb=torch.sum(entity_embed_all,1)

		obj_emb				= sub_emb + rel_emb

		x	= self.p.gamma - torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2)		
		score	= torch.sigmoid(x)

		return score

class RHKH_DistMult(RHKHBase):
	"""
	Implementation of RHKH model with N-DistMult as score function

	"""
	def __init__(self, edge_index, edge_type, params=None):
		super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params)
		self.drop = torch.nn.Dropout(self.p.id_drop)

	def forward(self, sub, rel):
		if self.p.position:
			sub_list=[]
			pos_list=[]
			for s in sub:
				sub0=self.p.id2sub_pos[s]

				sub1,pos1=sub0[0],sub0[1]
				sub_list.append(sub1)
				pos_list.append(list(pos1))
			sub=torch.LongTensor(sub_list)
			pos=torch.from_numpy(np.array(pos_list)).to(self.device)

			pos_sub_obj=torch.from_numpy(np.array(pos_list)).to(self.device)
			pos,obj_pos=pos_sub_obj[:,:-1],pos_sub_obj[:,-1:]

		if self.p.multi_r:
			rel_emb, all_ent,multi_rel_embed	= self.forward_base_multi(sub, rel, self.drop, self.drop)
		else:
			multi_rel_embed=self.multi_rel.cuda()
			rel_emb, all_ent	= self.forward_base(sub, rel, self.drop, self.drop)

		entity_embed_add1=torch.cat((all_ent,torch.zeros(1,all_ent.size()[1]).to(self.device)),0)
		id2entity_instance=self.p.id2entity_instance

		if self.p.position:
			if not self.p.multi_r:
				multi_rel_embed=torch.index_select(multi_rel_embed,0,rel)
			multi_rel_embed_add1=torch.cat((multi_rel_embed,torch.zeros(multi_rel_embed.size()[0],1,multi_rel_embed.size()[2]).to(self.device)),1)
			multi_rel_embed=torch.gather(multi_rel_embed_add1,1,pos.unsqueeze(2).repeat(1,1,200)).squeeze(1)
			entity_embed_all=entity_embed_add1[id2entity_instance[sub]]*multi_rel_embed
		else:
			entity_embed_all=entity_embed_add1[id2entity_instance[sub]]

		if self.p.shift==1:
			entity_embed_all=self.shift_onehot(entity_embed_all)
		elif self.p.shift==2:
			entity_embed_all=self.shift_rotate(entity_embed_all)
		sub_emb=torch.sum(entity_embed_all,1)

		obj_emb				= sub_emb * rel_emb


		x = torch.mm(obj_emb, all_ent.transpose(1, 0))
		x += self.bias.expand_as(x)

		score = torch.sigmoid(x)
		return score

def dist1(entity_emb,e,l,u):
    w = u -l + 1  
    k = 0.5*(w - 1) * (w - 1/w)
    return torch.where(torch.logical_and(torch.ge(e, l), torch.le(e, u)),
                        entity_emb / w,
                        entity_emb * w - k)

def dist2(entity_emb,l):
    return torch.where(torch.lt(entity_emb, l),entity_emb/l,entity_emb)
	
class RHKH_ConvE(RHKHBase):
	"""
	Implementation of RHKH model with N-ConvE as score function

	"""
	def __init__(self, edge_index, edge_type, params=None):
		super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params)

		self.bn0		= torch.nn.BatchNorm2d(1)
		self.bn1		= torch.nn.BatchNorm2d(self.p.num_filt)
		self.bn2		= torch.nn.BatchNorm1d(self.p.embed_dim)
		
		self.hidden_drop	= torch.nn.Dropout(self.p.id_drop)
		self.hidden_drop2	= torch.nn.Dropout(self.p.id_drop2)
		self.feature_drop	= torch.nn.Dropout(self.p.feat_drop)
		self.m_conv1		= torch.nn.Conv2d(1, out_channels=self.p.num_filt, kernel_size=(self.p.ker_sz, self.p.ker_sz), stride=1, padding=0, bias=self.p.bias)

		flat_sz_h		= int(2*self.p.k_w) - self.p.ker_sz + 1
		flat_sz_w		= self.p.k_h 	    - self.p.ker_sz + 1
		self.flat_sz		= flat_sz_h*flat_sz_w*self.p.num_filt
		self.fc			= torch.nn.Linear(self.flat_sz, self.p.embed_dim)

	def concat(self, e1_embed, rel_embed):
		e1_embed	= e1_embed. view(-1, 1, self.p.embed_dim)
		rel_embed	= rel_embed.view(-1, 1, self.p.embed_dim)
		stack_inp	= torch.cat([e1_embed, rel_embed], 1)
		stack_inp	= torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.p.k_w, self.p.k_h))
		return stack_inp

	def forward(self, sub, rel):
		if self.p.position:
			sub_list=[]
			pos_list=[]
			for s in sub:
				sub0=self.p.id2sub_pos[s]
				sub1,pos1=sub0[0],sub0[1]
				sub_list.append(sub1)
				pos_list.append(list(pos1))

			sub=torch.LongTensor(sub_list)
			pos=torch.from_numpy(np.array(pos_list)).to(self.device)

			pos_sub_obj=torch.from_numpy(np.array(pos_list)).to(self.device)
			pos,obj_pos=pos_sub_obj[:,:-1],pos_sub_obj[:,-1:]
		
		if self.p.multi_r:
			rel_emb, all_ent,multi_rel_embed	= self.forward_base_multi(sub, rel, self.hidden_drop, self.feature_drop)
		else:
			multi_rel_embed=self.multi_rel.cuda()
			rel_emb, all_ent	= self.forward_base(sub, rel, self.hidden_drop, self.feature_drop)

		entity_embed_add1=torch.cat((all_ent,torch.zeros(1,all_ent.size()[1]).to(self.device)),0)
		id2entity_instance=self.p.id2entity_instance

		if self.p.position:
			if not self.p.multi_r:
				multi_rel_embed=torch.index_select(multi_rel_embed,0,rel)
			multi_rel_embed_add1=torch.cat((multi_rel_embed,torch.zeros(multi_rel_embed.size()[0],1,multi_rel_embed.size()[2]).to(self.device)),1)
			multi_rel_embed=torch.gather(multi_rel_embed_add1,1,pos.unsqueeze(2).repeat(1,1,200)).squeeze(1)
			entity_embed_all=entity_embed_add1[id2entity_instance[sub]]*multi_rel_embed
		else:
			entity_embed_all=entity_embed_add1[id2entity_instance[sub]]

		if self.p.shift==1:
			entity_embed_all=self.shift_onehot(entity_embed_all)
		elif self.p.shift==2:
			entity_embed_all=self.shift_rotate(entity_embed_all)
		
		sub_emb=torch.sum(entity_embed_all,1)

		stk_inp				= self.concat(sub_emb, rel_emb)
		x				= self.bn0(stk_inp)
		x				= self.m_conv1(x)
		x				= self.bn1(x)
		x				= F.relu(x)
		x				= self.feature_drop(x)
		x				= x.view(-1, self.flat_sz)
		x				= self.fc(x)
		x				= self.hidden_drop2(x)
		x				= self.bn2(x)
		x				= F.relu(x)

		x = torch.mm(x, all_ent.transpose(1,0))
		x += self.bias.expand_as(x)

		score = torch.sigmoid(x)
		return score


class RHKH_Box(RHKHBase):
	def __init__(self, edge_index, edge_type, params=None):
		super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params)
		self.drop = torch.nn.Dropout(self.p.id_drop)
		self.p=params
		self.margin=self.p.margin

		self.in_dim=self.p.gcn_dim
		self.out_dim=self.p.gcn_dim

		self.fc_layer1 = torch.nn.Linear(self.in_dim,self.out_dim) # L
		self.fc_layer2 = torch.nn.Linear(self.in_dim,self.out_dim) # U
		self.fc_layer3 = torch.nn.Linear(self.in_dim,self.out_dim) # B

	def forward(self, sub, rel):
		if self.p.position:
			sub_list=[]
			pos_list=[]
			for s in sub:
				sub0=self.p.id2sub_pos[s]
				sub1,pos1=sub0[0],sub0[1]
				sub_list.append(sub1)
				pos_list.append(list(pos1))

			sub=torch.LongTensor(sub_list)
			pos=torch.from_numpy(np.array(pos_list)).to(self.device)

			pos_sub_obj=torch.from_numpy(np.array(pos_list)).to(self.device)
			pos,obj_pos=pos_sub_obj[:,:-1],pos_sub_obj[:,-1:]
	

		if self.p.multi_r:
			rel_emb, all_ent,multi_rel_embed	= self.forward_base_multi(sub, rel, self.drop, self.drop)

		else:
			multi_rel_embed=self.multi_rel.cuda()
			rel_emb, all_ent	= self.forward_base(sub, rel, self.drop, self.drop)

		r_center=rel_emb

		if self.p.position:
			if not self.p.multi_r:
				multi_rel_embed=torch.index_select(multi_rel_embed,0,rel)
			multi_rel_embed_add1=torch.cat((multi_rel_embed,torch.zeros(multi_rel_embed.size()[0],1,multi_rel_embed.size()[2]).to(self.device)),1)

			multi_rel_embed=torch.gather(multi_rel_embed_add1,1,pos_sub_obj.unsqueeze(2).repeat(1,1,200)).squeeze(1)
			r_center_multi=multi_rel_embed

		if self.p.dist_type==0:
			if self.p.r2fc:
				r_center=self.fc_layer1(rel_emb)
				if self.p.position:
					r_center_multi=self.fc_layer1(multi_rel_embed)
		elif self.p.dist_type==1:
			L=self.fc_layer1(rel_emb)
			U=self.fc_layer2(rel_emb)

			r_center=(L+U)/2

			if self.p.position:
				L_multi=self.fc_layer1(multi_rel_embed)
				U_multi=self.fc_layer2(multi_rel_embed)
				r_center_multi=(L_multi+U_multi)/2
		elif self.p.dist_type==2:
			L=self.fc_layer1(rel_emb)
			if self.p.position:
				L_multi=self.fc_layer1(multi_rel_embed)
			if self.p.r2fc:
				r_center=self.fc_layer2(rel_emb)
				if self.p.position:
					r_center_multi=self.fc_layer2(multi_rel_embed)

		B=self.fc_layer3(all_ent)
	
		entity_embed_add1=torch.cat((all_ent,torch.zeros(1,all_ent.size()[1]).to(self.device)),0)
		B_embed_add1=torch.cat((B,torch.zeros(1,B.size()[1]).to(self.device)),0)
		id2entity_instance=self.p.id2entity_instance.cuda()

		entity_embed_all=B_embed_add1[id2entity_instance[sub]]
		if self.p.shift==1:
			entity_embed_all=self.shift_onehot(entity_embed_all)
		elif self.p.shift==2:
			entity_embed_all=self.shift_rotate(entity_embed_all)
		sub_B_emb=torch.sum(entity_embed_all,1)
		non_zero_sub= torch.count_nonzero(id2entity_instance[sub]+1, dim=1).reshape(-1, 1)

		rand_sub=torch.rand((non_zero_sub.size()[0],non_zero_sub.size()[1])).cuda()
		non_zero_sub=torch.floor(non_zero_sub*rand_sub).long()
		rand_sample=torch.gather(id2entity_instance[sub],1,non_zero_sub)

		if self.p.position:
			r_center_multi_b=torch.gather(r_center_multi,1,non_zero_sub.unsqueeze(2).repeat(1,1,200)).squeeze(1)

			pred_index=torch.full_like(non_zero_sub,5).to(self.device)
			r_center_pred=torch.gather(r_center_multi,1,pred_index.unsqueeze(2).repeat(1,1,200)).squeeze(1)

			if self.p.dist_type==1:
				L_multi=torch.gather(L_multi,1,non_zero_sub.unsqueeze(2).repeat(1,1,200)).squeeze(1)
				U_multi=torch.gather(U_multi,1,non_zero_sub.unsqueeze(2).repeat(1,1,200)).squeeze(1)
			
			elif self.p.dist_type==2:
				L_multi=torch.gather(L_multi,1,non_zero_sub.unsqueeze(2).repeat(1,1,200)).squeeze(1)

		sub_emb=B_embed_add1[rand_sample].squeeze(1)
		sub_entity_emd=entity_embed_add1[rand_sample].squeeze(1)
		sub_e_B=sub_B_emb-sub_emb+sub_entity_emd
		if self.p.position:
			obj_b_emb				= r_center_multi_b -sub_e_B
			obj_emb				= r_center_pred -sub_B_emb
		else:
			obj_b_emb				= r_center -sub_e_B
			obj_emb				= r_center -sub_B_emb

		x_b=torch.norm(obj_b_emb.unsqueeze(1) - B, p=1, dim=2)
		x=torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2)	

		if self.p.dist_type==1:
	
			e_x_b=torch.sum(sub_e_B.unsqueeze(1)+B,dim=2)
			e_x=torch.sum(sub_B_emb.unsqueeze(1)+all_ent,dim=2)

			if self.p.position:

				L_x=torch.sum(L_multi,dim=1).unsqueeze(1).repeat(1, all_ent.size()[0])
				U_x=torch.sum(U_multi,dim=1).unsqueeze(1).repeat(1, all_ent.size()[0])

			else:
				L_x=torch.sum(L,dim=1).unsqueeze(1).repeat(1, all_ent.size()[0])
				U_x=torch.sum(U,dim=1).unsqueeze(1).repeat(1, all_ent.size()[0])


			x=dist1(x,e_x,L_x,U_x)
			x_b=dist1(x_b,e_x_b,L_x,U_x)
		elif self.p.dist_type==2:

			
			if self.p.position:
				L_x=torch.sum(L_multi,dim=1).unsqueeze(1).repeat(1, all_ent.size()[0])
			else:
				L_x=torch.sum(L,dim=1).unsqueeze(1).repeat(1, all_ent.size()[0])
	
			x=dist2(x,L_x)
			x_b=dist2(x_b,L_x)

		
		x_b=1/2*self.p.gamma-x_b
		x=1/2*self.p.gamma - x

		if not self.margin:
			score=torch.sigmoid(x_b+x)

		else:
			score=x+x_b
		return score

