#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: ----
"""

import torch
from torch_geometric.data import Data

from utils import encode_input_dataset, decode_and_get_threshold, load_predicates

import requests

import argparse

from itertools import combinations_with_replacement, product, permutations

from math import factorial

from tqdm import tqdm

import re

parser = argparse.ArgumentParser(description="Extract rules from trained GNNs")
parser.add_argument('--dataset-name',
                    help='Name of the dataset, for saving and loading files')
parser.add_argument('--encoding-scheme',
                    default='EC',
                    nargs='?',
                    choices=['NEC', 'EC'],
                    help='Choose whether to encode with edge colours or not')
parser.add_argument('--threshold',
                    type=float,
                    default=0.5,
                    help='Maximum threshold in the last layer of the model for which the rules are guaranteed to be captured')
parser.add_argument('--load-model-name',
                    help='Filename to load trained model')
parser.add_argument('--prefix',
                    help='Prefix that should be added to the extracted rules.',
                    default=None)
parser.add_argument('--max-atoms-in-body',
                    type=int,
                    default=2)

args = parser.parse_args()

if __name__ == "__main__":

    assert args.max_atoms_in_body > 0, "Number of atoms must be a positive integer"
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = torch.load("./models/{}.pt".format(args.load_model_name))
    model.to(device)
    
    # Load binary and unary predicates into memory
    binaryPredicates, unaryPredicates = load_predicates(args.dataset_name)
    num_binary = len(binaryPredicates)
    num_unary = len(unaryPredicates)
    mask_threshold = num_binary + num_unary

    total_num_rules = 0
    num_rules_duplicate_variable = 0

    def RDF_to_dlog(RDF_string):
        '''Convert RDF to Datalog predicate.'''
        #  print(RDF_string)
        RDF_list = re.split(''' (?=(?:[^'"]|'[^']*'|"[^"]*")*$)''', RDF_string)
        if RDF_list[1] == '<http://www.w3.org/1999/02/22-rdf-syntax-ns#type>':
            return "<{}>[?{}]".format(RDF_list[2], #.split('#')[1],
                                      RDF_list[0][1:-1].upper())
        else:
            return"<{}>[?{},?{}]".format(RDF_list[1],  # RDF_list[1][1:-1].split('#')[1],
                                          RDF_list[0][1:-1].upper(),  # RDF_list[0][1:-1].upper(),
                                          RDF_list[2][1:-1].upper())  # RDF_list[2][1:-1].upper())


    def create_rule_body(predicates, variables, pred_dict):
        body = set()
        variable_list = list(variables)
        variable_list.reverse()
        # Nested list of variables in each predicate
        variables_in_body = []
        for pred in predicates:
            arity = pred_dict[pred]
            assert arity == 1 or arity == 2, "Arity must be 1 or 2"
            if arity == 1:
                v = variable_list.pop()
                variables_in_body.append([v])
                atom = "<{}> <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> {} .".format(v, pred)
            else:
                v1 = variable_list.pop()
                v2 = variable_list.pop()
                # Don't allow repeated variables in body as this won't be a tree shaped rule body
                if v1 == v2:
                    # Then we have a duplicate atom, return None representing a useless rule body
                    return 'Duplicate'
                variables_in_body.append([v1, v2])
                atom = "<{}> {} <{}> .".format(v1, pred, v2)
            if atom in body:
                return None
            else:
                body.add(atom)
        if no_loose_variables(variables_in_body):
            return body
        else:
            return None
    
    def no_loose_variables(variables_in_body):
        if len(variables_in_body) == 1:
            # The below check doesn't work for single-atom bodies
            return True
        for i, atom in enumerate(variables_in_body):
            # This function returns false if there are no variables in this atom which
            # are present in at least one other atom
            # This is because this rule body wont' be a tree, but instead a forest
            if not any([any([v in other_atom for j, other_atom in enumerate(variables_in_body) if j != i]) for v in atom]):
                return False
        return True

    
    class node:
        def __init__(self, dataval=None, prevval=None):
            self.data = dataval
            self.prev = prevval
    def rhyme_scheme_combinations(iterable, r):
        pool = tuple(iterable)
        n = len(pool)
        curr_max = node(0,None)
        if r > n:
            return
        indices = [0] * r
        yield tuple(pool[i] for i in indices)
        while True:
            if indices[r-1] == r-1:
                return
            if indices[r-1] != indices[curr_max.data] + 1:
                indices[r-1] = indices[r-1] + 1
            else:
                for i in reversed(range(r-1)):
                    while(curr_max.data>=i):
                        curr_max = curr_max.prev
                    if(indices[i] != indices[curr_max.data] +1):
                        break
                indices[i] = indices[i] + 1
                indices[(i+1):] = [0] * (r-i-1)
                if(indices[i] > indices[curr_max.data]):
                    curr_max = node(i, curr_max)
            yield tuple(pool[i] for i in indices)
    
    def is_body_subset(head_dlogBodies, entailing_body, two_vars_in_head):
        for subsetBody in match_predicates(head_dlogBodies, entailing_body):
            if can_match_variables(subsetBody, entailing_body, two_vars_in_head):
                return True
        return False
    
    def match_predicates(head_dlogBodies, entailing_body):
        list_to_return = []
        first_predicate = entailing_body[0].partition("[")[0]
        for i in range(len(head_dlogBodies)):
            if head_dlogBodies[i].partition("[")[0] == first_predicate:
                if len(entailing_body) == 1:
                    list_to_return.append([head_dlogBodies[i]])
                else:
                    for x in match_predicates(head_dlogBodies[:i] + head_dlogBodies[i+1:], entailing_body[1:]):
                        list_to_return.append([head_dlogBodies[i]] + x)
        return list_to_return
    
    def can_match_variables(subsetBody, entailing_body, two_vars_in_head):
        var_dict = {}
        var_dict2 = {}
        var_dict['A'] = 'A'
        var_dict2['A'] = 'A'
        if two_vars_in_head:
            var_dict['B'] = 'B'
            var_dict2['B'] = 'B'
        for i in range(len(entailing_body)):
            x = subsetBody[i].split('?')
            y = entailing_body[i].split('?')
            for j in range(1, len(x)):
                if x[j][0] not in var_dict.keys():
                    if y[j][0] not in var_dict2.keys():
                        var_dict[x[j][0]] = y[j][0]
                        var_dict2[y[j][0]] = x[j][0]
                    else:
                        return False
                else:
                    if y[j][0] not in var_dict2.keys():
                        return False
                    else:
                        if var_dict[x[j][0]] != y[j][0] or var_dict2[y[j][0]] != x[j][0]:
                            return False

        return True
    
    #Credit to this post: https://stackoverflow.com/questions/6116978/how-to-replace-multiple-substrings-of-a-string    
    def multiple_replace(string, rep_dict):
        pattern = re.compile("|".join([re.escape(k) for k in sorted(rep_dict,key=len,reverse=True)]), flags=re.DOTALL)
        return pattern.sub(lambda x: rep_dict[x.group(0)], string)
        
    predicates_dict = {}
    for pred in binaryPredicates:
        predicates_dict[pred] = 2
    for pred in unaryPredicates:
        predicates_dict[pred] = 1
        
    novel_rules = set()
    heads_dict = {}
    counter = 1
    num_predicates = len(predicates_dict.keys())
    w = open('./rules/extracted/{}_{}_extrarules.txt'.format(args.load_model_name, int(args.threshold * 10)), 'a+')
    w_threshold = open('./rules/extracted/{}_{}_extrarules_thresholds.txt'.format(args.load_model_name, int(args.threshold * 10)), 'a+')
    for body_length in range(1, args.max_atoms_in_body + 1):
        print("Searching rule bodies of length {}".format(body_length))
        total_num_comb_with_replacement = factorial(num_predicates + body_length - 1) / (factorial(body_length) * factorial(num_predicates - 1))
        for body_predicates in tqdm(combinations_with_replacement(predicates_dict.keys(), r=body_length), total=total_num_comb_with_replacement):
            max_num_variables = sum([predicates_dict[p] for p in body_predicates])
            variables_string = ''.join([chr(ord('a')+i) for i in range(max_num_variables)])
            for variables in rhyme_scheme_combinations(variables_string, r=max_num_variables):
                num_variables = ord(max(variables))-ord('a')
                body = create_rule_body(body_predicates, variables, predicates_dict)
                if body == 'Duplicate':
                    num_rules_duplicate_variable += 1
                    body = None
                total_num_rules += 1
                if body is not None:
                    # Construct encoded graph from this dataset
                    empty_query_set = set()
                    x, edge_list, edge_type, node_dict, _, _, _ = encode_input_dataset(args.encoding_scheme,
                                                                            body,
                                                                            empty_query_set,
                                                                            binaryPredicates,
                                                                            unaryPredicates)
                    data = Data(x=x, edge_index=edge_list, edge_type=edge_type).to(device)
                    # Find what facts are predicated from this encoded graph
                    GNN_entailed_decoded = decode_and_get_threshold(node_dict, num_binary, num_unary,
                                                binaryPredicates, unaryPredicates,
                                                model(data), args.threshold)
                    # If any of these entailed facts correspond to a rule that doesn't
                    # already exist, note down that rule.
                    body = list(body)
                    dlogBodies = [RDF_to_dlog(body[0])]
                    if len(body) > 1:
                        for bodyPred in body[1:]:
                            dlogBodies.append(RDF_to_dlog(bodyPred))
                    
                    for head_threshold in GNN_entailed_decoded:
                        head = head_threshold[0]
                        extracted_threshold = head_threshold[1]
                        head = head + ' .'
                        head_string = RDF_to_dlog(head)
                        head_dlogBodies = []
                        unaryHead = True
                        replacement_variables_string = ""
                        replacement_variables_big_string = ""
                        head_list = re.split(''' (?=(?:[^'"]|'[^']*'|"[^"]*")*$)''', head)
                        # Rename head variables to A and B
                        if head_list[1] == '<http://www.w3.org/1999/02/22-rdf-syntax-ns#type>':
                            head_var = head_list[0][1:-1].upper()
                            head_string = head_string.replace("?{}".format(head_var), "?A")
                            for i in range(len(dlogBodies)):
                                head_dlogBodies.append(dlogBodies[i].replace("?{}".format(head_var), "?#").replace("?A", "?{}".format(head_var)).replace("?#", "?A"))
                            replacement_variables_string = ''.join([chr(ord('B')+i) for i in range(num_variables)])
                            replacement_variables_big_string = ''.join([chr(ord('B')+i) for i in range(args.max_atoms_in_body * 2)])
                        else:
                            head_var1 = head_list[0][1:-1].upper()
                            head_var2 = head_list[2][1:-1].upper()
                            if head_var1 != head_var2:
                                head_string = multiple_replace(head_string, {"?{}".format(head_var1):"?A" , "?{}".format(head_var2):"?B"})
                                unaryHead = False
                                for i in range(len(dlogBodies)):
                                    if head_var1 == "B" and head_var2 != "A":
                                        head_dlogBodies.append(multiple_replace(dlogBodies[i], {"?B":"?A", "?{}".format(head_var2):"?B", "?A":"?{}".format(head_var2)}))
                                    else:
                                        if head_var2 == "A" and head_var1 !="B":
                                            head_dlogBodies.append(multiple_replace(dlogBodies[i], {"?A":"?B", "?{}".format(head_var1):"?A", "?B":"?{}".format(head_var1)}))
                                        else:
                                            head_dlogBodies.append(multiple_replace(dlogBodies[i], {"?{}".format(head_var1):"?A" , "?{}".format(head_var2):"?B", "?A":"?{}".format(head_var1), "?B":"?{}".format(head_var2)}))
                                replacement_variables_string = ''.join([chr(ord('C')+i) for i in range(num_variables-1)])
                                replacement_variables_big_string = ''.join([chr(ord('C')+i) for i in range(args.max_atoms_in_body * 2)])
                            else:
                                head_string = head_string.replace("?{}".format(head_var1), "?A")
                                for i in range(len(dlogBodies)):
                                    head_dlogBodies.append(dlogBodies[i].replace("?{}".format(head_var1), "?#").replace("?A", "?{}".format(head_var1)).replace("?#", "?A"))
                                replacement_variables_string = ''.join([chr(ord('B')+i) for i in range(num_variables)])
                                replacement_variables_big_string = ''.join([chr(ord('B')+i) for i in range(args.max_atoms_in_body * 2)])
                        new_rule = True
                        
                        if head_string not in heads_dict.keys():
                            heads_dict[head_string] = list()
                            heads_dict[head_string].append([head_string])
                        else:
                            for entailing_body in heads_dict[head_string]:
                                if is_body_subset(head_dlogBodies, entailing_body, not unaryHead):
                                    new_rule = False
                                    break
                        if new_rule:
                            heads_dict[head_string].append(head_dlogBodies)
                            counter = counter + 1
                            body_string = head_dlogBodies[0]
                            if len(head_dlogBodies) > 1:
                                for bodyPred in head_dlogBodies[1:]:
                                    body_string = body_string + ', ' + bodyPred
                            dlog_rule = head_string + " :- " + body_string + " .\n"
                            novel_rules.add(dlog_rule)
                            w.write(dlog_rule)
                            w_threshold.write(str(float(extracted_threshold)) + '\n')
    print("Total number of rules = {}".format(total_num_rules))
    print("Number of rules pruned due to containing a duplicate variable = {}".format(num_rules_duplicate_variable))
