import gc
import copy
gc.enable()

import os
import sys
from sys import *
#from random import *
from collections import defaultdict
import torch
import torch.nn as nn
from torch import *
from torch.nn import *
from torch.optim import *
from random import shuffle
from random import randint
import time
import datetime
import json
import torch.nn.functional as F
from model import *

# from apex import amp

for i in range(len(sys.argv)):
	if sys.argv[i].startswith("--local_rank"):
		sys.argv.pop(i)
		break

device = 'cuda'

DATA_DIR = sys.argv[1]
OUTPUT_DIR = sys.argv[2]

_args = dict()
_args['RotatE'] = 'RotatE'
_args['em_epoch'] = 2
_args['start'] = 0
_args['hop'] = 1
_args['pgnd_num'] = 256
_args['pgnd_polish_selflink_only'] = True
_args['pgnd_use_gamma'] = True
_args['pgnd_panelty'] = 0.1
_args['max_rules'] = 1000
_args['max_rule_len'] = 4
_args['max_train_h'] = 5000
_args['max_beam_rules'] = 2000
_args['max_rules_per_h'] = 300
_args['use_ranking_loss'] = False
_args['use_negative_rules'] = False
_args['model_rate'] = 1
_args['predictor_lr'] = 5e-5
_args['generator_lr'] = 1e-3
_args['predictor_epoch'] = 100000
_args['generator_epoch'] = 10000
_args['generator_embed_dim'] = 512
_args['generator_hidden_dim'] = 256
_args['cuda_mem'] = 32


for k in _args:
	_args[k] = str(_args[k])
for arg in sys.argv[3:]:
	k, *v = arg.split('=')
	_args[k] = '='.join(v)

def args(k, apply=eval):
	if k in _args:
		v = apply(_args[k])
		print(f"Argument Used: {k} = {_args[k]} ({v})")
		return v
	else:
		print(f"Error: No argument '{k}'.")
		raise Exception

# batch_size = torch.cuda.device_count()
batch_size = 1
assert batch_size == 1





print(f"num_gpus = {torch.cuda.device_count()}", flush=True)
print(f"DATA_DIR = {DATA_DIR}", flush=True)
print(f"OUTPUT_DIR = {OUTPUT_DIR}", flush=True)

#print(LOCAL)

log_file = open(f"{OUTPUT_DIR}/train_log.txt", 'a')

old_print = print

def write(printstr, end='\n'):
	old_print(printstr, end=end)
	log_file.write(printstr + end)
	log_file.flush()

print = write

entity2id = dict()
relation2id = dict()
id2entity = dict()
id2relation = dict()

with open(f'{DATA_DIR}/entities.dict') as fin:
	entity2id = dict()
	for line in fin:
		eid, entity = line.strip().split('\t')
		entity2id[entity] = int(eid)
		id2entity[int(eid)] = entity

with open(f'{DATA_DIR}/relations.dict') as fin:
	relation2id = dict()
	for line in fin:
		rid, relation = line.strip().split('\t')
		relation2id[relation] = int(rid)
		id2relation[int(rid)] = relation

E = len(entity2id)
R = len(relation2id)

mov = R
for i in range(R):
	id2relation[i + mov] = id2relation[i] + "_REV"

R += mov
R += 1

class Graph:
	def __init__(self):
		self.e = [[set() for i in range(R)] for i in range(E)]

	def add(self, h, r, t):
		self.e[h][r].add(t)

G = Graph()
G_test = Graph()

train_t = defaultdict(lambda : set())
answer_valid_t = defaultdict(lambda : set())
train_h = [set() for i in range(R)]


def add_data(train_t, train_h, G_list, h, r, t, add=True):
	assert 0<=r<R
	if add:
		for G in G_list:
			G.add(h, r, t)
			assert r!=R-1

	if True:
		train_h[r].add(h)
		train_t[(h, r)].add(t)
		answer_valid_t[(h, r)].add(t)

# for i in range(E):
# 	add_data(train_t, train_h, G, i, R-1, i, False)


with open(f"{DATA_DIR}/train.txt") as fin:
	for line in fin:
		h, r, t = line.strip().split('\t')
		h, r, t = entity2id[h], relation2id[r], entity2id[t]
	   
		add_data(train_t, train_h, [G, G_test], h, r, t)
		add_data(train_t, train_h, [G, G_test], t, r + mov, h)

for i in range(R):
	train_h[i] = tuple(train_h[i])

