import gc
import copy
gc.enable()

import os
import sys
from sys import *
#from random import *
from collections import defaultdict
import torch
import torch.nn as nn
from torch import *
from torch.nn import *
from torch.optim import *
from random import shuffle
from random import randint
import time
import datetime
import json
import torch.nn.functional as F
from model import *

relation2id = dict()
id2relation = dict()

with open(f'dataset/umls_325/relations.dict') as fin:
	relation2id = dict()
	for line in fin:
		rid, relation = line.strip().split('\t')
		relation2id[relation] = int(rid)

		rel = ""
		cnt = 0
		for c in reversed(relation):
			if c == '/':
				cnt += 1
				if cnt == 2 and len(rel) >= 30:
					break
				if cnt >= 3:
					break
			rel = c + rel

		id2relation[int(rid)] = rel

R = len(relation2id)

mov = R
inv = [0] * 2 * R

for i in range(R):
	inv[i + mov] = i
	inv[i] = i + mov
	id2relation[i + mov] = "!" + id2relation[i]

a = torch.load(sys.argv[1])
r = int(a['r'])

def has_revlink(rule):
	for i in range(len(rule) - 1):
		if rule[i] == inv[rule[i + 1]]:
			return True
	return False


def contains_r(rule):
	r = int(a['r'])
	ret = r in set(map(int, rule))
	return ret

def prt(p, n=10):
	for _i in range(n):
		i = p[_i]
		r = id2relation[int(a['r'])]
		path = map(lambda x : id2relation[int(x)], a['rule_list'][i])
		val = a['predictor']['rule_weight_raw'][i]


		print(f"&$\\gets$&$",end='')
		Ltr = "XUVW"
		for i, r in enumerate(path):
			r = r.replace('_', '\\_')
			if r[0] != '!':
				print(f"{Ltr[i]}\\relarr{{{r}}}",end="")
			else:
				print(f"{Ltr[i]}\\relarrl{{{r[1:]}}}",end="")

		print("Y$\\\\")



weight = a['predictor']['rule_weight_raw']


print("Relation:", id2relation[int(a['r'])])
print("general:")
p = sorted(range(len(a['rule_list'])), 
	key=lambda i : (weight[i]),
	reverse=True)
prt(p)

print("no self:")
p = sorted(range(len(a['rule_list'])), 
	key=lambda i : (not contains_r(a['rule_list'][i]), not has_revlink(a['rule_list'][i]), weight[i]),
	reverse=True)
prt(p,n=40)

# print("revlink:")
# p = sorted(range(len(a['rule_list'])), 
# 	key=lambda i : (has_revlink(a['rule_list'][i]), weight[i]),
# 	reverse=True)
# prt(p,n=40)