#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: ----
"""
import torch
from torch_geometric.data import Data

import argparse
import os.path

from utils import predict_entailed_fast, load_predicates, dlog_to_RDF, encode_input_dataset, output_scores

parser = argparse.ArgumentParser(description="Evaluate a trained GNNs")
parser.add_argument('--dataset-name',
                    help='Name of the dataset, for saving and loading files')
parser.add_argument('--load-model-name',
                    help='Filename to of trained model to load')
parser.add_argument('--encoding-scheme',
                    default='NEC',
                    nargs='?',
                    choices=['NEC', 'EC'],
                    help='Choose whether to encode with edge colours or not')
parser.add_argument('--threshold',
                    type=float,
                    default=0.5)
parser.add_argument('--test-ground-rules', action='store_true',
                    help='Report number of ground rules learnt')
parser.add_argument('--test-data', action='store_true',
                    help='Report results on test dataset')
parser.add_argument('--verbose', action='store_true',
                    help='Additionally report exact numbers and FNs, FPs')
parser.add_argument('--find-max-threshold', action='store_true',
                    help='Binary search to find max threshold such that all rules are learnt')
parser.add_argument('--binary-search-tolerance',
                    type=float,
                    default=0.0001,
                    help='Tolerance in binary search algorithm for max threshold')
parser.add_argument('--test-graph',
                    nargs='?',
                    default=None,
                    help='Filename of graph test data')
parser.add_argument('--test-examples',
                    nargs='?',
                    default=None,
                    help='Filename of (positive and negative) examples')
parser.add_argument('--max-iterations',
                    nargs='?',
                    default=None,
                    help ='Maximum number of GNN interations to test. Default is None, for infinite')
parser.add_argument('--print-entailed-facts',
                    default=None,
                    help='Print the facts that have been derived in the provided filename.')
parser.add_argument('--get-scores',
                    action='store_true',
                    help='Rather than using the threshold, get the scores for the facts in the query set')

args = parser.parse_args()


def print_and_log(message):
    """A helper function to both print to console and also log information in
    a suitable file"""
    print(message)
    with open("./models/{}.txt".format(args.load_model_name), 'a+') as w:
        w.write(message + '\n')


if __name__ == "__main__":
    
    print_and_log("Running evaluation of {} using {} encoding, threshold={}".format(args.dataset_name,
                                                                                    args.encoding_scheme,
                                                                                    args.threshold))
    print_and_log("Evaluating model {}".format(args.load_model_name))

    if args.test_data:
        test_data_path = args.test_graph
        test_data_output_path = args.test_examples
        assert os.path.exists(test_data_path)
        assert os.path.exists(test_data_output_path)
        
    # Load binary and unary predicates into memory
    binaryPredicates, unaryPredicates = load_predicates(args.dataset_name)
    num_binary = len(binaryPredicates)
    num_unary = len(unaryPredicates)
    mask_threshold = num_binary + num_unary
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    @torch.no_grad()
    def evaluate(eval_model, max_iterations=1, threshold=0.5):
        """Evaluate the model on test dataset."""
        eval_model.eval()
        
        incomplete_dataset = set()
        with open(test_data_path, 'r') as test_dataset:
            for the_fact in test_dataset:
                assert(the_fact[-1:] == '\n')
                incomplete_dataset.add(the_fact[:-1])
        
        examples_dataset = set()
        with open(test_data_output_path, 'r') as ground_entailed:
            for the_fact in ground_entailed:
                assert(the_fact[-1:] == '\n')
                examples_dataset.add(the_fact[:-1])

        GNN_entailed_dataset = set()
        if args.get_scores:
            GNN_entailed_dataset = output_scores(args.encoding_scheme,
                                                 eval_model,
                                                 binaryPredicates,
                                                 unaryPredicates,
                                                 incomplete_dataset,
                                                 examples_dataset,
                                                 device=device)
        else:
            GNN_entailed_dataset = predict_entailed_fast(args.encoding_scheme,
                                                         eval_model,
                                                         binaryPredicates,
                                                         unaryPredicates,
                                                         incomplete_dataset,
                                                         examples_dataset,
                                                         max_iterations,
                                                         threshold, device=device)
        if args.print_entailed_facts is not None:
            with open(args.print_entailed_facts, 'w') as output:
                for fact in GNN_entailed_dataset:
                    output.write(fact + '\n')

        return GNN_entailed_dataset
    
    model = torch.load('./models/' + args.load_model_name + '.pt').to(device)
    model.eval()
    
    def check_ground_rules(model, threshold, verbose=True):
        '''Check what true rules are learnt by this model'''
        prefix_dict = {}
        with open("./rules/ground/{}_rules.dlog".format(args.dataset_name)) as f:
            a = [x for x in f]
        dlog_rules = [x.split(' ') for x in a]
        unlearnt_rules = []
        learnt_rules = []
        i=0
        while dlog_rules[i][0] == 'PREFIX':
            if dlog_rules[i][2][-1:] == '\n':
                dlog_rules[i][2] = dlog_rules[i][2][:-1]
            prefix_dict[dlog_rules[i][1][:-1]] = dlog_rules[i][2][1:-1] # Also get rid of < >
            i += 1

        # Skip however many more blank lines required
        while dlog_rules[i][0] == '\n':
            i += 1

        # Now start interpreting dlog rules and converting to RDF triples
        while i < len(dlog_rules):
            # We can always ignore the final element in each array because it's '.\n'
            dlog_rules[i] = dlog_rules[i][:-1]
            body = dlog_rules[i][2:] # since 2nd element is :-, all after that is body
            for j, pred in enumerate(body): # clean off trailing commas signifying \land
                if pred[-1] == ',':
                    body[j] = pred[:-1]

            head = [dlog_rules[i][0]]

            body = dlog_to_RDF(body, prefix_dict)

            RDF_head = dlog_to_RDF(head, prefix_dict)
            x, edge_list, edge_type, node_dict, _, _, _ = encode_input_dataset(args.encoding_scheme, body, binaryPredicates, unaryPredicates, training=False)
            y, _, _, y_node_dict, _, _, _ = encode_input_dataset(args.encoding_scheme, RDF_head, binaryPredicates, unaryPredicates, node_dict=node_dict, edge_list=edge_list, edge_type_list=edge_type)
            assert(node_dict == y_node_dict)
            data = Data(x=x, y=y, edge_index=edge_list, edge_type=edge_type)
            data = data.to(device)
            out = model(data)
            head_index = torch.nonzero(y>=threshold)
            assert(len(head_index) == 1)
            head_index = head_index[0]
            GNN_value = out[head_index[0], head_index[1]]
            nice_formatted_rule = ' '.join(dlog_rules[i])
            if GNN_value <= threshold:
                unlearnt_rules.append(nice_formatted_rule)
            else:
                learnt_rules.append(nice_formatted_rule)
            i += 1
        if verbose:
            print_and_log("Number of unlearnt rules = {}".format(len(unlearnt_rules)))
            print_and_log("Number of learnt rules = {}".format(len(learnt_rules)))
        return learnt_rules, unlearnt_rules

    if args.find_max_threshold:
        # Binary search for maximum threshold at which all rules are learnt
        lower_bound = 0.0
        upper_bound = 1.0
        while upper_bound - lower_bound >= args.binary_search_tolerance:
            print("Error = {}, tolerance = {}".format(upper_bound - lower_bound, args.binary_search_tolerance), end='\r')
            x = lower_bound + (upper_bound - lower_bound) / 2.0
            _, unlearnt_rules = check_ground_rules(model, x, verbose=False)
            all_rules_learnt = len(unlearnt_rules) == 0
            if all_rules_learnt:
                lower_bound = x
            else:
                upper_bound = x
        print("Maximum threshold with all rules learnt = {}".format(lower_bound))
    if args.test_ground_rules:
        check_ground_rules(model, args.threshold)
    if args.test_data:
        print_and_log("Checking fixpoint of GNN")
        evaluate(model, max_iterations=1, threshold=args.threshold)
