import gc
import copy
gc.enable()

import os
import sys
from sys import *
#from random import *
from collections import defaultdict
import torch
import torch.nn as nn
from torch import *
from torch.nn import *
from torch.optim import *
from random import shuffle
from random import randint
import time
import datetime
import json
import torch.nn.functional as F
from model import *

# from apex import amp

for i in range(len(sys.argv)):
	if sys.argv[i].startswith("--local_rank"):
		sys.argv.pop(i)
		break

device = 'cuda'

DATA_DIR = sys.argv[1]
LOCAL = True
batch_size = torch.cuda.device_count()
batch_size = 1
assert batch_size == 1


print(f"num_gpus = {torch.cuda.device_count()}", flush=True)
print(f"DATA_DIR = {DATA_DIR}", flush=True)

#print(LOCAL)

entity2id = dict()
relation2id = dict()
id2entity = dict()
id2relation = dict()

with open(f'{DATA_DIR}/entities.dict') as fin:
	entity2id = dict()
	for line in fin:
		eid, entity = line.strip().split('\t')
		entity2id[entity] = int(eid)
		id2entity[int(eid)] = entity

with open(f'{DATA_DIR}/relations.dict') as fin:
	relation2id = dict()
	for line in fin:
		rid, relation = line.strip().split('\t')
		relation2id[relation] = int(rid)
		id2relation[int(rid)] = relation

E = len(entity2id)
R = len(relation2id)

mov = R
for i in range(R):
	id2relation[i + mov] = id2relation[i] + "_REV"

R += mov
R += 1

class Graph:
	def __init__(self):
		self.e = [[set() for i in range(R)] for i in range(E)]

	def add(self, h, r, t):
		self.e[h][r].add(t)

G = Graph()
G_test = Graph()

train_t = defaultdict(lambda : set())
answer_valid_t = defaultdict(lambda : set())
train_h = [set() for i in range(R)]


def add_data(train_t, train_h, G_list, h, r, t, add=True):
	assert 0<=r<R
	if add:
		for G in G_list:
			G.add(h, r, t)
			assert r!=R-1

	if True:
		train_h[r].add(h)
		train_t[(h, r)].add(t)
		answer_valid_t[(h, r)].add(t)

# for i in range(E):
# 	add_data(train_t, train_h, G, i, R-1, i, False)


ntrain = 0
nvalid = 0
ntest = 0
with open(f"{DATA_DIR}/train.txt") as fin:
	for line in fin:
		ntrain += 1

with open(f"{DATA_DIR}/valid.txt") as fin:
	for line in fin:
		nvalid += 1

with open(f"{DATA_DIR}/test.txt") as fin:
	for line in fin:
		ntest += 1

print(f"{E}&{mov}&{ntrain}&{nvalid}&{ntest}")