import numpy as np
import json
import pickle
import torch
import random
import time
from tqdm import tqdm
import os


def load_rules(file, threshold = 0.01, k_rules = 100):
    rules = json.load(open(file))
    guidance = {}
    conf = {}
    c_avg = []
    for i in rules:
        content = rules[i]
        rels = []
        confs = []
        constraints = []
        for j in content:
            #if not j['acyclic'] and j['conf'] > threshold and j['body_supp'] > 2 and len(j['var_constraints']) == 0:
            if not j['acyclic'] and j['conf'] > threshold and j['body_supp'] > 2:
            #if j['conf'] > threshold and j['body_supp'] > 2:
                rels.append(j['body_rels'])
                confs.append(j['conf'])
                c_avg.append(j['conf'])
                constraints.append(j['var_constraints'])
            rels = rels[:k_rules]
            confs = confs[:k_rules]
        guidance[int(i)] = (rels, confs, constraints)
        print(confs)
        conf[int(i)] = np.array(confs)
    with open('guidance.pkl', 'wb') as fh:
        pickle.dump(guidance, fh)
    with open('confs.pkl', 'wb') as fc:
        pickle.dump(conf, fc)
    return guidance, sum(c_avg)/len(c_avg)

def rebuild_data(rules, dataset, name):
    with open('entity2id.json', 'r') as f:
        entity2id = json.load(f)
    with open('relation2id.json', 'r') as f:
        rel2id = json.load(f)
    with open(rules, 'rb') as fh:
        rule = pickle.load(fh)
    triplets = []
    with open(dataset, 'r') as fr:
        for line in fr:
            line_split = line.split('\t')
            head = int(entity2id[line_split[0]])
            tail = int(entity2id[line_split[2]])
            rel = int(rel2id[line_split[1]])
            triplets.append((head, rel, tail))
    triplets = list(set(triplets))
    queries = []
    answers = {}
    e = 'e'
    r = 'r'
    for i in tqdm(triplets):
        rel = i[1]
        head = i[0]
        tail = i[2]
        #fired_rules = rule[rel]
        query = [(head, rel), (e, (r))]      #triplet itself should be considered as a query for better 1p training
        if query[0] not in answers:
            answers[query[0]] = set()
            answers[query[0]].add(tail)
        else:
            answers[query[0]].add(tail)
        queries.append((tuple(query), (head, rel, tail)))

        if rel in rule:
            fired_rules = rule[rel]
        else:
            fired_rules = []
        for j in fired_rules:
            length = len(j)
            j = tuple(j)
            if length == 1:
                query = [(head, j), (e, (r))]

            elif length == 2:
                query = [(head, j), (e, (r, r))]

            elif length == 3:
                query = [(head, j), (e, (r, r, r))]
            if query[0] not in answers:
                answers[query[0]] = set()
                answers[query[0]].add(tail)
            else:
                answers[query[0]].add(tail)
            queries.append((tuple(query), (head, rel, tail)))
    queries = tuple(queries)
    queries = list(set(queries))

    with open(name + 'queries_.pkl', 'wb') as fq:
        pickle.dump(queries, fq)

    with open(name + 'answers_.pkl', 'wb') as fa:
        pickle.dump(answers, fa)

