import gc
import copy
gc.enable()

import os
import sys
from sys import *
#from random import *
from collections import defaultdict
import torch
import torch.nn as nn
from torch import *
from torch.nn import *
from torch.optim import *
from random import shuffle
from random import randint
import time
import datetime
import json
import torch.nn.functional as F
from model import *

# from apex import amp

for i in range(len(sys.argv)):
	if sys.argv[i].startswith("--local_rank"):
		sys.argv.pop(i)
		break

device = 'cuda'

DATA_DIR = sys.argv[1]
OUTPUT_DIR = sys.argv[2]
LOCAL = (sys.argv[3] == 'True')
batch_size = torch.cuda.device_count()
batch_size = 1
assert batch_size == 1


print(f"num_gpus = {torch.cuda.device_count()}", flush=True)
print(f"DATA_DIR = {DATA_DIR}", flush=True)
print(f"OUTPUT_DIR = {OUTPUT_DIR}", flush=True)

#print(LOCAL)

log_file = open(f"{OUTPUT_DIR}/train_log.txt", 'a')

old_print = print

def write(printstr, end='\n'):
	if LOCAL:
		old_print(printstr, end=end)
	else:
		log_file.write(printstr + end)
		log_file.flush()

print = write

entity2id = dict()
relation2id = dict()
id2entity = dict()
id2relation = dict()

with open(f'{DATA_DIR}/entities.dict') as fin:
	entity2id = dict()
	for line in fin:
		eid, entity = line.strip().split('\t')
		entity2id[entity] = int(eid)
		id2entity[int(eid)] = entity

with open(f'{DATA_DIR}/relations.dict') as fin:
	relation2id = dict()
	for line in fin:
		rid, relation = line.strip().split('\t')
		relation2id[relation] = int(rid)
		id2relation[int(rid)] = relation

E = len(entity2id)
R = len(relation2id)

mov = R
for i in range(R):
	id2relation[i + mov] = id2relation[i] + "_REV"

R += mov
R += 1

class Graph:
	def __init__(self):
		self.e = [[set() for i in range(R)] for i in range(E)]

	def add(self, h, r, t):
		self.e[h][r].add(t)

G = Graph()
G_test = Graph()

train_t = defaultdict(lambda : set())
answer_valid_t = defaultdict(lambda : set())
train_h = [set() for i in range(R)]


def add_data(train_t, train_h, G_list, h, r, t, add=True):
	assert 0<=r<R
	if add:
		for G in G_list:
			G.add(h, r, t)
			assert r!=R-1

	if True:
		train_h[r].add(h)
		train_t[(h, r)].add(t)
		answer_valid_t[(h, r)].add(t)

# for i in range(E):
# 	add_data(train_t, train_h, G, i, R-1, i, False)


with open(f"{DATA_DIR}/train.txt") as fin:
	for line in fin:
		h, r, t = line.strip().split('\t')
		h, r, t = entity2id[h], relation2id[r], entity2id[t]
	   
		add_data(train_t, train_h, [G, G_test], h, r, t)
		add_data(train_t, train_h, [G, G_test], t, r + mov, h)

for i in range(R):
	train_h[i] = tuple(train_h[i])

valid_h = [set() for r in range(R)]
valid_t = defaultdict(lambda : set())
with open(f"{DATA_DIR}/valid.txt") as fin:
	for line in fin:
		h, r, t = line.strip().split('\t')
		h, r, t = entity2id[h], relation2id[r], entity2id[t]
		add_data(valid_t, valid_h, [G_test], h, r, t)
		add_data(valid_t, valid_h, [G_test], t, r + mov, h)

answer_test_t = copy.deepcopy(answer_valid_t)
test_h = [set() for r in range(R)]
test_t = defaultdict(lambda : set())
with open(f"{DATA_DIR}/test.txt") as fin:
	for line in fin:
		h, r, t = line.strip().split('\t')
		h, r, t = entity2id[h], relation2id[r], entity2id[t]
		answer_test_t[(h, r)].add(t)
		answer_test_t[(t, r + mov)].add(h)
		test_t[(h, r)].add(t)
		test_t[(t, r + mov)].add(h)
		test_h[r].add(h);
		test_h[r + mov].add(t)

