#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: ----
"""
import csv
import time
import torch
from torch_geometric.data import Data

import numpy as np

from itertools import combinations

import requests

import re

from tqdm import tqdm

# Code in this file for interfacing with RDFox is based off that found here: 
# https://docs.oxfordsemantic.tech/getting-started.html

rdfox_server = "http://localhost:8080"
RDF_type_string = 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'


def assert_response_ok(response, message):
    '''Helper function to raise an exception if the REST endpoint returns an
    unexpected status code.'''
    if not response.ok:
        raise Exception(
            message + "\nStatus received={}\n{}".format(response.status_code,
                                                        response.text))


def dlog_to_RDF(dlog_array, prefix_dict, append_str=''):
    '''Convert an array of predicates in Datalog form to RDF form.'''
    RDF_strings = []
    for dlog_pred in dlog_array:
        # For each Datalog predicate, separate it into its prefix and main
        # component
        prefix, pred = dlog_pred.split(':')
        # Get the variables, remove the ? at the front, make the letters
        # lowercase:
        constants = [x[1:].lower() + append_str for x in pred.split('[')[1][:-1].split(',')]
        # Always of form Predicate[?X, ?Y], therefore by splitting at [ we get
        # the predicate name
        pred_name = pred.split('[')[0]
        if len(constants) == 2:  # Then binary predicate
            RDF_string = '<{}> <{}> <{}> .'.format(constants[0],
                                                   prefix_dict[prefix] + pred_name,
                                                   constants[1])
        else:
            assert (len(constants)) == 1
            RDF_string = '<{}> <{}> <{}> .'.format(constants[0],
                                                   RDF_type_string,
                                                   prefix_dict[prefix] + pred_name)
        RDF_strings.append(RDF_string)

    return RDF_strings


def create_training_dataset(dataset_name, encoding_scheme, unaryPredicates,
                            binaryPredicates, expand_neighbourhood=False):
    '''Takes as input the name of the dataset, unary and binary predicates, 
    outputs a Pytorch Geometric Data object with a graph for the body of each
    rule'''

    # If there are no datastores currently in the RDF server
    if requests.get(rdfox_server + "/datastores").text == '?Name\n':
        # Create the datastore
        response = requests.post(
            rdfox_server + "/datastores/{}".format(dataset_name),
            params={'type': 'par-complex-nn'})
        assert_response_ok(response, "Failed to create datastore.")
    else:
        # Otherwise, clear the current datastore in case there is already data
        # in there.
        response = requests.delete(rdfox_server +
                                   "/datastores/{}/content".format(dataset_name))
        assert_response_ok(response, "Failed to clear content from datastore.")

    # Construct the Ruleset
    with open("./rules/ground/{}_rules.dlog".format(dataset_name)) as f:
        a = [x for x in f]

    # Format them to send to RDF server
    formatted_datalog_rule = ' '.join(a)
    response = requests.post(
        rdfox_server + "/datastores/{}/content".format(dataset_name),
        data=formatted_datalog_rule)
    assert_response_ok(response, "Failed to add rule.")

    # Separate each rule out into elements of an array
    dlog_rules = [x.split(' ') for x in a]

    num_binary = len(binaryPredicates)

    # The first part of the dlog rules file is made up of prefixes
    # We'll put these in a dictionary for later use
    prefix_dict = {}
    i = 0
    while dlog_rules[i][0] == 'PREFIX':
        if dlog_rules[i][2][-1:] == '\n':
            dlog_rules[i][2] = dlog_rules[i][2][:-1]
        prefix_dict[dlog_rules[i][1][:-1]] = dlog_rules[i][2][1:-1]  # Also get rid of < >
        i += 1

    # Skip however many more blank lines required
    while dlog_rules[i][0] == '\n':
        i += 1

    # Now start interpreting dlog rules and converting to RDF triples
    dataset = []
    node_dicts = []
    masks = []
    print("Processing rules")
    while i < len(dlog_rules):
        if i % 10 == 0: print("Line {}  of {}".format(i, len(dlog_rules)))
        # We can always ignore the final element in each array, because it's
        # always '.\n'
        dlog_rules[i] = dlog_rules[i][:-1]
        body = dlog_rules[i][2:]  # since 2nd element is :-, all after is body
        for j, pred in enumerate(body):  # clean off trailing commas signifying \land
            if pred[-1] == ',':
                body[j] = pred[:-1]

        head = [dlog_rules[i][0]]

        if expand_neighbourhood:
            RDF_bodies = [dlog_to_RDF(body, prefix_dict), dlog_to_RDF(body, prefix_dict)]
        else:
            RDF_bodies = [dlog_to_RDF(body, prefix_dict)]

        RDF_head = dlog_to_RDF(head, prefix_dict)

        for j, RDF_body in enumerate(RDF_bodies):
            # The first two bodies represent the normal rule and the normal
            # rule with dummy constants added. Dummy constants are requested
            # by passing the training=True bool to the encode function.
            # Therefore, we include the dummy constants in our first training
            # example but not the second.
            (x, edge_list, edge_type,
             node_dict) = encode_input_dataset(encoding_scheme, RDF_body,
                                               binaryPredicates, unaryPredicates,
                                               training=j == 0 and expand_neighbourhood)
            y, _, _, y_node_dict, _, _, _ = encode_input_dataset(encoding_scheme,
                                                        RDF_head,
                                                        binaryPredicates,
                                                        unaryPredicates,
                                                        node_dict=node_dict,
                                                        edge_list=edge_list,
                                                        edge_type_list=edge_type)
            assert (node_dict == y_node_dict)
            node_dicts.append(node_dict)

            # Initialize our mask as a tensor of ones - will modify this later
            mask = torch.ones(y.size())

            # Add facts to RDFox server
            turtle_data = '\n'.join(RDF_body)
            response = requests.post(
                rdfox_server + "/datastores/{}/content".format(dataset_name),
                data=turtle_data)
            assert_response_ok(response, "Failed to add facts to datastore.")

            # Return all entailed facts
            sparql_text = "SELECT ?p ?r ?q WHERE {{ ?p ?r ?q }}"
            response = requests.get(
                rdfox_server + "/datastores/{}/sparql".format(dataset_name),
                params={"query": sparql_text})
            assert_response_ok(response, "Failed to run return entailed facts.")

            # Format entailed facts
            entailed_triples = [tuple(x.split('\t')) for x in response.text.split('\n')][1:-1]
            # Link entailed facts to their constants in hash map (Python dict)
            # for fast fact checking later.
            entailed_dict = {}
            for triple in entailed_triples:
                if triple[1] == "<{}>".format(RDF_type_string):
                    # Then the arity of the predicate is 1
                    if triple[0] not in entailed_dict.keys():
                        entailed_dict[triple[0]] = set()
                    entailed_dict[triple[0]].add(triple[2])
                else:
                    # Then arity is 2
                    const_pair = (triple[0], triple[2])  # Note: not sorted, as the ordering is important
                    if const_pair not in entailed_dict.keys():
                        entailed_dict[const_pair] = set()
                    entailed_dict[const_pair].add(triple[1])

            for node_index, node_feature_vec in enumerate(y):
                # First test if node is single or pair of constants
                # If single, we'll just test for the unary predicates
                # If pair, we'll just test the binary predicates
                if type(node_dict[node_index]) is tuple:
                    for predicate_index, predicate_bool in enumerate(node_feature_vec[:num_binary]):
                        # We'll test greater than 0.5 to avoid floating point
                        # equality errors.
                        # If a bool is currently 0, we want to check it's not
                        # entailed by another rule.
                        if predicate_bool < 0.5:
                            consts = node_dict[node_index]
                            predicate = binaryPredicates[predicate_index]
                            if consts in entailed_dict.keys():
                                if predicate in entailed_dict[consts]:
                                    mask[node_index][predicate_index] = 0.0
                else:  # Then just a single constant associated with this node
                    for predicate_index, predicate_bool in enumerate(
                            node_feature_vec[num_binary:]):  # Note we now just go through the unary predicates
                        # Avoid testing floating point equality as above
                        if predicate_bool < 0.5:
                            const = node_dict[node_index]
                            predicate = unaryPredicates[predicate_index]
                            if const in entailed_dict.keys():
                                if predicate in entailed_dict[const]:
                                    mask[node_index][
                                        num_binary + predicate_index] = 0.0  # Add the offset of all the binary predicates

            # Delete all facts
            sparql_text = "DELETE { ?p ?r ?q } WHERE { ?p ?r ?q }"
            response = requests.get(
                rdfox_server + "/datastores/{}/sparql".format(dataset_name),
                params={"update": sparql_text})
            assert_response_ok(response, "Failed to delete facts.")

            # Construct a PTG Data object
            data = Data(x=x, y=torch.cat((y, mask), 1), edge_index=edge_list,
                        edge_type=edge_type)
            dataset.append(data)
            masks.append(mask)
        i += 1
    return dataset, node_dicts


def index_to_pred(index, node_dict, binaryPredicates, unaryPredicates, num_binary):
    '''Convert an (i,j) index into a fact by querying the node dictionary and
    converting the predicate index into a predicate name.'''
    consts = node_dict[index[0]]
    if index[1] < num_binary:
        pred = binaryPredicates[index[1]]
    else:
        pred = unaryPredicates[index[1] - num_binary]
    return "{}({})".format(pred, consts)


def load_predicates(dataset_name):
    '''Load the predicates from their file into memory, return them.'''
    # Lists to store binary and unary predicates
    binaryPredicates = []
    unaryPredicates = []

    file_path = './predicates/{}_predicates.csv'.format(dataset_name)

    try:
        with open(file_path, 'r') as f:
            for line in f:
                # Every line is of form "predicate,arity"
                pair = line.split(',')
                if int(pair[1][:-1]) == 1:  # [:-1] to get rid of \n
                    unaryPredicates.append(pair[0])
                else:
                    binaryPredicates.append(pair[0])
        return binaryPredicates, unaryPredicates

    except FileNotFoundError:
        raise FileNotFoundError('Predicates csv file for {} dataset not found.'.format(dataset_name))


def find_predicates(ttl_file_path, dlog_file_path, output_file_path):
    '''Read datasets of facts and rules, write a CSV containing every predicate
    found. Attempts to essentially find the Signature that we're working with.'''

    assert ttl_file_path[-3:] == ".nt" or ttl_file_path[-4:] == ".ttl"
    assert (dlog_file_path[-5:] == ".dlog")
    assert (output_file_path[-4:] == ".csv")

    binaryPredicates = set()
    unaryPredicates = set()
    with open(ttl_file_path, 'r') as f:
        for i, RDF_triple in tqdm(enumerate(f)):
            RDF_list = re.split(''' (?=(?:[^'"]|'[^']*'|"[^"]*")*$)''', RDF_triple)
            # The second item in this list is either the binary predicate, or
            # an RDF type, in which case the third item is a unary predicate
            if RDF_list[1] == '<http://www.w3.org/1999/02/22-rdf-syntax-ns#type>':
                unaryPredicates.add(RDF_list[2])
            else:
                if RDF_list[1] not in binaryPredicates:
                    print(RDF_list[1])
                binaryPredicates.add(RDF_list[1])

                if 'synthesis' in RDF_list[1]:
                    print(RDF_triple)
                    break

    with open(dlog_file_path) as f:
        dlog_rules = [x.split(' ') for x in f]

        # The first part of the dlog rules file is made up of prefixes
        # We'll put these in a dictionary for later use
        prefix_dict = {}
        i = 0
        while dlog_rules[i][0] == 'PREFIX':
            if dlog_rules[i][2][-1:] == '\n':
                dlog_rules[i][2] = dlog_rules[i][2][:-1]
            prefix_dict[dlog_rules[i][1][:-1]] = dlog_rules[i][2][1:-1]  # Also get rid of < >
            i += 1
        # Skip however many more blank lines required
        while dlog_rules[i][0] == '\n':
            i += 1
        while i < len(dlog_rules):
            print(i)
            # remove .\n off the end:
            if dlog_rules[i][-1] == '.\n':
                dlog_rules[i] = dlog_rules[i][:-1]
            for entry in dlog_rules[i]:
                if entry != ':-':
                    arity = entry.count('?')
                    prefix, predicate = entry.split(':')
                    predicate = predicate.split('[')[0]  # get rid of the variables, just keep predicate name
                    if arity == 2:
                        binaryPredicates.add("<{}{}>".format(prefix_dict[prefix], predicate))
                    elif arity == 1:
                        unaryPredicates.add("<{}{}>".format(prefix_dict[prefix], predicate))
                    else:
                        # Mysterious new arity discovered
                        raise Exception("Arity of " + arity + " included in " +
                                        "rules file. Allowed arities = 1,2")
            i += 1

    with open(output_file_path, mode='w') as output_file:
        writer = csv.writer(output_file, delimiter=',')
        for pred in binaryPredicates:
            writer.writerow([pred, 2])
        for pred in unaryPredicates:
            writer.writerow([pred, 1])


def encode_input_dataset(encoding_scheme, input_dataset, query_dataset, binaryPredicates, unaryPredicates,
                         training=False):
    """Encode a dataset as a graph, which is created anew.
       This is not quite like the encoding in the paper, because we add edges and nodes for
       the union of the initial dataset and the entailed dataset, but use the k-hot encodings only for the
       initial dataset."""

    start_time = time.time()

    #  Calculate length of feature vectors, and assign a vector position to each predicate.
    #print("Assigning predicates to positions of feature vectors...")
    num_binary = len(binaryPredicates)
    num_unary = len(unaryPredicates)
    feature_dimension = num_binary + num_unary
    #  This dictionary maps each predicate to the corresponding position of the feature vectors.
    pred_dict = {}
    for i, pred in enumerate(binaryPredicates):
        pred_dict[pred] = i
    for i, pred in enumerate(unaryPredicates):
        pred_dict[pred] = num_binary + i
    #print("Done in {} s.".format(time.time()-start_time))
    start_time = time.time()

    #  Extract a list of constants, pairs of constants, and a mapping from each constant 'a' to all predicates A
    #  such that A(a) is in the dataset, and each pair 'a,b' to all predicates R such that R(a,b) is in the dataset.
    #print("Transforming dataset into lists and dictionaries...")
    all_constants, all_pairs_of_constants, input_dataset_constants_to_pred_dict,\
    query_dataset_constants_to_pred_dict = process(input_dataset, query_dataset, training)
    #print("Done in {} s.".format(time.time()-start_time))
    start_time = time.time()

    #print("Creating nodes...")
    # A node for each individual constant
    singleton_nodes = list(all_constants)
    # Number of Singleton Nodes
    num_singleton_nodes = len(singleton_nodes)
    #  A node for each pair of constants
    pair_nodes = set()
    #for pair in all_pairs_of_constants.keys():
    for pair in all_pairs_of_constants:
        pair_nodes.add(tuple((pair)))
        pair_nodes.add((pair[1], pair[0]))
    pair_nodes = list(pair_nodes)
    #  A map from every constant or pair of constants to the corresponding singleton node.
    #  In this step, we only add mappings from constants to singleton nodes.
    const_node_dict = {const: i for i, const in enumerate(singleton_nodes)}
    #  List of all nodes.
    nodes = singleton_nodes + pair_nodes
    #print("Done in {} s.".format(time.time() - start_time))
    start_time = time.time()

    #print("Creating list of edges...")
    #  A list of all the edges in the graph, expressed as pairs of integers corresponding to nodes.
    edge_list = []
    #  A list of colours, with:
    #  --0 for edges connecting v_a,b with v_a (and vice versa)
    #  --1 for edges connecting v_b,a with v_a (and vice versa)
    #  --2 for edges connecting v_a,b with v_b,a
    edge_type_list = []
    #  This is a set of all pairs of constants, but expressed as nodes (i.e. indices) instead.
    pairs_as_nodes = set()
    for i, pair in enumerate(pair_nodes):
        # Link each pair to just the node corresponding to its first constant
        edge_list.append((const_node_dict[pair[0]], i + num_singleton_nodes))
        edge_type_list.append(0)
        edge_list.append((i + num_singleton_nodes, const_node_dict[pair[0]]))
        edge_type_list.append(0)
        if encoding_scheme == 'EC':
            # Now link each pair to the node corresponding to its second constant, now of type 1
            edge_list.append((const_node_dict[pair[1]], i + num_singleton_nodes))
            edge_type_list.append(1)
            edge_list.append((i + num_singleton_nodes, const_node_dict[pair[1]]))
            edge_type_list.append(1)
        if (pair[1], pair[0]) in const_node_dict.keys():
            # Link to reversed version of this node, of type 2
            edge_list.append((const_node_dict[(pair[1], pair[0])], i + num_singleton_nodes))
            edge_type_list.append(2)
            edge_list.append((i + num_singleton_nodes, const_node_dict[(pair[1], pair[0])]))
            edge_type_list.append(2)
        #  This adds mappings from pairs of constants to the corresponding pair nodes, and completes the
        #  definition of const_node_dict.
        const_node_dict[tuple((pair))] = i + num_singleton_nodes
        #  Fill the field pairs_as_nodes as expected.
        pairs_as_nodes.add((const_node_dict[pair[0]], const_node_dict[pair[1]]))
        pairs_as_nodes.add((const_node_dict[pair[1]], const_node_dict[pair[0]]))
    # Also link every pair of single constants for which there exists a
    # binary predicate in the dataset:
    edge_list = edge_list + list(pairs_as_nodes)
    edge_type_list = edge_type_list + [3 for _ in pairs_as_nodes]
    #print("Done in {} s.".format(time.time() - start_time))
    start_time = time.time()

    #print("Constructing additional return objects...")
    #  Return variables
    #print("Node to constant dictionary")
    node_to_const_dict = {index: constant for index, constant in enumerate(nodes)}
    #print("Edge list")
    return_edge_list = torch.LongTensor(edge_list).t().contiguous()
    #print("Edge type list")
    return_edge_type_list = torch.LongTensor(edge_type_list)
    if len(return_edge_list) == 0:
        return_edge_list = torch.LongTensor([[], []])
    #print("Graph input")
    # Now create x vectors:
    x = np.zeros((len(nodes), feature_dimension))
    for const in input_dataset_constants_to_pred_dict.keys():
        const_index = const_node_dict[const]
        for pred in input_dataset_constants_to_pred_dict[const]:
            pred_index = pred_dict[pred]
            x[const_index][pred_index] = 1
    x = torch.FloatTensor(x)
    #print("Query mask")
    # Now create x vectors:
    x_query = np.zeros((len(nodes), feature_dimension))
    exists_nonrepresentable_query = False
    for const in query_dataset_constants_to_pred_dict.keys():
        # Here is where we drop all query facts that are not representable in the dataset
        try:
            const_index = const_node_dict[const]
            for pred in query_dataset_constants_to_pred_dict[const]:
                pred_index = pred_dict[pred]
                x_query[const_index][pred_index] = 1
        except:
            if not exists_nonrepresentable_query:
                exists_nonrepresentable_query = True
                print("Warning: some examples cannot be represented in the encoding of the input graph.")
    x_query = torch.FloatTensor(x_query)
    #print("Done in {} s.".format(time.time() - start_time))

    return x, return_edge_list, return_edge_type_list, node_to_const_dict, const_node_dict, pred_dict, x_query


def process(input_dataset, query_dataset, training=False):
    #  List of all constants mentioned in the dataset.
    all_constants = set()
    #  Dictionary of all pairs of constants (a,b) such that fact R(a,b) is in the dataset, mapping the pair to
    #  the set all predicates R such that R(a,b) is in the dataset.
    #all_pairs_of_constants = {}
    all_pairs_of_constants = set()
    #  Dictionary mapping each constant in the dataset 'a' to all predicates for facts of the form A(a),
    #  and each pair '(a,b)' of constants in the dataset to all predicates for facts R(a,b) in the dataset.
    input_dataset_constants_to_predicates_dict = {}
    query_dataset_constants_to_predicates_dict = {}

    for RDF_triple in input_dataset:
        # Credit to this Stack Overflow post:
        # https://stackoverflow.com/questions/2785755/how-to-split-but-ignore-separators-in-quoted-strings-in-python
        # RDF_list = re.split(''' (?=(?:[^'"]|'[^']*'|"[^"]*")*$)''', RDF_triple)
        RDF_list = str.split(RDF_triple)
        if RDF_list[1] == '<http://www.w3.org/1999/02/22-rdf-syntax-ns#type>':
            pred = RDF_list[2]
            constants = RDF_list[0]
            all_constants.add(RDF_list[0])
        else:
            pred = RDF_list[1]
            constants = (RDF_list[0], RDF_list[2])  # Not sorted
            #if constants not in all_pairs_of_constants.keys():
            #    all_pairs_of_constants[constants] = set()
            #all_pairs_of_constants[constants].add(pred)
            all_pairs_of_constants.add(constants)
            all_constants.add(RDF_list[0])
            all_constants.add(RDF_list[2])

        if constants not in input_dataset_constants_to_predicates_dict.keys():
            input_dataset_constants_to_predicates_dict[constants] = set()
        input_dataset_constants_to_predicates_dict[constants].add(pred)

    for RDF_triple in query_dataset:
        # Credit to this Stack Overflow post:
        # https://stackoverflow.com/questions/2785755/how-to-split-but-ignore-separators-in-quoted-strings-in-python
        # RDF_list = re.split(''' (?=(?:[^'"]|'[^']*'|"[^"]*")*$)''', RDF_triple)
        RDF_list = str.split(RDF_triple)
        if RDF_list[1] == '<http://www.w3.org/1999/02/22-rdf-syntax-ns#type>':
            pred = RDF_list[2]
            constants = RDF_list[0]
            # Toggled: all_constants.add(RDF_list[0])
        else:
            pred = RDF_list[1]
            constants = (RDF_list[0], RDF_list[2])  # Not sorted
            # if constants not in all_pairs_of_constants.keys():
            #    all_pairs_of_constants[constants] = set()
            #all_pairs_of_constants[constants].add(pred)
            if training:
                all_pairs_of_constants.add(constants)
                all_constants.add(RDF_list[0])
                all_constants.add(RDF_list[2])

        if constants not in query_dataset_constants_to_predicates_dict.keys():
            query_dataset_constants_to_predicates_dict[constants] = set()
        query_dataset_constants_to_predicates_dict[constants].add(pred)

#    if training:  # To reduce false negatives, we add in dummy constants
#        special_constants = ['#', '##']
#        for special_constant in special_constants:
#            query_dataset_constants_to_predicates_dict[special_constant] = set()
#            for constant in all_constants:
#                all_pairs_of_constants[(special_constant, constant)] = set()
#            all_constants.add(special_constant)

    return all_constants, all_pairs_of_constants, input_dataset_constants_to_predicates_dict, \
           query_dataset_constants_to_predicates_dict


def encode_entailed_dataset(dataset, node_to_const_dict, const_to_node_dict, pred_dict):
    """Given a labelled graph as two (inverse of each other) maps node_to_const and const_to_node,
     a predicate dictionary, and a dataset, encodes the dataset as k-hot vectors using the
     given node and predicate encoding."""
    #  A non-empty original graph has to be given.
    assert node_to_const_dict is not None, "Attempted to encode an entailed dataset on an empty graph." \
                                           " Feature not yet supported."
    assert const_to_node_dict is not None, "Attempted to encode an entailed dataset on an empty graph." \
                                           " Feature not yet supported."
    #  Ensure that node_to_const and const_to_node are inverse of each other.
    for const in const_to_node_dict:
        assert node_to_const_dict[const_to_node_dict[const]] == const, "Error: graph representation is not bijective;" \
                                                                       "this is definitely a bug."
    #  The signature must have at least a predicate.
    feature_dimension = len(pred_dict.keys())
    assert feature_dimension > 0, "Error: feature dimension is 0 or negative. Check your input."

    #  Process dataset
    all_constants, all_pairs_of_constants, dataset_constants_to_predicates_dict, _ = process(dataset, set())

    #  Ensure compatibility of the dataset with the given graph
    for predicate_set in dataset_constants_to_predicates_dict.values():
        for pred in predicate_set:
            assert pred in pred_dict.keys()

    #  Generate output.
    num_nodes = len(node_to_const_dict.keys())
    x = np.zeros((num_nodes, feature_dimension))
    for const in dataset_constants_to_predicates_dict.keys():
        const_index = const_to_node_dict[const]
        for pred in dataset_constants_to_predicates_dict[const]:
            pred_index = pred_dict[pred]
            x[const_index][pred_index] = 1
    x = torch.FloatTensor(x)
    return x


def remove_unencodable(dataset, node_dict):
    '''Given an encoding of a graph, and a set of facts, this removes
    all facts that cannot be encoded in this dataset e.g. because they
     have constants not mentioned in it, or because there is no supporting
     edge in the graph.'''
    nodes_as_constants = [node_dict[key] for key in node_dict.keys()]

    filtered_rdf_triples = []
    for RDF_triple in dataset:
        # Credit to this Stack Overflow post:
        # https://stackoverflow.com/questions/2785755/how-to-split-but-ignore-separators-in-quoted-strings-in-python
        RDF_list = re.split(''' (?=(?:[^'"]|'[^']*'|"[^"]*")*$)''', RDF_triple)
        if RDF_list[1] == '<http://www.w3.org/1999/02/22-rdf-syntax-ns#type>':
            constant = RDF_list[0]
            if constant in nodes_as_constants:
                filtered_rdf_triples.append(RDF_triple)
        else:
            constants = (RDF_list[0], RDF_list[2])  # Not sorted
            if constants in nodes_as_constants:
                filtered_rdf_triples.append(RDF_triple)
    print("Removed " + str(len(dataset) - len(filtered_rdf_triples)) + " unencodable triples from a total of " +
          str(len(dataset)))
    return filtered_rdf_triples


def decode(node_dict, num_binary, num_unary, binaryPredicates, unaryPredicates,
           feature_vectors, threshold):
    '''Decode feature vectors back into a dataset.'''
    threshold_indices = torch.nonzero(feature_vectors >= threshold)
    GNN_dataset = set()
    for i, index in enumerate(threshold_indices):
        index = index.tolist()
        const_index = index[0]
        pred_index = index[1]
        const = node_dict[const_index]
        if type(const) is tuple:  # Then we just want to consider this if it's in the binary preds
            if pred_index < num_binary:
                predicate = binaryPredicates[pred_index]
                RDF_triplet = "{}\t{}\t{}".format(const[0], predicate, const[1])
                GNN_dataset.add(RDF_triplet)
        else:  # Then we're dealing with a unary predicate (second section of the vec)
            if pred_index >= num_binary:
                predicate = unaryPredicates[pred_index - num_binary]
                RDF_triplet = "{}\t<http://www.w3.org/1999/02/22-rdf-syntax-ns#type>\t{}".format(const, predicate)
                GNN_dataset.add(RDF_triplet)
    return GNN_dataset


def decode_with_scores(node_dict, num_binary, num_unary, binaryPredicates, unaryPredicates, feature_vectors,
                       query_tensor):
    '''Decode feature vectors and give corresponding scores.'''
    # We only need the scores of the query atoms, so pass only those.  
    assert query_tensor.size() == feature_vectors.size(), "ERROR: GNN input and output have different dimentsions."
    feature_vectors = feature_vectors * query_tensor
    nonnegative_indices = torch.nonzero(feature_vectors)
    scored_GNN_dataset = set()

    for i, index in enumerate(nonnegative_indices):
        index = index.tolist()
        const_index = index[0]
        pred_index = index[1]
        score = feature_vectors[const_index][pred_index]
        const = node_dict[const_index]
        if type(const) is tuple:  # Then we just want to consider this if it's in the binary preds
            if pred_index < num_binary:
                predicate = binaryPredicates[pred_index]
                scored_RDF_triplet = "{}\t{}\t{}\t{}".format(const[0], predicate, const[1], score)
                scored_GNN_dataset.add(scored_RDF_triplet)
        else:  # Then we're dealing with a unary predicate (second section of the vec)
            if pred_index >= num_binary:
                predicate = unaryPredicates[pred_index - num_binary]
                scored_RDF_triplet = "{}\t<http://www.w3.org/1999/02/22-rdf-syntax-ns#type>\t{}\t{}".format(const,
                                                                                                            score)
                scored_GNN_dataset.add(scored_RDF_triplet)
    return scored_GNN_dataset


def decode_and_get_threshold(node_dict, num_binary, num_unary, binaryPredicates, unaryPredicates,
                             feature_vectors, threshold):
    '''Decode feature vectors back into a dataset.
    Additionally report back the threshold at which all facts in the dataset would no longer be predicted'''
    threshold_indices = torch.nonzero(feature_vectors >= threshold)
    GNN_dataset = set()
    for i, index in enumerate(threshold_indices):
        index = index.tolist()
        const_index = index[0]
        pred_index = index[1]
        extraction_threshold = feature_vectors[index[0], index[1]]
        const = node_dict[const_index]
        if type(const) is tuple:  # Then we just want to consider this if it's in the binary preds
            if pred_index < num_binary:
                predicate = binaryPredicates[pred_index]
                RDF_triplet = "{} {} {}".format(const[0], predicate, const[1])
                GNN_dataset.add((RDF_triplet, extraction_threshold))
        else:  # Then we're dealing with a unary predicate (second section of the vec)
            if pred_index >= num_binary:
                predicate = unaryPredicates[pred_index - num_binary]
                RDF_triplet = "{} <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> {}".format(const, predicate)
                GNN_dataset.add((RDF_triplet, extraction_threshold))
    return GNN_dataset


def predict_entailed_fast(encoding_scheme, model, binaryPredicates,
                          unaryPredicates, dataset, query_dataset, max_iterations=1,
                          threshold=0.5, device='cpu'):
    '''Predict what facts are entailed by a given GNN. Use
    max_iterations = None if you want to continue until fixpoint.'''
    num_binary = len(binaryPredicates)
    num_unary = len(unaryPredicates)

    all_entailed_facts = set()
    all_facts_returned = False
    num_iterations = 1
    while not all_facts_returned:
        print("GNN iteration {}".format(num_iterations), end='\r')
        (dataset_x, edge_list, edge_type,
         node_to_const_dict, dataset_const_to_node_dict, pred_dict) = encode_input_dataset(encoding_scheme,
                                                                                           dataset, query_dataset,
                                                                                           binaryPredicates,
                                                                                           unaryPredicates)
        test_data = Data(x=dataset_x, edge_index=edge_list,
                         edge_type=edge_type).to(device)
        entailed_facts_encoded = model(test_data)
        entailed_facts_decoded = decode(node_to_const_dict, num_binary,
                                        num_unary, binaryPredicates,
                                        unaryPredicates,
                                        entailed_facts_encoded, threshold)
        if len(entailed_facts_decoded.difference(all_entailed_facts)) == 0:
            # Then no new facts have been entailed
            all_facts_returned = True
            print('\n')
            print("No change in entailed dataset")
        else:
            all_entailed_facts = all_entailed_facts.union(entailed_facts_decoded)
            dataset = dataset.union(entailed_facts_decoded)
            if max_iterations is not None:
                if num_iterations >= max_iterations:
                    all_facts_returned = True
        num_iterations += 1
    return all_entailed_facts


def output_scores(encoding_scheme, model, binaryPredicates, unaryPredicates, incomplete_graph, examples, device='cpu'):
    '''Give the scores for the facts in the query dataset.'''
    num_binary = len(binaryPredicates)
    num_unary = len(unaryPredicates)
    print("Encoding input dataset...")
    (dataset_x, edge_list, edge_type,
     node_to_const_dict, dataset_const_to_node_dict, pred_dict, dataset_x_query) = encode_input_dataset(encoding_scheme,
                                                                                       incomplete_graph,
                                                                                       examples,
                                                                                       binaryPredicates,
                                                                                       unaryPredicates)
    print("Encapsulating input data...")
    test_data = Data(x=dataset_x, edge_index=edge_list, edge_type=edge_type).to(device)
    print("Applying model to data...")
    entailed_facts_encoded = model(test_data)
    print("Decoding...")
    nonzero_scores_and_facts = decode_with_scores(node_to_const_dict, num_binary, num_unary, binaryPredicates,
                                                  unaryPredicates, entailed_facts_encoded, dataset_x_query)
    print("Done.")
    return nonzero_scores_and_facts
