import gc
import copy
import os
import sys
from sys import *
from collections import defaultdict
import torch
import torch.nn as nn
from torch import *
from torch.nn import *
from torch.nn.parallel import DistributedDataParallel
from torch.optim import *
from random import shuffle
from random import randint
import time
import json
import torch.nn.functional as F



CUDA_MEM = 32

class Predictor(nn.Module):
	def __init__(self, E, R, DATA_DIR, RotatE, print=print):
		super(Predictor, self).__init__()
		self.E = E
		self.R = R
		self.DATA_DIR = DATA_DIR
		self.training = True
		self.pre_mr = [0]
		self.pre_mrr = [0]
		self.pre_h1 = [0]
		self.pre_h3 = [0]
		self.pre_h10 = [0]
		self.MAX_RULE_NUM = 3000
		self.MAX_SAMPLED_RULES = self.MAX_RULE_NUM
		self.MAX_RULE_LEN_HARD = 4
		self.print = print
		for i in range(1, E + 2):
			self.pre_mr.append(self.pre_mr[-1] + i)
			self.pre_mrr.append(self.pre_mrr[-1] + 1.0 / i)
			self.pre_h1.append(self.pre_h1[-1] + (1 if i <= 1 else 0))
			self.pre_h3.append(self.pre_h3[-1] + (1 if i <= 3 else 0))
			self.pre_h10.append(self.pre_h10[-1] + (1 if i <= 10 else 0))


		self.use_ranking_loss = True
		self.parameterize_relation_embed = True
		self.allow_neg_rules = False

		# load rotate
		if RotatE is not None:
			import numpy
			config = json.load(open(f"{DATA_DIR}/{RotatE}/config.json"))
			self.gamma = config['gamma']
			self.embed_dim = config['hidden_dim']
			self.embed_range = (self.gamma + 2.0) / self.embed_dim
			self.entity_embed = torch.tensor(numpy.load(f"{DATA_DIR}/{RotatE}/entity_embedding.npy"))
			relation_embed = torch.tensor(numpy.load(f"{DATA_DIR}/{RotatE}/relation_embedding.npy"))
			self.relation_embed = (torch.cat([relation_embed, -relation_embed, torch.zeros(self.embed_dim).unsqueeze(0)], dim=0))
		else:
			self.parameterize_relation_embed = False
			self.gamma = 0.0
			self.embed_dim = 1
			self.embed_range = 1.0
			self.entity_embed = torch.zeros(E, 2).float()
			self.relation_embed = torch.zeros(R, 1).float()

		assert self.entity_embed.size()[0] == E
		assert self.relation_embed.size()[0] == R

		self.entity_embed = self.entity_embed.detach().cuda()
		self.additional_grounding_buffer = dict()

		self.print(f"RotatE loaded: gamma = {self.gamma} embed_dim = {self.embed_dim}")

	@staticmethod
	def set_cuda_mem(x):
		CUDA_MEM = x

	def set_training(self, tr):
		self.training = tr
		if tr:
			self.train()
		else:
			self.eval()


	def calc_embed(self, h_embed, r_embed):
		if isinstance(h_embed, int):
			h_embed = self.entity_embed.index_select(0, torch.tensor(h_embed).cuda()).squeeze().cuda()
		if isinstance(r_embed, int):
			r_embed = self.relation_embed.index_select(0, torch.tensor(r_embed).cuda()).squeeze().cuda()

		pi = 3.14159265358979323846
		re_h, im_h = torch.chunk(h_embed, 2, dim=-1)

		# if r_embed.size()[-1] != h_embed.size()[-1]:
		phase = r_embed / (self.embed_range / pi)
		re_r = torch.cos(phase)
		im_r = torch.sin(phase)
		# else:
		#   re_r, im_r = torch.chunk(r_embed, 2, dim=-1)

		re_res = re_h * re_r - im_h * im_r
		im_res = re_h * im_r + im_h * re_r

		return torch.cat([re_res, im_res], dim=-1)

	@staticmethod
	def embed_norm(emb):
		re, im = torch.chunk(emb.cuda(), 2, dim=-1)
		score = torch.stack([re, im], dim=-1)
		score = score.norm(dim=-1).sum(dim=-1)
		return score

	def index_select(self, tensor, index):
		if self.training:
			if not isinstance(index, torch.Tensor):
				index = torch.tensor(index)
			index = index.to(tensor.device)
			return tensor.index_select(0, index).squeeze(0)
		else:
			return tensor[index]

	@staticmethod
	def load_batch(batch):
		return tuple(map(lambda x : x.cuda() if isinstance(x, torch.Tensor) else x, batch[0]))

	def rule_embed(self, rule_cho=None):
		rule_embed = torch.zeros(len(self.rule_list) if rule_cho is None else rule_cho.size()[-1], self.embed_dim).cuda()
		for i in range(self.MAX_RULE_LEN):
			rule_embed += self.index_select(self.relation_embed, self.rules[i] if rule_cho is None else self.rules[i][rule_cho])
		return rule_embed

	def load_ckpt(self, ckpt):
		self.rules = ckpt['rules'].to(self.rules.device)
		self.rule_list = ckpt['rule_list']
		# self.rule_weight = torch.nn.Parameter(ckpt['parameters']['rule_weight'])

	def relation_init_load(self, r, paths):
		self.r = r
		self.set_training(False)
		self.relation_embed = self.relation_embed.cuda()

		self.MAX_RULE_LEN = 0
		for path in paths:
			self.MAX_RULE_LEN = max(self.MAX_RULE_LEN, len(path))

		pad = self.R - 1
		gen_end = self.R - 1
		gen_pad = self.R
		rules = []
		rules_gen = []
		rule_list = []
		for path in paths:
			npad = (self.MAX_RULE_LEN - len(path))
			rules.append(path + (pad, ) * npad)
			rules_gen.append((self.r, ) + path + (gen_end, ) + (gen_pad, ) * npad)
			rule_list.append(tuple(path))
		
		self.rules = torch.LongTensor(rules).t()
		self.rules_gen = torch.LongTensor(rules_gen)
		self.rule_list = rule_list

		num_rules = len(self.rule_list)
		with torch.no_grad():
			self.rule_value = torch.zeros(num_rules).cuda()
			self.tmp_rule_embed = self.rule_embed().detach()
			self.prior_value = torch.zeros(num_rules).cuda()
			self.num_init = 0
			self.prior_coef = 0.01

		self.pad = pad
		self.gen_pad = gen_pad
		self.gen_end = gen_end

	def set_prior_value(self, p):
		del self.prior_value
		self.prior_value = p


	def relation_init_begin(self, r, paths=None):
		prior_value = None
		if paths is None:
			DATA_DIR = self.DATA_DIR

			# load rules
			with open(f"{DATA_DIR}/Rules/rules_{r}.txt") as file:
				tmp_paths = []
				for i, line in enumerate(file):
					try:
						path, prec = line.split('\t')
						path = tuple(map(int, path.split(' ')))
						prec = float(prec) * (0.8 ** max(0, len(path) - 4))
						if len(path) <= self.MAX_RULE_LEN_HARD:
							tmp_paths.append((path, prec, i))
					except:
						continue

			# assert len(tmp_paths) > 0
			tmp_paths = sorted(tmp_paths, key=lambda x : (x[1], x[2]), reverse=True)[:self.MAX_RULE_NUM]
			tmp_paths = [((r,), 10000, -1)] + tmp_paths
			paths = [path for path, _, _ in tmp_paths]
			prior_value = torch.tensor([prec for _, prec, _ in tmp_paths]).cuda()
			self.print(f"Read {len(paths)}/{self.MAX_RULE_NUM+1} rules MAX_RULE_LEN_HARD = {self.MAX_RULE_LEN_HARD}")

		self.relation_init_load(r, paths)
		
		if prior_value is not None:
			self.prior_value = prior_value.cuda()

	def relation_init_pretrain(self, r, pre):
		pre = torch.load(pre)
		self.relation_init_begin(r, pre['rule_list'])
		self.relation_embed = torch.nn.Parameter(self.relation_embed)
		self.rule_weight_raw = torch.nn.Parameter(pre['parameters']['rule_weight'].cpu())

	@staticmethod
	def clean():
		alloc = torch.cuda.memory_allocated() / 1024 / 1024 / 1024
		cache = torch.cuda.memory_cached() / 1024 / 1024 / 1024
		# self.print(f"alloc = {alloc}G cache = {cache}G")
		if alloc + cache >= CUDA_MEM * 0.9:
			torch.cuda.empty_cache()

	def additional_groundings(self, h, i, num, groundings):
		key = (h, tuple(self.rule_list[i]))
		if key in self.additional_grounding_buffer:
			return self.additional_grounding_buffer[key]

		with torch.no_grad():
			rule_embed = self.calc_embed(h, self.tmp_rule_embed[i])
			dist = self.embed_norm(self.entity_embed - rule_embed)
			dist[torch.LongTensor(groundings).cuda()] = 1e10
			ret = torch.arange(self.E).cuda()[dist <= self.gamma]
			num -= ret.size()[-1]
			dist[ret] = 1e10
			num = min(num, (dist < 1e9).sum().item())
			if num > 0:
				tmp = dist.topk(num, dim=0, largest=False, sorted=False)[1]
				ret = torch.cat([ret, tmp], dim=0)
			ret = ret.cpu()
			# if i == 0:
			#   self.print(f"additional_groundings |ret| = {ret.size()[-1]}")
			self.clean()

		self.additional_grounding_buffer[key] = ret
		return ret

	def con_score(self, h, con_rule, con_entity, con_weight, rule_cho=None, self_rule_embed=None):
		num_con = con_entity.size()[-1]
		batch_size = int((CUDA_MEM // 16) * 65536 * (1 if self.training else 4) * (1000 / self.embed_dim))
		# self.print(f"batch_size = {batch_size}")
		# self.print(f"con_score num_con = {num_con} batch_size = {batch_size} {num_con == batch_size}")
		batch_size = (num_con + batch_size - 1) // batch_size
		batch_size = (num_con + batch_size - 1) // batch_size

		if self_rule_embed is None:
			self_rule_embed = self.rule_embed(rule_cho=rule_cho)
		elif rule_cho is not None:
			self_rule_embed = self_rule_embed[rule_cho]
		rule_embed = self.calc_embed(h, self_rule_embed)
		results = []

		with torch.no_grad():
			grad_batch = randint(0, num_con - 1)
			if batch_size != num_con:
				perm = torch.randperm(num_con).cuda()

		# grad_count = 0
		for i in range(0, num_con, batch_size):
			if batch_size != num_con:
				batch = perm[i : i + batch_size]
				crule = con_rule[batch]
				centity = con_entity[batch]
			else:
				crule = con_rule
				centity = con_entity

			# if i <= grad_batch < i + batch_size:  grad_count += 1
			if self.training and (i <= grad_batch < i + batch_size):
				con_diff = rule_embed.index_select(0, crule)
				con_diff -= self.entity_embed.index_select(0, centity)
				res = self.embed_norm(con_diff)
			else:
				with torch.no_grad():
					con_diff = rule_embed[crule]
					con_diff -= self.entity_embed[centity]
					res = self.embed_norm(con_diff).detach()
			results.append(res)
			del crule, centity, con_diff, res
			self.clean()
			# torch.cuda.empty_cache()
		# assert grad_count == 1

		score = torch.cat(results, dim=0)
		if batch_size != num_con:
			with torch.no_grad():
				invperm = torch.empty(num_con).long().cuda()
				invperm[perm] = torch.arange(num_con).cuda()
			score = self.index_select(score, invperm)

		score = (self.gamma - score).sigmoid()

		# score = (1 - score / self.gamma).exp()
		# # try to compute softmax
		# score_sum = torch.sparse.sum(torch.sparse.FloatTensor(
		#   torch.stack([con_rule, con_entity], dim=0),
		#   score,
		#   torch.Size([len(self.rule_list), self.E])
		# ), -1).to_dense()
		# # dim = num_rules
		# score = score / self.index_select(score_sum, con_rule)


		score = score * con_weight
		return score

	def rule_init_step(self, batch, weight=1):
		
		num_rules = len(self.rule_list)
		h, _, t_list, mask, con_rule, con_entity, con_weight = self.load_batch([batch])
		with torch.no_grad():   
			con_score = self.con_score(h, con_rule, con_entity, con_weight, self_rule_embed=self.tmp_rule_embed) # * con_count

			indices = torch.stack([con_rule, con_entity], dim=0)
			def calc_value(con_score):
				return torch.sparse.sum(torch.sparse.FloatTensor(
					indices,
					con_score,
					torch.Size([num_rules, self.E])
				).cuda(), -1).to_dense()

			pos = calc_value(con_score * mask[con_entity])
			neg = calc_value(con_score * ~mask[con_entity])
			num = calc_value(con_weight).clamp(min=0.001)

			value = (pos - neg) / num * weight
			#print(torch.isnan(value).sum())

			if hasattr(self, 'rule_value'):
				self.rule_value += value
				# self.num_init += 1
				self.num_init += weight


		self.clean()
		return value

	def rule_cho(self, value, non_neg_value, output_mask=False, num_samples=None):
		if num_samples is None:
			num_samples = self.MAX_SAMPLED_RULES
		with torch.no_grad():
			value = value.clone()
			value[torch.isnan(value)] = 0
			num_rules = value.size()[-1]
			topk_ind = value.topk(min(num_samples - 1, num_rules), dim=0, largest=True, sorted=False)[1]
			cho = torch.zeros(num_rules).bool().cuda()
			cho[topk_ind] = 1
			# print(f"{cho.size()} {(non_neg_value <= 0).size()}")
			cho &= (non_neg_value >= 0)

			# print(cho)

			cho[0] = 1

			if output_mask:
				return cho
			else:
				return torch.arange(num_rules).cuda()[cho]

	def choose_rules(self, batch, num_samples):
		with torch.no_grad():
			w = self.rule_init_step(batch)
			value = (w + self.prior_coef * self.prior_value) * self.rule_weight()
			cho = self.rule_cho(value, w, num_samples=num_samples)
			return cho #.cpu()

	def rule_init_end(self):
		with torch.no_grad():
			self.rule_value[torch.isnan(self.rule_value)] = 0
			value = self.rule_value / self.num_init + self.prior_coef * self.prior_value
			# special: remove nonneg
			# self.print("allow_neg_rules: " + self.allow_neg_rules)
			if self.allow_neg_rules:
				cho = self.rule_cho(value, torch.ones_like(value).cuda(), output_mask=True)  
			else:
				cho = self.rule_cho(value, self.rule_value, output_mask=True)    

			value = value[cho]
			self.rule_value = self.rule_value[cho]
			self.prior_value = self.prior_value[cho]
			self.rules = self.rules[:, cho]
			self.rules_gen = self.rules_gen[cho]

			cho_list = cho.detach().cpu().numpy().tolist()
			new_rule_list = []
			for r, c in zip(self.rule_list, cho_list):
				if c:
					new_rule_list.append(r)
			self.rule_list = new_rule_list

			num_cho = cho.sum().item()
			assert value.size()[-1] == num_cho
			assert self.rules.size()[-1] == num_cho
			assert len(self.rule_list) == num_cho
			del cho
			

		if self.parameterize_relation_embed:
			self.relation_embed = torch.nn.Parameter(self.relation_embed.cpu())

		# rule_weight_init
		weight_init = self.rule_value
		self.print(f"init_end r = {self.r} num_cho = {num_cho} weight_init = [{weight_init.min().item(), weight_init.max().item()}]")
		weight_init /= weight_init.max().clamp(min=0.0001)
		weight_init = weight_init.clamp(min=0.0001)
		weight_init[0] = 1
		self.rule_weight_raw = torch.nn.Parameter(weight_init)

		# self.rule_weight_raw = torch.nn.Parameter(value.log())
		# self.rule_weight_raw = torch.nn.Parameter(torch.randn_like(value) / 10)
		ret = self.rule_value
		del self.rule_value
		# del self.tmp_rule_embed
		return ret

	def set_rule_coef(self, value):
		self.rule_coef = value.cuda()

	def rule_weight(self):
		# return (self.rule_weight_raw - self.rule_weight_raw.max().detach()).exp()
		return self.rule_weight_raw


	def relation_init_end(self):
		# del self.tmp_rule_embed
		# del self.additional_grounding_buffer
		# del self.prior_value
		# del self.rule_coef
		pass


	def forward(self, batch):
		E, R = self.E, self.R
		num_rules = len(self.rule_list)

		h, _, t_list, mask, rule_cho, con_rule, con_entity, con_weight = self.load_batch(batch)

		con_score = self.con_score(h, con_rule, con_entity, con_weight, rule_cho=rule_cho) * self.index_select(self.rule_weight(), con_rule)

		score = torch.sparse.sum(torch.sparse.FloatTensor(
			torch.stack([con_entity, con_rule], dim=0),
			con_score,
			torch.Size([E, num_rules])
		), -1).to_dense()

		# self.print(f"score = [{score.min()}, {score.max()}]")

		loss = torch.tensor(0.0).cuda()
		result = torch.zeros(6).cuda()

		if self.training:
			score = (score * 1).softmax(dim=-1)
			pos = score.masked_select(mask.bool())
			neg = score.masked_select(~mask.bool())

			loss = neg.sum()
			# with torch.no_grad():
			#   pos_min = pos.min().item()
			#   neg_max = neg.max().item()
			#   cutter = (pos_min + neg_max) / 2

			# pos_loss = -F.logsigmoid(pos - cutter)
			# neg_loss = -F.logsigmoid(cutter - neg) * neg.softmax(dim=-1).detach()
			# loss += pos_loss.sum() + neg_loss.sum() 

			if self.use_ranking_loss:
				for t in t_list:
					s =  score.index_select(0, torch.tensor(t).cuda()).squeeze()
					wrong = (neg > s)
					# self.print(f"train s = {s.item()} score = [{score.min().item()}, {score.max().item()}] rank = {(score > s).sum().item()}")
					loss += ((neg - s) * wrong).sum() / wrong.sum().clamp(min=1)

			# loss = torch.tensor(0.0) / torch.tensor(0.0)
			if torch.isnan(loss).item():
				loss = torch.tensor(0.0).cuda().requires_grad_()

		else:
			score[torch.isnan(score)] = 0
			mask[h] = 1
			for t in t_list:
				incorrect = score.masked_select(~mask.bool())
				rankl = (incorrect > score[t]).sum().item() #+ 1
				rankr = (incorrect >= score[t]).sum().item() + 1
				result[0] += 1
				result[1] += (self.pre_mr[rankr] - self.pre_mr[rankl]) / (rankr - rankl)
				result[2] += (self.pre_mrr[rankr] - self.pre_mrr[rankl]) / (rankr - rankl)
				result[3] += (self.pre_h1[rankr] - self.pre_h1[rankl]) / (rankr - rankl)
				result[4] += (self.pre_h3[rankr] - self.pre_h3[rankl]) / (rankr - rankl)
				result[5] += (self.pre_h10[rankr] - self.pre_h10[rankl]) / (rankr - rankl)
				# self.print(result)


		return loss, result.detach()

class Generator(nn.Module):
	def __init__(self, num_relations, embedding_dim, hidden_dim, use_cuda=True):
		super(Generator, self).__init__()
		self.num_relations = num_relations
		self.embedding_dim = embedding_dim
		self.hidden_dim = hidden_dim
		self.mov = num_relations // 2
		self.vocab_size = self.num_relations + 2
		self.label_size = self.num_relations + 1
		self.ending_idx = num_relations
		self.padding_idx = self.num_relations + 1
		self.num_layers = 1
		self.use_cuda = use_cuda
		self.print = print

		self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim, padding_idx=self.padding_idx)
		self.rnn = nn.LSTM(self.embedding_dim * 2, self.hidden_dim, self.num_layers, batch_first=True)
		self.linear = nn.Linear(self.hidden_dim, self.label_size)
		self.criterion = nn.CrossEntropyLoss(reduction='none')

		if self.use_cuda:
			self.cuda()

	def inv(self, r):
		if r < self.mov:
			return r + self.mov
		else:
			return r - self.mov

	def zero_state(self, batch_size): 
		state_shape = (self.num_layers, batch_size, self.hidden_dim)
		h0 = c0 = torch.zeros(*state_shape, requires_grad=False)
		if self.use_cuda:
			return (h0.cuda(), c0.cuda())
		else:
			return (h0, c0)

	def forward(self, inputs, relation, hidden):
		embedding = self.embedding(inputs)
		embedding_r = self.embedding(relation).unsqueeze(1).expand(-1, inputs.size(1), -1)
		embedding = torch.cat([embedding, embedding_r], dim=-1)
		outputs, hidden = self.rnn(embedding, hidden)
		logits = self.linear(outputs)
		Predictor.clean()
		return logits, hidden

	def loss(self, inputs, target, mask, weight):
		if self.use_cuda:
			inputs = inputs.cuda()
			target = target.cuda()
			mask = mask.cuda()
			weight = weight.cuda()

		hidden = self.zero_state(inputs.size(0))
		logits, hidden = self.forward(inputs, inputs[:, 0], hidden)
		logits = torch.masked_select(logits, mask.unsqueeze(-1)).view(-1, self.label_size)
		target = torch.masked_select(target, mask)
		weight = torch.masked_select((mask.t() * weight).t(), mask)
		loss = (self.criterion(logits, target) * weight).sum() / weight.sum()
		return loss

	def sample(self, relation):
		rule = [relation]
		relation = torch.LongTensor([relation])
		if self.use_cuda:
			relation = relation.cuda()
		hidden = self.zero_state(1)
		while True:
			inputs = torch.LongTensor([[rule[-1]]])
			if self.use_cuda:
				inputs = inputs.cuda()
			logits, hidden = self.forward(inputs, relation, hidden)
			probability = torch.softmax(logits.squeeze(0).squeeze(0), dim=-1)
			sample = torch.multinomial(probability, 1).item()
			if sample == self.ending_idx:
				break
			rule.append(sample)
		return rule

	def log_probability(self, rule):
		rule.append(self.ending_idx)
		relation = torch.LongTensor([rule[0]])
		if self.use_cuda:
			relation = relation.cuda()
		hidden = self.zero_state(1)
		log_prob = 0.0
		for k in range(1, len(rule)):
			inputs = torch.LongTensor([[rule[k-1]]])
			if self.use_cuda:
				inputs = inputs.cuda()
			logits, hidden = self.forward(inputs, relation, hidden)
			log_prob += torch.log_softmax(logits.squeeze(0).squeeze(0), dim=-1)[rule[k]]
		return log_prob

	def next_relation_log_probability(self, seq):
		inputs = torch.LongTensor([seq])
		relation = torch.LongTensor([seq[0]])
		if self.use_cuda:
			inputs = inputs.cuda()
			relation = relation.cuda()
		hidden = self.zero_state(1)
		logits, hidden = self.forward(inputs, relation, hidden)
		log_prob = torch.log_softmax(logits[0, -1, :] * 5, dim=-1).data.cpu().numpy().tolist()
		return log_prob

	def beam_search(self, relation, num_samples, max_len):
		with torch.no_grad():
			found_rules = []
			prev_rules = [[[relation], 0]]
			for k in range(max_len):
				self.print(f"k = {k} |prev| = {len(prev_rules)}")
				current_rules = list()
				for _i, (rule, score) in enumerate(prev_rules):
					assert rule[-1] != self.ending_idx
					log_prob = self.next_relation_log_probability(rule)
					for i in (range(self.label_size) if (k + 1) != max_len else [self.ending_idx]):
						# if k != 0 and rule[-1] == self.inv(i):
						# 	continue
						new_rule = rule + [i]
						new_score = score + log_prob[i]
						(current_rules if i != self.ending_idx else found_rules).append((new_rule, new_score))
					
					Predictor.clean()
					if _i % 100 == 0:
						self.print(f"beam_search k = {k} i = {_i}")
				prev_rules = sorted(current_rules, key=lambda x:x[1], reverse=True)[:num_samples]
				found_rules = sorted(found_rules, key=lambda x:x[1], reverse=True)[:num_samples]

			current_rules = list()
			self.print(f"beam_search |rules| = {len(found_rules)}")
			for rule, score in found_rules:
				assert rule[-1] == self.ending_idx
				current_rules.append(rule[:-1])
				# else:
				# 	current_rules.append(rule)
			return current_rules
