import argparse
import copy
import math
import os
import random
from enum import Enum
from typing import List, Optional
from itertools import combinations, product

import torch
from torch_geometric.data import Data
from tqdm import tqdm
from encoding_schemes import default_type_pred

import gnn_architectures
from model_sparsity import weight_cutoff_model, max_weight_size_in_model
from encoding_schemes import ICLREncoderDecoder, CanonicalEncoderDecoder
import data_parser


def is_monotonic_rule_string_captured(
        model: gnn_architectures.GNN,
        threshold: float,
        can_encoder_decoder: CanonicalEncoderDecoder,
        rule: str,
):
    parts = rule.split(' implies ')
    assert len(parts) == 2, 'Should only be body and head'
    body_str, head_str = parts
    assert ' not ' not in body_str, 'Only monotonic rules can be extracted'
    body_atoms = body_str.split(' and ')
    rule_body = []

    # get rule body
    for atom in body_atoms:
        atom_comma_parts = atom.split(',')
        assert len(atom_comma_parts) <= 2, 'Only unary and binary predicates supported'
        if len(atom_comma_parts) == 2:
            var1 = atom_comma_parts[0][-1]
            var2 = atom_comma_parts[1][0]
            atom_bracket_parts = atom.split('(')
            assert len(atom_bracket_parts) == 2, 'Should only be a single opening bracket in binary body atom'
            predicate = atom_bracket_parts[0]
            rule_body.append((var1, predicate, var2))
        elif len(atom_comma_parts) == 1:
            var = atom[-2]  # second last char is variable
            # get predicate
            atom_bracket_parts = atom.split('(')
            assert len(atom_bracket_parts) == 2, 'Should only be a single opening bracket in unary body atom'
            predicate = atom_bracket_parts[0]
            rule_body.append((var, default_type_pred, predicate))

    # get rule head
    atom = head_str
    # get variables
    atom_comma_parts = atom.split(',')
    assert len(atom_comma_parts) == 1, 'Only unary predicates supported in the head'
    var = atom[-2]
    # get predicate
    atom_bracket_parts = atom.split('(')
    assert len(atom_bracket_parts) == 2, 'Should only be a single opening bracket in the head'
    predicate = atom_bracket_parts[0]
    rule_head = (var, default_type_pred, predicate)

    return is_monotonic_rule_captured(model, threshold, can_encoder_decoder, rule_body, rule_head)