valid_h = [set() for r in range(R)]
valid_t = defaultdict(lambda : set())
with open(f"{DATA_DIR}/valid.txt") as fin:
	for line in fin:
		h, r, t = line.strip().split('\t')
		h, r, t = entity2id[h], relation2id[r], entity2id[t]
		add_data(valid_t, valid_h, [G_test], h, r, t)
		add_data(valid_t, valid_h, [G_test], t, r + mov, h)

answer_test_t = copy.deepcopy(answer_valid_t)
test_h = [set() for r in range(R)]
test_t = defaultdict(lambda : set())
with open(f"{DATA_DIR}/test.txt") as fin:
	for line in fin:
		h, r, t = line.strip().split('\t')
		h, r, t = entity2id[h], relation2id[r], entity2id[t]
		answer_test_t[(h, r)].add(t)
		answer_test_t[(t, r + mov)].add(h)
		test_t[(h, r)].add(t)
		test_t[(t, r + mov)].add(h)
		test_h[r].add(h);
		test_h[r + mov].add(t)

# print(max(map(len, train_h)))
# for i in range(R):
# 	old_print(i, len(train_h[i]))

RotatE = args('RotatE', str)
if RotatE == "None":
	RotatE = None
model_init = Predictor(E, R, DATA_DIR, RotatE, print=print)
model_init.set_cuda_mem(args('cuda_mem'))

# global trie_edges, con_rule, con_entity, con_matrix


start = args('start')
hop = args('hop')


print(f"Work range: [{start}, {R-1}, {hop}]")