# print(max(map(len, train_h)))
# for i in range(R):
# 	old_print(i, len(train_h[i]))


model_init = Predictor(E, R, DATA_DIR, "RotatE_500", print=print)

# global trie_edges, con_rule, con_entity, con_matrix


start = int(argv[4])
hop = int(argv[5])


print(f"Work range: [{start}, {R-1}, {hop}]")

for r in range(start, R - 1, hop):

	num_EM_epochs = 2

	min_groundings = 256
	only_self_link = (r < mov)
	allow_nonpos_add = True
	non_grounding_panelty = 0.1

	MAX_TRAIN_H = 5000
	MAX_RULES = 2000
	MAX_AG_RULES = 2000
	MAX_SAMPLES = 300
	MAX_RULE_LEN = 4

	print(f"Work r = {r}")


	generator = Generator(R - 1, 512, 256).cuda()
	generator.train()

	rule_coef = dict()

	additional_grounding_buffer = dict()


	def print(s):
		timestr = datetime.datetime.now().strftime("%H:%M:%S.%f")
		return write(f"[{timestr}] r = {r} {s}")

	model = copy.deepcopy(model_init)
	model.additional_grounding_buffer = additional_grounding_buffer

	model.print = print
	generator.print = print

	model.relation_init_pretrain(r, f'{OUTPUT_DIR}/pretrained/model_{r}.ckpt')
	ag_rules = set(model.rule_list[:MAX_AG_RULES])

	def make_edges(rule_cho=None):
		if rule_cho is None:
			num_rules = len(model.rule_list)
			rule_cho = range(num_rules)
		else:
			if isinstance(rule_cho, torch.Tensor):
				rule_cho = rule_cho.detach().cpu().numpy().tolist()
			num_rules = len(rule_cho)
		trie_edges = defaultdict(lambda : [[], set()])

		for rule, rule_id in enumerate(rule_cho):
			path = model.rule_list[rule_id]
			if rule == 0:
				# this path can have no groundings
				assert path == (r,)
				continue
			for i in range(len(path) + 1):
				tmp = tuple(path[:i])
				if i == len(path):
					trie_edges[tmp][0].append((rule, rule_id))
				else:
					trie_edges[tmp][1].add(path[i])
		return trie_edges

	def make_batch(h, r, t_list, trie_edges, answer=None, G=G, count=False, rule_cho=None, add_cho=False):
		t_list = list(t_list)
		if answer is None:
			answer = t_list

		_rule_cho = rule_cho
		if rule_cho is None:
			rule_cho = range(len(model.rule_list))
			if add_cho:
				_rule_cho = torch.arange(len(model.rule_list))
		else:
			add_cho = True
			if isinstance(rule_cho, torch.Tensor):
				rule_cho = rule_cho.detach().cpu().numpy().tolist()

		con_rule = []
		con_entity = []
		con_weight = []
		con_count = []
		groundings = [[] for i in rule_cho]

		def dfs(path, pos):
			edge = trie_edges[path]
			if len(edge[0]) > 0:
				for p, c in pos.items():
					for (rule, rule_id) in edge[0]:
						con_entity.append(p)
						con_rule.append(rule)
						con_weight.append(1.0)
						groundings[rule].append(p)
						if count:
							con_count.append(c)
			for r in edge[1]:
				newpos = defaultdict(lambda : 0)
				for p, c in pos.items():
					for q in G.e[p][r]:
						newpos[q] += c
				if len(newpos) > 0:
					dfs(path + (r,), newpos)


		dfs(tuple(), {h : 1})

		con_rule = [torch.LongTensor(con_rule)]
		con_entity = [torch.LongTensor(con_entity)]
		con_weight = [torch.FloatTensor(con_weight)]
		if count:
			con_count = [torch.FloatTensor(con_count)]

		for rule, rule_id in enumerate(rule_cho):
			if model.rule_list[rule_id] not in ag_rules:
				continue

			add = min_groundings * (8 if rule_id == 0 else 1) * (0 if (only_self_link and rule_id != 0) else 1) - len(groundings[rule])

			if add <= 0 and not allow_nonpos_add:
				continue

			ag = model.additional_groundings(h, rule_id, add, groundings[rule])

			con_rule.append(torch.zeros_like(ag).long() + rule)
			con_entity.append(ag)
			con_weight.append(torch.ones_like(ag) * non_grounding_panelty)
			if count:
				# extra points can not have con
				con_count.append(torch.ones_like(ag) * non_grounding_panelty)

		con_rule = torch.cat(con_rule, dim=0)
		con_entity = torch.cat(con_entity, dim=0)
		con_weight = torch.cat(con_weight, dim=0)
		if count:
			con_count = torch.cat(con_count, dim=0)

		if con_rule.size()[-1] == 0:
			# no contributions
			con_rule = torch.tensor([0]).cuda()
			con_entity = torch.tensor([h]).cuda()
			con_weight = torch.tensor([0]).cuda()
			if count:
				con_count = torch.tensor([0]).cuda()

		mask = torch.zeros(E).bool()
		for t in answer:
			mask[t] = 1

		ret = (h, r, t_list, mask.detach(), con_rule.detach(), con_entity.detach(), con_weight.detach())
		if count:
			ret = ret + (con_count.detach(),)
		if add_cho:
			ret = ret[:4] + (_rule_cho, ) + ret[4:]
		return ret

	trie_edges = make_edges()

	model = model.cuda()
	for name, param in model.named_parameters():
		print(f"Model Parameter: {name} ({param.type()}:{param.size()})")

	train_batch = []
	valid_batch = []
	test_batch = []
	gen_batch = []

	for (i, h) in enumerate(valid_h[r]):
		if i % 10 == 0:
			print(f"valid_batch: {i}/{len(valid_h[r])}")
		batch = make_batch(h, r, valid_t[(h, r)], trie_edges, answer=answer_valid_t[(h, r)], add_cho=True)
		valid_batch.append(batch)

	for (i, h) in enumerate(test_h[r]):
		if i % 10 == 0:
			print(f"test_batch: {i}/{len(test_h[r])}")
		batch = make_batch(h, r, test_t[(h, r)], trie_edges, answer=answer_test_t[(h, r)], G=G_test, add_cho=True)
		test_batch.append(batch)

	model.relation_init_end()
	model.clean()
	gc.collect()

	def export(result):
		result = result / result[0].clamp(min=1)
		return result[1:].detach().cpu().numpy().tolist()

	def format(result):

		return " ".join(map(lambda x : "%.4lf" % x, export(result)))

	def value(result, eps=0):
		result = export(result)
		return (1.0 * result[1] + 0.9 * result[2] + 0.8 * result[3] + 0.7 * result[4] + 0.01 / max(1, result[0]) + eps, result[2] + eps, result[3] + eps, result[4], -result[0] + eps)

	def evaluate(valid_batch):
		model.set_training(False)
		with torch.no_grad():
			valid_result = torch.zeros(6)
			for i in range(0, len(valid_batch), batch_size):
				valid_result += model(valid_batch[i : i + batch_size])[1].cpu()
				if i % 100 == 0 and i > 0:
					print(f"eval #{i}/{len(valid_batch)}")
					
		return valid_result

	best_result = evaluate(valid_batch)
	test_result = evaluate(test_batch)
	print("__V__\t" + ("\t".join([str(r), str(best_result[0].item())] + list(map(lambda x : "%.4lf"%x, export(best_result))))))
	print("__T__\t" + ("\t".join([str(r), str(test_result[0].item())] + list(map(lambda x : "%.4lf"%x, export(test_result))))))
