#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: ----
"""

import torch
from torch_geometric.data import Data

from utils import encode_input_dataset, decode_and_get_threshold, load_predicates

import argparse

parser = argparse.ArgumentParser(description="Check rules captured by trained GNNs")
parser.add_argument('--dataset-name',
                    help='Name of the dataset, for saving and loading files')
parser.add_argument('--encoding-scheme',
                    default='EC',
                    nargs='?',
                    choices=['NEC', 'EC'],
                    help='Choose whether to encode with edge colours or not')
parser.add_argument('--threshold',
                    type=float,
                    default=0.5,
                    help='Maximum threshold in the last layer of the model for which the rules are guaranteed to be captured')
parser.add_argument('--load-model-name',
                    help='Filename to load trained model')
parser.add_argument('--rules',
                    help='File with all the rules we want to check')

args = parser.parse_args()

if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = torch.load("./models/{}.pt".format(args.load_model_name))
    model.to(device)

    # Load binary and unary predicates into memory
    binaryPredicates, unaryPredicates = load_predicates(args.dataset_name)
    num_binary = len(binaryPredicates)
    num_unary = len(unaryPredicates)
    mask_threshold = num_binary + num_unary


    def rule_atom_to_fact(atom):
        rel = atom.split('(')[0]
        ent1 = atom.split('(')[1].split(')')[0].split(',')[0]
        ent2 = atom.split('(')[1].split(')')[0].split(',')[1]
        return "{} {} {}".format(ent1, rel, ent2)


    def process_rule(input_rule):
        atoms = input_rule.split(' ')
        the_head = rule_atom_to_fact(atoms[0])
        assert atoms[1] == '<='
        the_body = set()
        for atom in atoms[2:]:
            the_body.add(rule_atom_to_fact(atom))
        return the_head, the_body


    def rule_is_tree_like(the_head, the_body):
        head_ent1, _, head_ent2 = the_head.split(' ')
        if head_ent1 == head_ent2:
            #  I refer to head entity as x
            graph = {head_ent1: set(head_ent1)}
            for atom in the_body:
                ent1, _, ent2 = atom.split(' ')
                graph[ent1] = graph.get(ent1, set()).union(set(ent2))
                graph[ent2] = graph.get(ent2, set()).union(set(ent1))
                # Remove self-loop in root
            graph[ent1].remove(ent1)
            seen = {head_ent1}
            # Do an exploration of the tree rooted at x
            frontier = [head_ent1]
            while len(frontier) > 0:
                selected = frontier.pop()
                for ent in graph[selected]:
                    if ent in seen:
                        return False
                    # Since we are exploring connection a to b, remove connection b to a
                    graph[ent].remove(selected)
                    frontier.append(ent)
                    seen.add(ent)
            return True
        else:
            # I refer to these as x and y
            assert head_ent1 != head_ent2
            # Create a tree as a dictionary where each node is a key, and each of its neigbhours is in its key value.
            graph = {head_ent1: set(head_ent2), head_ent2: set(head_ent1)}
            for atom in the_body:
                ent1, _, ent2 = atom.split(' ')
                graph[ent1] = graph.get(ent1, set()).union(set(ent2))
                graph[ent2] = graph.get(ent2, set()).union(set(ent1))
            # Remove edge between x and y
            graph[ent1].remove(ent2)
            graph[ent2].remove(ent1)
            seen = {head_ent1, head_ent2}
            # Do an exploration of the tree rooted at x
            frontier = [head_ent1]
            while len(frontier) > 0:
                selected = frontier.pop()
                for ent in graph[selected]:
                    if ent in seen:
                        return False
                    # Since we are exploring connection a to b, remove connection b to a
                    graph[ent].remove(selected)
                    frontier.append(ent)
                    seen.add(ent)
            # Do an exploration of the tree rooted at y
            frontier = [head_ent2]
            while len(frontier) > 0:
                selected = frontier.pop()
                for ent in graph[selected]:
                    if ent in seen:
                        return False
                    # Since we are exploring connection a to b, remove connection b to a
                    graph[ent].remove(selected)
                    frontier.append(ent)
                    seen.add(ent)
            return True


    input_rules = open(args.rules, 'r').readlines()
    print("Total number of rules to test: {}".format(len(input_rules)))
    counter = 0
    counter_treelike = 0
    counter_yes = 0
    counter_yes_treelike = 0
    with open(args.rules + '_entailed_general.txt', 'w') as output:
        for line in input_rules:
            counter += 1
            if counter % 250 == 0:
                print("Rules processed: {}".format(counter))
            rule_entailed = False
            # Remove the end-of-line character
            rule = line.split('\t')[3][:-1]
            head, body = process_rule(rule)
            if rule_is_tree_like(head, body):
                counter_treelike += 1
            empty_query_set = set()
            x, edge_list, edge_type, node_dict, _, _, _ = encode_input_dataset(args.encoding_scheme,
                                                                               body,
                                                                               empty_query_set,
                                                                               binaryPredicates,
                                                                               unaryPredicates)
            data = Data(x=x, edge_index=edge_list, edge_type=edge_type).to(device)
            # Find what facts are predicated from this encoded graph
            GNN_entailed_decoded = decode_and_get_threshold(node_dict, num_binary, num_unary,
                                                            binaryPredicates, unaryPredicates,
                                                            model(data), args.threshold)
            # Check if the target fact has been derived
            for derived_heads in GNN_entailed_decoded:
                # 'derived_heads' is a pair, where the first element is the head, and the second, the threshold 
                derived_head = derived_heads[0]
                if head == derived_head:
                    rule_entailed = True
            if rule_entailed:
                counter_yes += 1
                if rule_is_tree_like(head, body):
                    counter_yes_treelike += 1
                output.write(rule + '\n')
        print("Total number of rules covered by MeGaNN: {}".format(counter_yes))
        print("Total number of tree-like rules: {}".format(counter_treelike))
        print("Total number of tree-like rules covered by MeGaNN: {}".format(counter_yes_treelike))
        output.close()