for r in range(start, R - 1, hop):
	def print(*s):
		s = ' '.join(map(str, s))
		timestr = datetime.datetime.now().strftime("%H:%M:%S.%f")
		return write(f"{timestr} r = {r} | {s}")

	num_EM_epochs = args('em_epoch')

	min_groundings = args('pgnd_num')
	only_self_link = args('pgnd_polish_selflink_only')
	allow_nonpos_add = args('pgnd_use_gamma')
	non_grounding_panelty = args('pgnd_panelty')

	MAX_TRAIN_H = args('max_train_h')
	MAX_RULES = args('max_beam_rules')
	MAX_AG_RULES = MAX_RULES
	MAX_SAMPLES = args('max_rules_per_h')
	MAX_RULE_LEN = args('max_rule_len')
	model_init.MAX_RULE_NUM = args('max_rules')
	model_init.MAX_RULE_LEN_HARD = MAX_RULE_LEN
	model_init.use_ranking_loss = args('use_ranking_loss')
	model_init.allow_neg_rules = args('use_negative_rules')
	# model.MAX_RULE_LEN_HARD = MAX_RULE_LEN

	model_rate = args('model_rate')
	predictor_lr = args('predictor_lr')
	generator_lr = args('generator_lr')
	predictor_epoch = args('predictor_epoch')
	generator_epoch = args('generator_epoch')

	print(f"Work r = {r}")

	if len(train_h[r]) == 0:
		for EM_epoch in range(num_EM_epochs + 1):
			EM_str = EM_epoch if EM_epoch < num_EM_epochs else '#'
			print(f"Skipped {EM_str} __V__	{r}	-0.0	-0.0000	-0.0000	-0.0000	-0.0000	-0.0000")
			print(f"Skipped {EM_str} __T__	{r}	-0.0	-0.0000	-0.0000	-0.0000	-0.0000	-0.0000")
		continue



	generator = Generator(R - 1, args('generator_embed_dim'), args('generator_hidden_dim')).cuda()
	generator.train()

	rule_coef = dict()

	additional_grounding_buffer = dict()

	for EM_epoch in range(num_EM_epochs + 1):

		def print(*s):
			s = ' '.join(map(str, s))
			timestr = datetime.datetime.now().strftime("%H:%M:%S.%f")
			EM_str = EM_epoch if EM_epoch < num_EM_epochs else '#'
			return write(f"[{timestr}] r = {r} EM = {EM_str} | {s}")

		if EM_epoch == num_EM_epochs:
			rule_coef = dict()


		model = copy.deepcopy(model_init)
		model.additional_grounding_buffer = additional_grounding_buffer

		model.print = print
		generator.print = print

		# generate rules
		if EM_epoch == 0:
			model.relation_init_begin(r)
		else:

			sampled = set()
			sampled.add(tuple([r]))
			sampled.add(tuple([]))

			paths = [(r,)]
			prior = [0.0,]
			for _p in generator.beam_search(r, MAX_RULES, MAX_RULE_LEN + 1):
				p = tuple(_p)[1:]
				if p in sampled:
					continue
				sampled.add(p)
				paths.append(p)
				v = generator.log_probability(_p)
				prior.append(v)
				if len(sampled) % 100 == 0:
					print(f"sampled # = {len(sampled)} p = {p} v = {v} score = {'***' if p not in rule_coef else rule_coef[p]}")

			print(f"Done |sampled| = {len(sampled)}")

			model.relation_init_load(r, paths)
			model.set_prior_value(torch.tensor(prior).cuda())

		ag_rules = set(model.rule_list[:MAX_AG_RULES])

		def make_edges(rule_cho=None):
			if rule_cho is None:
				num_rules = len(model.rule_list)
				rule_cho = range(num_rules)
			else:
				if isinstance(rule_cho, torch.Tensor):
					rule_cho = rule_cho.detach().cpu().numpy().tolist()
				num_rules = len(rule_cho)
			trie_edges = defaultdict(lambda : [[], set()])

			for rule, rule_id in enumerate(rule_cho):
				path = model.rule_list[rule_id]
				if rule == 0:
					# this path can have no groundings
					assert path == (r,)
					continue
				for i in range(len(path) + 1):
					tmp = tuple(path[:i])
					if i == len(path):
						trie_edges[tmp][0].append((rule, rule_id))
					else:
						trie_edges[tmp][1].add(path[i])
			return trie_edges

		def make_batch(h, r, t_list, trie_edges, answer=None, G=G, count=False, rule_cho=None, add_cho=False):
			t_list = list(t_list)
			if answer is None:
				answer = t_list

			_rule_cho = rule_cho
			if rule_cho is None:
				rule_cho = range(len(model.rule_list))
				if add_cho:
					_rule_cho = torch.arange(len(model.rule_list))
			else:
				add_cho = True
				if isinstance(rule_cho, torch.Tensor):
					rule_cho = rule_cho.detach().cpu().numpy().tolist()

			con_rule = []
			con_entity = []
			con_weight = []
			con_count = []
			groundings = [[] for i in rule_cho]

			def dfs(path, pos):
				edge = trie_edges[path]
				if len(edge[0]) > 0:
					for p, c in pos.items():
						for (rule, rule_id) in edge[0]:
							con_entity.append(p)
							con_rule.append(rule)
							con_weight.append(1.0)
							groundings[rule].append(p)
							if count:
								con_count.append(c)
				for r in edge[1]:
					newpos = defaultdict(lambda : 0)
					for p, c in pos.items():
						for q in G.e[p][r]:
							newpos[q] += c
					if len(newpos) > 0:
						dfs(path + (r,), newpos)


			dfs(tuple(), {h : 1})

			con_rule = [torch.LongTensor(con_rule)]
			con_entity = [torch.LongTensor(con_entity)]
			con_weight = [torch.FloatTensor(con_weight)]
			if count:
				con_count = [torch.FloatTensor(con_count)]

			for rule, rule_id in enumerate(rule_cho):
				if model.rule_list[rule_id] not in ag_rules:
					continue

				add = min_groundings * (8 if rule_id == 0 else 1) * (0 if (only_self_link and rule_id != 0) else 1) - len(groundings[rule])

				if add <= 0 and not allow_nonpos_add:
					continue

				ag = model.additional_groundings(h, rule_id, add, groundings[rule])

				con_rule.append(torch.zeros_like(ag).long() + rule)
				con_entity.append(ag)
				con_weight.append(torch.ones_like(ag) * non_grounding_panelty)
				if count:
					# extra points can not have con
					con_count.append(torch.ones_like(ag) * non_grounding_panelty)

			con_rule = torch.cat(con_rule, dim=0)
			con_entity = torch.cat(con_entity, dim=0)
			con_weight = torch.cat(con_weight, dim=0)
			if count:
				con_count = torch.cat(con_count, dim=0)

			if con_rule.size()[-1] == 0:
				# no contributions
				con_rule = torch.tensor([0]).cuda()
				con_entity = torch.tensor([h]).cuda()
				con_weight = torch.tensor([0]).cuda()
				if count:
					con_count = torch.tensor([0]).cuda()

			mask = torch.zeros(E).bool()
			for t in answer:
				mask[t] = 1

			ret = (h, r, t_list, mask.detach(), con_rule.detach(), con_entity.detach(), con_weight.detach())
			if count:
				ret = ret + (con_count.detach(),)
			if add_cho:
				ret = ret[:4] + (_rule_cho, ) + ret[4:]
			return ret

		trie_edges = make_edges()
		train_h_sampled = list(train_h[r])
		shuffle(train_h_sampled)
		train_h_sampled = train_h_sampled[:MAX_TRAIN_H]

		for (i, h) in enumerate(train_h_sampled):
			if i % model_rate == 0:
				print(f"init: {i}/{len(train_h_sampled)}")
			t_list = train_t[(h, r)]
			model.rule_init_step(make_batch(h, r, t_list, trie_edges), weight=len(t_list))

		gc.collect()
		rule_init_value = model.rule_init_end()
		rule_coef_list = []

		for (i, rule) in enumerate(model.rule_list):
			rule = tuple(rule)
			if rule not in rule_coef:
				rule_coef[rule] = rule_init_value[i].item()
			rule_coef_list.append(rule_coef[rule])

		model.set_rule_coef(torch.tensor(rule_coef_list))
			
		del rule_init_value
		del rule_coef
		del rule_coef_list

		model = model.cuda()
		for name, param in model.named_parameters():
			print(f"Model Parameter: {name} ({param.type()}:{param.size()})")

		train_batch = []
		valid_batch = []
		test_batch = []
		gen_batch = []


		trie_edges = make_edges()

		for (i, h) in enumerate(train_h_sampled):
			if i % (10 * model_rate) == 0:
				print(f"train_batch: {i}/{len(train_h_sampled)}")
			t_list = train_t[(h, r)]
			batch = make_batch(h, r, t_list, trie_edges, add_cho=True)
			# cho = model.choose_rules(batch, MAX_SAMPLES).cpu()
			# private_edges = make_edges(cho)
			# batch = make_batch(h, r, t_list, private_edges, rule_cho=cho)
			# for gb in zip(*make_gen_batch(cho)):
			# 	gen_batch.append(gb)
			for t in t_list:
				train_batch.append(batch[:2] + ([t], ) + batch[3:])

		gc.collect()

		for (i, h) in enumerate(valid_h[r]):
			if i % (10 * model_rate) == 0:
				print(f"valid_batch: {i}/{len(valid_h[r])}")
			batch = make_batch(h, r, valid_t[(h, r)], trie_edges, answer=answer_valid_t[(h, r)], add_cho=True)
			valid_batch.append(batch)

		for (i, h) in enumerate(test_h[r]):
			if i % (10 * model_rate) == 0:
				print(f"test_batch: {i}/{len(test_h[r])}")
			batch = make_batch(h, r, test_t[(h, r)], trie_edges, answer=answer_test_t[(h, r)], G=G_test, add_cho=True)
			test_batch.append(batch)

		model.relation_init_end()
		model.clean()
		gc.collect()

		def train_model(model, train_batch, valid_batch, num_epochs=100000, lr=5e-5, model_rate=1):
			shuffle(train_batch)
			shuffle(valid_batch)

			num_epochs *= model_rate
			print_epoch = 50 * model_rate
			valid_print_rate = 5
			valid_epoch = max(100, int(len(valid_batch) / 400 + 0.5) * 100) * model_rate
			save_rate = 10 * model_rate
			
			decay = 0
			cum_loss = torch.tensor(0.0)
			global best_result
			global best_model
			best_result = torch.zeros(6) - 1e10
			best_model = model.state_dict()

			opt = Adam(model.parameters(), lr=lr, weight_decay=decay)
			sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_epochs, eta_min=lr/5)

			def export(result):
				result = result / result[0].clamp(min=1)
				return result[1:].detach().cpu().numpy().tolist()

			def format(result):

				return " ".join(map(lambda x : "%.4lf" % x, export(result)))

			def value(result, eps=0):
				result = export(result)
				return (1.0 * result[1] + 0.9 * result[2] + 0.8 * result[3] + 0.7 * result[4] + 0.01 / max(1, result[0]) + eps, result[2] + eps, result[3] + eps, result[4], -result[0] + eps)

			def train_step(batch):
				model.set_training(True)
				loss, _ = model(batch)
				opt.zero_grad()
				loss.backward()
				opt.step()
				sch.step()
				return loss

			def evaluate(valid_batch):
				model.set_training(False)
				with torch.no_grad():
					valid_result = best_result * 0
					for i in range(0, len(valid_batch), batch_size):
						valid_result += model(valid_batch[i : i + batch_size])[1].cpu()
						if i % (print_epoch * valid_print_rate) == 0 and i > 0:
							print(f"eval #{i}/{len(valid_batch)}")
							
				return valid_result

			def valid():
				global best_result
				global best_model
				valid_result = evaluate(valid_batch)
				updated = False
				same = False
				if value(valid_result, eps=1e-6) > value(best_result):
					if value(valid_result) > value(best_result):
						updated = True
					else:
						same = True
					best_result = valid_result.clone().detach()
					best_model = copy.deepcopy(model.state_dict())
				print(f"valid = {format(valid_result)} {'updated' if updated else ''} {'same' if same else ''}")
				return updated, valid_result

		
			valid()
			last_update = 0

			torch.cuda.empty_cache()
			gc.collect()

			for epoch in range(1, num_epochs + 1):
				if epoch % max(1, len(train_batch) // batch_size) == 0:
					shuffle(train_batch)
				batch = [train_batch[(epoch * batch_size + i) % len(train_batch)] for i in range(batch_size)]
				loss = train_step(batch)
				cum_loss += loss.cpu()

				if epoch % print_epoch == 0:
					lr_str = "%.2e" % (opt.param_groups[0]['lr'])
					print(f"#{epoch} lr = {lr_str} loss = {cum_loss.item()/print_epoch}")
					cum_loss *= 0

				if epoch % valid_epoch == 0:
					if valid()[0]:
						last_update = epoch
					elif epoch >= max(last_update + 10000, 10000) * model_rate:
						print(f"Early break: Never updated since {last_update}")
						break
					if 1-1e-6 < export(best_result)[0] < 1+1e-6:
						print(f"Early break: Perfect")
						break

			with torch.no_grad():
				model.load_state_dict(best_model)
				model.relation_embed *= 0
				model.relation_embed += model_init.relation_embed.cuda()
				model.rule_weight_raw[0] += 1000.0
				valid()

			model.load_state_dict(best_model)
			test_result = evaluate(test_batch)
			print("__V__\t" + ("\t".join([str(r), str(best_result[0].item())] + list(map(lambda x : "%.4lf"%x, export(best_result))))))
			print("__T__\t" + ("\t".join([str(r), str(test_result[0].item())] + list(map(lambda x : "%.4lf"%x, export(test_result))))))

			return best_model, best_result, test_result

		def train_generator(generator, gen_data, num_epochs=10000, lr=1e-3):

			opt = Adam(generator.parameters(), lr=lr)
			sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_epochs, eta_min=lr/10)

			print_epoch = 100
			cum_loss = 0
			for epoch in range(1, num_epochs + 1):

				loss = generator.loss(*gen_data)

				opt.zero_grad()
				loss.backward()
				opt.step()
				sch.step()

				cum_loss += loss.item()
				model.clean()

				if epoch % print_epoch == 0:
					lr_str = "%.2e" % (opt.param_groups[0]['lr'])
					print(f"train_generator #{epoch} lr = {lr_str} loss = {cum_loss / print_epoch}")
					cum_loss = 0


		best_model, best_result, test_result = train_model(model, train_batch, valid_batch, num_epochs=predictor_epoch, model_rate=model_rate, lr=predictor_lr) # , num_epochs=1000)
		model.load_state_dict(best_model)
		rule_weight = model.rule_weight()
		
		rule_coef = dict()
		for (i, rule) in enumerate(model.rule_list):
			rule = tuple(rule)
			rule_coef[rule] = rule_weight[i].item()

		del train_batch
		del valid_batch
		del test_batch

		if EM_epoch != num_EM_epochs:

			weight = torch.zeros_like(model.rule_weight_raw).long().cpu()

			def make_gen_data(weight):
				# assert cho[0].item() == 0
				nonzero = (weight > 0)
				print(f"make_gen_data |nonzero| = {nonzero.sum().item()}")
				rules = model.rules_gen[nonzero]
				weight = weight[nonzero]
				inp = rules[:, :-1]
				tar = rules[:, 1:]
				mask = (tar != model.gen_pad)
				return inp, tar, mask, weight

			for (i, h) in enumerate(train_h_sampled):
				if i % 10 == 0:
					print(f"gen_data: {i}/{len(train_h_sampled)}")
				t_list = train_t[(h, r)]
				batch = make_batch(h, r, t_list, trie_edges)
				cho = model.choose_rules(batch, MAX_SAMPLES).cpu()
				weight[cho] += len(t_list)

			gen_data = make_gen_data(weight)

			train_generator(generator, gen_data, num_epochs=generator_epoch, lr=generator_lr) #, num_epochs=1000)
			del gen_data

		ckpt = {
			'r': r,
			'valid': best_result,
			'test': test_result,
			'rule_list': model.rule_list,
			# 'ag_rules': list(ag_rules),
			'args': args,
			'predictor': best_model,
			'generator': generator.state_dict(),
		}
		torch.save(ckpt, f"{OUTPUT_DIR}/model_{r}.ckpt")

		del model
		gc.collect()














 