def is_monotonic_rule_captured(
        model,
        threshold,
        can_encoder_decoder,
        rule_body,
        rule_head,
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if rule_body == ():  # special case for empty rule body
        head_constant = rule_head[0]
        (gr_features, gr_nodes, gr_edge_list, gr_colour_list) =\
            can_encoder_decoder.encode_dataset(rule_body, empty_constant=head_constant)
    else:
        (gr_features, gr_nodes, gr_edge_list, gr_colour_list) = can_encoder_decoder.encode_dataset(rule_body)
    gr_dataset = Data(x=gr_features, edge_index=gr_edge_list, edge_type=gr_colour_list).to(device)
    gnn_output_gr = model(gr_dataset)

    cd_output_dataset_scores_dict = can_encoder_decoder.decode_graph(gr_nodes, gnn_output_gr, threshold)

    return rule_head in cd_output_dataset_scores_dict


# check which monotonic rules mean-GNN captures
# note: this is too slow to run feasibly
def check_all_monotonic_rules_lattice(
        model: gnn_architectures.GNN,
        canonical_encoder_file: str,
        threshold: float,
):
    can_encoder_decoder = CanonicalEncoderDecoder(load_from_document=canonical_encoder_file)

    captured = is_monotonic_rule_string_captured(
        model,
        model_threshold,
        can_encoder_decoder,
        '<http://swat.cse.lehigh.edu/onto/univ-bench.owl#undergraduateDegreeFrom>(Y,X) '
        'implies <http://swat.cse.lehigh.edu/onto/univ-bench.owl#University>(X)')
    print('\n\nRule captured:', captured)

    # list of predicates
    binary_predicate_list = list(can_encoder_decoder.binary_pred_colour_dict.keys())
    unary_predicate_list = list(can_encoder_decoder.unary_pred_position_dict.keys())
    print('\n\nbinary predicates', len(binary_predicate_list), binary_predicate_list)
    print('\n\nunary predicates', len(unary_predicate_list), unary_predicate_list)

    all_base_atoms = [('X', default_type_pred, predicate) for predicate in unary_predicate_list] +\
                     [('Y', predicate, 'X') for predicate in binary_predicate_list]
    head_atoms = [('X', default_type_pred, predicate) for predicate in unary_predicate_list]

    # the following is from algorithm 5 of the journal paper
    sound_rules = []
    for head_atom in tqdm(head_atoms):
        rule_length = 0
        min_pos = []
        # TODO (?) - make queue
        frontier = [[]]  # [] is an empty body (i.e. \top)
        while frontier:
            gamma = frontier[0]
            if len(gamma) > rule_length:
                rule_length = len(gamma)
                print('Reached rule length:', rule_length)
                print('Frontier size:', len(frontier))
            if is_monotonic_rule_captured(model, threshold, can_encoder_decoder, gamma, head_atom):
                min_pos = [gamma_p for gamma_p in min_pos if not set(gamma) <= set(gamma_p)]
                frontier.pop(0)
                min_pos.append(gamma)
            else:
                while True:
                    # check if smallest gamma_n exists
                    smallest_gamma_n = None
                    smallest_gamma_n_index = -1
                    for i, gamma_n in enumerate(frontier):
                        if set(gamma_n) <= set(gamma):
                            if smallest_gamma_n is None or len(gamma_n) < len(smallest_gamma_n):
                                smallest_gamma_n = gamma_n
                                smallest_gamma_n_index = i
                    if smallest_gamma_n is None:
                        break

                    frontier.pop(smallest_gamma_n_index)

                    if len(smallest_gamma_n) == len(all_base_atoms):  # extension not in gamma_all
                        continue
                    for new_atom in all_base_atoms:
                        gamma_s = smallest_gamma_n + [new_atom]
                        add_to_frontier = True
                        for gamma_p in min_pos:
                            if set(gamma_p) <= set(gamma_s):
                                add_to_frontier = False
                                break

                        if add_to_frontier:
                            frontier.append(gamma_s)

        for gamma in min_pos:
            sound_rules.append((gamma, head_atom))

    print(sound_rules)
    assert False


# check if a candidate rule is subsumed by a sound rule
def rule_subsumed(candidate_rule, sound_rule):
    b1, h1 = candidate_rule
    b2, h2 = sound_rule
    if h1 != h2:
        return False
    return set(b2).issubset(set(b1))


# prepare a predicate string to use with latex
def clean_predicate_string_for_latex(string):
    if '#' in string:
        string = string.split('#')[1]  # gets rid of redundant part of predicate
    string = string.replace('concept:', '')  # redundant in nell
    string = string.replace('<', '').replace('>', '')  # not friendly for latex
    string = string.replace('_', '\\_')  # not friendly for latex
    string = f'\\text{{{string}}}'
    return string


# convert rule to human-readable ALCQ notation
def rule_to_alcq_notation(rule):
    if rule is None:
        return None
    body, head = rule
    body_string = '\\top'
    for s, p, o in body:
        if p == default_type_pred:
            o = clean_predicate_string_for_latex(o)
            body_string += f' \\sqcap {o}'
        else:
            p = clean_predicate_string_for_latex(p)
            body_string += f' \\sqcap \\exists {p}.\\top'
    head_concept = clean_predicate_string_for_latex(head[2])
    if len(body) != 0:  # remove redundant \top
        body_string = body_string.removeprefix('\\top \\sqcap ')
    return f'{body_string} \\sqsubseteq {head_concept}'


# convert iclr encoded rule to Datalog notation
def iclr_rule_to_datalog_notation(rule, iclr_encoder_decoder):
    if rule is None:
        return None
    body, head = rule
    body_string = '\\top'
    to_binary_preds = iclr_encoder_decoder.unary_canonical_to_input_predicate_dict

    for s, p, o in body:
        if p == default_type_pred:
            o = clean_predicate_string_for_latex(to_binary_preds[o])
            body_string += f' \\land {o}(X,Y)'
        else:
            # only binary atoms in the datasets considered
            # thus, any binary encoded fact will appear, for every edge colour
            # sound rules of this form have no meaning in the datasets we use
            # the binary predicate can always be dropped, since it is guaranteed to appear
            # we include it here for completeness
            p = clean_predicate_string_for_latex(p)
            body_string += f' \\land \\exists {p}.\\top'

    head_atom = f'{clean_predicate_string_for_latex(to_binary_preds[head[2]])}(X,Y)'
    if len(body) != 0:  # remove redundant \top
        body_string = body_string.removeprefix('\\top \\land ')
    return f'{body_string} \\rightarrow {head_atom}'


# count the number of sound rules that have only unary, only binary, or some mix of unary and binary in the body
def count_sound_rules_with_predicate_arities(rules, unary_predicates, binary_predicates):
    count_unary, count_binary, count_mixed = 0, 0, 0

    for body, head in rules:
        has_unary, has_binary = False, False
        for atom in body:
            for literal in atom:
                if literal in unary_predicates:
                    has_unary = True
                if literal in binary_predicates:
                    has_binary = True
        if has_unary and not has_binary:
            count_unary += 1
        elif not has_unary and has_binary:
            count_binary += 1
        elif has_unary and has_binary:
            count_mixed += 1

    return count_unary, count_binary, count_mixed


# check which monotonic rules mean-GNN captures
def check_monotonic_rules(
        model: gnn_architectures.GNN,
        canonical_encoder_file: str,
        iclr22_encoder_file: Optional[str],
        encoding_scheme: str,
        threshold: float,
        max_body_atoms_to_choose: int,
):
    can_encoder_decoder = CanonicalEncoderDecoder(load_from_document=canonical_encoder_file)
    iclr_encoder_decoder = None
    if encoding_scheme == 'iclr22':
        iclr_encoder_decoder = ICLREncoderDecoder(load_from_document=iclr22_encoder_file)

    # list of predicates
    binary_predicate_list = list(can_encoder_decoder.binary_pred_colour_dict.keys())
    unary_predicate_list = list(can_encoder_decoder.unary_pred_position_dict.keys())
    # print('\n\nbinary predicates', len(binary_predicate_list), binary_predicate_list)
    # print('\n\nunary predicates', len(unary_predicate_list), unary_predicate_list)

    all_base_atoms = [('X', default_type_pred, predicate) for predicate in unary_predicate_list] + \
                     [('Y', predicate, 'X') for predicate in binary_predicate_list]
    head_atoms = [('X', default_type_pred, predicate) for predicate in unary_predicate_list]

    sound_counts = {}
    sound_rules = []  # for checking whether candidate rules are subsumed
    sample_sound_rule_dict = {}  # for returning at the end

    # iterate through different numbers of body atoms
    for body_atoms_to_choose in range(0, max_body_atoms_to_choose + 1):
        print(f'Checking rules with {body_atoms_to_choose} body atoms')
        sound_count = 0
        all_bodies = list(combinations(all_base_atoms, body_atoms_to_choose))
        all_rules = list(product(all_bodies, head_atoms))

        # filter out all rules subsumed by already sound rules with fewer body atoms
        print(f'{len(all_rules)} possible rules')
        all_rules = [
            rule for rule in all_rules
            if not any(rule_subsumed(rule, sound_rule) for sound_rule in sound_rules)
        ]
        print(f'Filtered, checking {len(all_rules)} rules')

        # for returning a random sound rule
        random.shuffle(all_rules)
        first_sound_rule = None

        # check each rule remaining after filtering
        for body, head in tqdm(all_rules):
            if is_monotonic_rule_captured(model, threshold, can_encoder_decoder, body, head):
                sound_count += 1
                sound_rules.append((body, head))

                if first_sound_rule is None:
                    first_sound_rule = (body, head)

        sound_counts[body_atoms_to_choose] = sound_count
        sample_sound_rule_dict[body_atoms_to_choose] = first_sound_rule
        print('Count of sound rules by number of body atoms:', sound_counts)

    # if ICLR encoding, convert predicates back into canonical form
    if encoding_scheme == 'iclr22':
        sample_sound_rule_dict = {
            key: iclr_rule_to_datalog_notation(rule, iclr_encoder_decoder)
            for key, rule in sample_sound_rule_dict.items()
        }
    else:
        # neaten up text of sample sound rules for latex
        sample_sound_rule_dict = {key: rule_to_alcq_notation(rule) for key, rule in sample_sound_rule_dict.items()}
    # compute count of number of rules with different predicate arity
    arity_counts = count_sound_rules_with_predicate_arities(sound_rules, unary_predicate_list, binary_predicate_list)

    return sound_counts, sample_sound_rule_dict, arity_counts


def explanation_concept(
        dataset,
        constant,
        ell,
        binary_predicate_list,
):
    subconcept_count = 1  # track number of sub-concepts
    concept = '\\top'
    atomics_to_include = {A for c, predicate, A in dataset if c == constant and predicate == default_type_pred}
    for atomic in atomics_to_include:
        concept += f' ~\\sqcap~ {clean_predicate_string_for_latex(atomic)}'
        subconcept_count += 1
    if len(atomics_to_include) > 0:
        concept = concept.replace('\\top ~\\sqcap~ ', '')
        subconcept_count -= 1

    if ell == 0:
        return concept, subconcept_count

    for predicate in binary_predicate_list:
        # this check corresponds to data being passed in direction of the edges
        linked_constants = [s for s, p, o in dataset if o == constant and p == predicate]
        n = len(linked_constants)
        if n == 0:
            continue
        sub_concepts = [explanation_concept(dataset, ci, ell - 1, binary_predicate_list) for ci in linked_constants]
        concept += f' ~\\sqcap~ \\exists_{{{n}}} {clean_predicate_string_for_latex(predicate)}.('
        subconcept_count += 1
        for sub_concept, count in sub_concepts:
            subconcept_count += count
            concept += f'{sub_concept}, '
        concept = concept[:-2]
        concept += f') ~\\sqcap~ \\leq_{{{n}}} {clean_predicate_string_for_latex(predicate)}.\\top'
        subconcept_count += 1

    return concept, subconcept_count


def greedy_reduced_test_graph(
        layers: int,
        binary_predicate_list: List,
        prediction,
        model: gnn_architectures.GNN,
        can_encoder_decoder,
        threshold: float,
        test_graph,
):
    assert layers == 2, 'Solution hard-coded for 2 layers: generalises to any number of layers though'

    # Approach:
    # check 1-hop neighbourhood first
    # get list of binary predicates in neighbourhood
    # get size of the subtree for each binary predicate
    # sort binary predicates by size of the subtree
    # in descending order of size, try removing each binary predicate
    # if prediction still holds, remove the binary predicate and continue
    # otherwise, keep the binary predicate and continue
    # then, for each remaining constant adjacent to the root, repeat the above process

    root_constant, _, _ = prediction

    # create 2-hop neighbourhood graph

    # 1-hop
    root_atomics = []
    root_binaries = {binary_pred: [] for binary_pred in binary_predicate_list}  # root binary facts for each pred
    root_linked_constants = dict()  # map from constant adjacent to root -> all other facts in remaining 1-hop
    # map is recursive

    for s, p, o in test_graph:
        if s == root_constant and p == default_type_pred:
            root_atomics.append((s, p, o))
        elif o == root_constant:  # this check corresponds to data being passed in direction of the edges
            root_binaries[p].append((s, p, o))
            root_linked_constants[s] = {'atomics': [],
                                        'binaries': {binary_pred: [] for binary_pred in binary_predicate_list},
                                        'linked_constants': dict()}

    # 2-hop
    for s, p, o in test_graph:
        for constant in root_linked_constants.keys():
            if s == constant and p == default_type_pred:
                root_linked_constants[constant]['atomics'].append((s, p, o))
            elif o == constant:
                root_linked_constants[constant]['binaries'][p].append((s, p, o))
                root_linked_constants[constant]['linked_constants'][s] = {'atomics': [
                    (c, predicate, A) for c, predicate, A in test_graph if c == s and predicate == default_type_pred
                ]}  # final atomics (the ones of s), at the 3-hop border

    # check empty dataset first
    empty_body_dataset = [(root_constant, binary_predicate_list[0], 'dummy_constant')]
    if prediction in model_predictions(model, can_encoder_decoder, threshold, empty_body_dataset):
        return empty_body_dataset

    # now check dataset with only root atomics
    dataset = []
    dataset += root_atomics
    # only call model if dataset is non-empty
    if dataset and prediction in model_predictions(model, can_encoder_decoder, threshold, dataset):
        return dataset

    # include single existential for each one-hop binary, one at a time
    # from experiments, this seems sufficient most of the time, so no need to continue
    # (only failed 10/2000 times)
    for binary_pred in binary_predicate_list:
        if root_binaries[binary_pred]:
            dataset += [root_binaries[binary_pred][0]]
        if dataset and prediction in model_predictions(model, can_encoder_decoder, threshold, dataset):
            return dataset

    # include one-hop binaries, one at a time
    for binary_pred in binary_predicate_list:
        if root_binaries[binary_pred]:
            dataset += root_binaries[binary_pred][1:]  # includes the rest of the binaries
        if dataset and prediction in model_predictions(model, can_encoder_decoder, threshold, dataset):
            return dataset

    print('Algorithm failed to reduce dataset size')

    # TODO: continue, if you want more greedy checking (only checking 1 hop for now)

    return test_graph


# get predictions from the model
def model_predictions(
        model: gnn_architectures.GNN,
        can_encoder_decoder,
        threshold: float,
        test_graph,
):
    # encode input graph
    (test_x, test_nodes, test_edge_list, test_edge_colour_list) = can_encoder_decoder.encode_dataset(test_graph)
    test_data = Data(x=test_x, edge_index=test_edge_list, edge_type=test_edge_colour_list)

    # use model
    gnn_output = model(test_data)

    # decode
    cd_output_dataset_scores_dict = can_encoder_decoder.decode_graph(test_nodes, gnn_output, threshold)
    model_predicted_facts = cd_output_dataset_scores_dict.keys()
    return model_predicted_facts


# get rule explanations for mean GNNs
# assumes canonical encoding
def explain_predictions(
        model: gnn_architectures.GNN,
        canonical_encoder_file: str,
        threshold: float,
        test_graph_path: str,
        test_positive_examples_path: str,
        reduce_rule_size=True,
):
    can_encoder_decoder = CanonicalEncoderDecoder(load_from_document=canonical_encoder_file)
    # list of predicates
    binary_predicate_list = list(can_encoder_decoder.binary_pred_colour_dict.keys())

    # load input graph
    assert os.path.exists(test_graph_path)
    print("Loading graph data from {}".format(test_graph_path))
    test_graph = data_parser.parse(test_graph_path)

    # load test positives
    assert os.path.exists(test_positive_examples_path)
    print("Loading examples data from {}".format(test_positive_examples_path))
    test_positive_examples_dataset = data_parser.parse(test_positive_examples_path)

    # get model predictions, and true positives
    model_predicted_facts = model_predictions(model, can_encoder_decoder, threshold, test_graph)
    true_positives = [fact for fact in test_positive_examples_dataset if fact in model_predicted_facts]

    print('Explaining predictions')
    if reduce_rule_size:
        print('Using a greedy algorithm to reduce the explanatory rule size')
    random.shuffle(true_positives)
    sample_rules = []
    avg_concepts_per_rule = 0
    for prediction in tqdm(true_positives):
        constant, _, atomic_concept = prediction
        layers = model.num_layers
        if reduce_rule_size:
            explain_test_graph = greedy_reduced_test_graph(layers, binary_predicate_list, prediction,
                                                           model, can_encoder_decoder,
                                                           threshold, test_graph)
        else:
            explain_test_graph = test_graph
        body, num_concepts_in_rule = explanation_concept(explain_test_graph, constant, layers, binary_predicate_list)
        rule = f'{body} \\sqsubseteq {clean_predicate_string_for_latex(atomic_concept)}'
        num_concepts_in_rule += 1
        if len(sample_rules) < 50:  # store the first 50 rules to return, as examples
            clean_fact = f'{clean_predicate_string_for_latex(prediction[2])}' \
                         f'({clean_predicate_string_for_latex(prediction[0])})'
            sample_rules.append((clean_fact, rule))
        avg_concepts_per_rule += num_concepts_in_rule
    avg_concepts_per_rule = avg_concepts_per_rule / len(true_positives)

    return avg_concepts_per_rule, sample_rules


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Extract monotonic sound rules")
    parser.add_argument('--model-path', help='Path to model file')
    parser.add_argument('--encoder-path', help='Path to canonical encoder')
    args = parser.parse_args()

    model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    loaded_model: gnn_architectures.GNN = torch.load(args.model_path).to(model_device)
    model_threshold = loaded_model.eval_thresholds[-1]

    check_all_monotonic_rules_lattice(loaded_model, args.encoder_path, model_threshold)