def rebuild_test(rules, dataset, name, avg_conf = 0.0):
    with open('entity2id.json', 'r') as f:
        entity2id = json.load(f)
    with open('relation2id.json', 'r') as f:
        rel2id = json.load(f)

    with open('ts2id.json', 'r') as f:
        ts2id = json.load(f)
    with open(rules, 'rb') as fh:
        rule = pickle.load(fh)
    triplets = []
    inverse_len = len(rel2id)
    with open(dataset, 'r') as fr:
        for line in fr:
            line_split = line.split('\t')
            head = int(entity2id[line_split[0]])
            tail = int(entity2id[line_split[2]])
            rel = int(rel2id[line_split[1]])
            ts = int(ts2id[line_split[3].strip()])
            triplets.append((head, rel, tail, ts))
            triplets.append((tail, rel + inverse_len, head, ts))
    #triplets = list(set(triplets))
    queries = []
    answers = {}
    e = 'e'
    r = 'r'
    for i in tqdm(triplets):
        rel = i[1]
        head = i[0]
        tail = i[2]
        ts = i[3]
        #fired_rules = rule[rel]
        query = i
        #guidance = [[(head, rel), (e, (r))]]      #triplet itself should be considered as a query for better 1p training
        guidance = [[(head, rel), (e, (r)), avg_conf, []]]
        #guidance = []
        if i not in answers:
            answers[i] = set()
            answers[i].add(tail)
        else:
            answers[i].add(tail)
        #queries[i].add(tuple(query))

        if rel in rule:
            fired_rules, confs, var_constraints = rule[rel]
        else:
            fired_rules = []
            confs = []
            var_constraints = []
        for j in range(len(fired_rules)):
            length = len(fired_rules[j])
            body = tuple(fired_rules[j])
            if length == 1:
                guide = [(head, body), (e, (r)), confs[j], var_constraints[j]]
            elif length == 2:
                guide = [(head, body), (e, (r, r)), confs[j], var_constraints[j]]
            elif length == 3:
                guide = [(head, body), (e, (r, r, r)), confs[j], var_constraints[j]]

            guidance.append(guide)
        queries.append([query, guidance])

    with open(name + 'queries_conf_constraints_3.pkl', 'wb') as fq:
        pickle.dump(queries, fq)

    with open(name + 'answers_conf_constraints_3.pkl', 'wb') as fa:
        pickle.dump(answers, fa)


def mult_data_builder(rule, dataset, index, num):
    queries = []
    answers = {}
    e = 'e'
    r = 'r'
    upper = min(len(dataset), (index + 1) * num)
    triplets = dataset[index*num, upper]
    for i in tqdm(triplets):
        rel = i[1]
        head = i[0]
        tail = i[2]
        # fired_rules = rule[rel]
        query = [(head, rel), (e, (r))]  # triplet itself should be considered as a query for better 1p training
        s = time.time()
        if query[0] not in answers:
            answers[query[0]] = set()
            answers[query[0]].add(tail)
        else:
            answers[query[0]].add(tail)
        queries.append(tuple(query))
        print('1 time: ', time.time() - s)

        if rel in rule:
            fired_rules = rule[rel]
        else:
            fired_rules = []
        s = time.time()
        for j in fired_rules:
            length = len(j)
            j = tuple(j)
            st = time.time()
            if length == 1:
                query = [(head, rel), (e, (r))]

            elif length == 2:
                query = [(head, j), (e, (r, r))]

            elif length == 3:
                query = [(head, j), (e, (r, r, r))]
            print('elif in loop: ', time.time() - st)
            st = time.time()
            if query[0] not in answers:
                answers[query[0]] = set()
                answers[query[0]].add(tail)
            else:
                answers[query[0]].add(tail)
            print('if in rule loop: ', time.time() - st)
            queries.append(tuple(query))
        print('rule loop time: ', time.time() - s)
        queries = tuple(queries)
        queries = list(set(queries))

    return []




guidance, avg = load_rules('18_rules_tr.json')
time.sleep(1)
'''rebuild_test('guidance.pkl', 'train.txt', 'train_')
rebuild_test('guidance.pkl', 'valid.txt', 'valid_')
rebuild_test('guidance.pkl', 'test.txt', 'test_')'''
#rebuild_data('guidance.pkl', 'train.txt', 'train_noflatten_')
rebuild_test('guidance.pkl', 'valid.txt', 'valid_repro_123_inverse_tr_0.01', 0.01)
rebuild_test('guidance.pkl', 'test.txt', 'test_repro_123_inverse_tr_0.01', 0.01)