from argparse import ArgumentParser
import requests

parser = ArgumentParser()
parser.add_argument('--rules', help="File with all the rules learned by AnyBURL")
parser.add_argument('--dataset', help="Incompete graph for testing")
parser.add_argument('--threshold', help="Threshold selected for testing")
parser.add_argument('--queries', help="Queries for MeGaNN")
parser.add_argument('--scores', help="Scores for the queries for MeGaNN")


args = parser.parse_args()

rdfox_server = "http://localhost:8080"
type_predicate = 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' 


def anyburl_atom_to_datalog_atom(atom):
    # Anyburl atoms are of the form string(X,Y) 
    predicate = atom.split('(')[0]
    var1 = atom.split('(')[1][0]
    var2 = atom.split('(')[1][2]
    # Datalog atoms are of the form <string>[?X,?Y]
    return '<' + predicate + '>[?' + var1 + ',?' + var2 + ']'


def anyburl_rule_to_datalog_rule(the_rule):
    # Anyburl rules are of the form h(X,Y) <= p(X,Z), q(Y,Z)\n 
    head = the_rule.split(' <= ')[0]
    body = the_rule.split(' <= ')[1].split(', ')
    body[-1] = body[-1][:-1]
    newhead = anyburl_atom_to_datalog_atom(head)
    newbody = []
    for atom in body:
        newbody.append(anyburl_atom_to_datalog_atom(atom))
    # Datalog rules are of the form <h>[?X,?Y] :- <p>[?X,?Z], <q>[?Y,?Z] .\n 
    return newhead + ' :- ' + ', '.join(newbody) + ' .\n' 


def assert_response_ok(response, message):
    """Helper function to raise an exception if the REST endpoint returns an
    unexpected status code."""
    if not response.ok:
        raise Exception(
            message + "\nStatus received={}\n{}".format(response.status_code,
                                                        response.text))


# Read and store dataset in memory
dataset = [] 
dataset_file = open(args.dataset, 'r')
dataset_lines = dataset_file.readlines()
for line in dataset_lines:
    # Remove end of line character 
    if line.endswith('\n'):
        line = line[:-1]
    ent1, ent2, ent3 = line.split('\t',2)
    fact = '<{}> <{}> <{}> .\n'.format(ent1, ent2, ent3)
    dataset.append(fact)
print("{} facts read".format(len(dataset)))

# Read and store rules in memory
rules_to_entailed_facts = {} 
rules_file = open(args.rules, "r")
rules_lines = rules_file.readlines()
for line in rules_lines:
    # For each line, the rule is in the fourth column
    rule = line.split('\t', 3)[3]
    # Transform AnyBURL rule to Datalog rule
    rule = anyburl_rule_to_datalog_rule(rule) 
    # Each rule is mapped to the set of facts it entails on the dataset, currently empty 
    rules_to_entailed_facts[rule] = set()
print("{} rules read".format(len(rules_to_entailed_facts)))

print("Applying rules to dataset...")
counter = 0
# Compute the set facts entailed by each rule over the dataset
for rule in rules_to_entailed_facts:
    counter += 1
    if counter % 100 == 0:
        print("Rules processed: {}".format(counter))

    # print("Transforming body atoms into facts...")
    # Create data store in the RDF server, or reset it if it already exists.
    if requests.get(rdfox_server + "/datastores").text == '?Name\n':
        # Create the datastore
        response = requests.post(rdfox_server + "/datastores/temp", params={'type': 'par-complex-nn'})
        assert_response_ok(response, "Failed to create datastore.")
    else:
        response = requests.delete(rdfox_server + "/datastores/temp/content")
        assert_response_ok(response, "Failed to clear content from datastore.")

    # print("Loading dataset to RDFox...")
    # Send dataset to RDFox data store.
    to_server = "" 
    for fact in dataset:
        to_server += fact   
    response = requests.post(rdfox_server + "/datastores/temp/content", data=to_server)
    assert_response_ok(response, "Failed to add dataset to datastore.")

    # print("Loading rule to RDFox...")
    # Send rule to RDFox data store.
    response = requests.post(rdfox_server + "/datastores/temp/content", data=rule)
    assert_response_ok(response, "Failed to add rule {} to datastore.".format(rule))

    # print("Answering query...")
    # Return all entailed facts
    # This is a dirty trick: in each rule, the head predicate must be the piece of string before the first '[' 
    head_predicate = rule.split('[')[0]
    sparql_text = "SELECT ?p ?q WHERE {{ ?p {} ?q  }}".format(head_predicate)
    response = requests.get(rdfox_server + "/datastores/temp/sparql", params={"query": sparql_text})
    assert_response_ok(response, "Failed to run return entailed facts.") 
    
    # print("Storing answers...")
    # The response file has a first line with column names, ?p and ?q, and a last empty line. We get rid of those.
    candidates = response.text.split('\n')[1:-1]
    for answer in candidates:
        # Each answer is of the form <http://oxfordsemantic.techString1>\t<http://oxfordsemantic.techString2>
        # Or of the form <http://oxfordsemantic.tech/RDFox/String1>\t<http://oxfordsemantic.tech/RDFox/String2>
        # Notice that we remove '<' and '>' because MeGaNN has learned without them
        ent1 = answer.split('\t')[0][1:-1]
        ent2 = answer.split('\t')[1][1:-1]
        # Remove the prefix added by RDFox
        if ent1.startswith("http://oxfordsemantic.tech/RDFox/"):
            ent1 = ent1[33:]
        if ent2.startswith("http://oxfordsemantic.tech/RDFox/"):
            ent2 = ent2[33:]
        if ent1.startswith("http://oxfordsemantic.tech"):
            ent1 = ent1[26:]
        if ent2.startswith("http://oxfordsemantic.tech"):
            ent2 = ent2[26:]
        rules_to_entailed_facts[rule].add("{}\t{}\t{}".format(ent1, head_predicate[1:-1], ent2))

print("Writing queries for MGNN...")
# Write query facts for MGNN
all_facts = set()
for rule in rules_to_entailed_facts:
    for fact in rules_to_entailed_facts[rule]:
        all_facts.add(fact)
with open(args.queries, 'w') as queries:
    for fact in all_facts: 
        queries.write(fact + '\n')
    queries.close()

# When the program is run for the first time, it will throw an error here.

print("Reading scores file...")
# Process scores file
facts_to_scores = {} 
scores_file = open(args.scores, 'r')
scores_lines = scores_file.readlines()
for line in scores_lines:
    # Remove end of line character 
    if line.endswith('\n'):
        line = line[:-1]
    ent1, ent2, ent3, score = line.split('\t', 3)
    facts_to_scores["{}\t{}\t{}".format(ent1, ent2, ent3)] = float(score)

print("Checking number of rules captured on dataset...")
counter = 0
for rule in rules_to_entailed_facts:
    captured = True 
    for fact in rules_to_entailed_facts[rule]:
        if facts_to_scores.get(fact, 0) < float(args.threshold):
            captured = False
    if captured:
        counter += 1

print("MGNN captures {} rules out of {} rules learned by AnyBURL on this dataset".format(counter,
                                                                                         len(rules_to_entailed_facts)))
