# import objects
from meteor_reasoner.utils.operate_dataset import print_dataset
from meteor_reasoner.materialization.materialize import coalescing_d

from meteor_reasoner.classes.atom import Atom
from meteor_reasoner.classes.term import Term
from meteor_reasoner.classes.literal import Literal, Operator
from meteor_reasoner.classes.rule import Rule
from meteor_reasoner.classes.interval import Interval

import networkx as nx
import random
from matplotlib import pyplot as plt
import decimal
import collections
import argparse
import json
import utils
import typing
from collections import defaultdict
import tqdm
from tqdm.contrib.concurrent import process_map
import functools
import numpy as np

from numpy.random import choice


def dump_graph(G: nx.DiGraph):
    pos = nx.spring_layout(G)
    nx.draw_networkx(G, pos)
    node_labels = nx.get_node_attributes(G, 'data')
    for k, v in node_labels.items():
        node_labels[k] = (node_labels[k][0], str(node_labels[k][2]))
    nx.draw_networkx_labels(G, pos, labels=node_labels)
    edge_labels = nx.get_edge_attributes(G, 'rule')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
    plt.show()


class Instance():
    def __init__(self, program: typing.List[Rule], data: defaultdict, query: typing.Tuple[str, typing.Tuple, Interval],  valid: bool):
        self.program = program
        self.data = data
        self.query = query
        self.valid = valid

    def __str__(self):
        return "Data:" + utils.dump_dataset(self.data) + "\n" + \
            "Program:" + "\n".join(list(map(utils.dump_rule, self.program))) + "\n\n" + \
            "Query:" + utils.dump_single_data(self.query) + "\n"


def default_argument_parser():
    parser = argparse.ArgumentParser(description="generate DatalogMTL data")
    parser.add_argument("--single", action='store_true',
                        help="Generate one instance (for debug purpose)")
    parser.add_argument("--rational-number", action='store_true')
    parser.add_argument("--multiple-body-atoms", action='store_true')
    parser.add_argument("--recursive", action='store_true')
    parser.add_argument("--mixed-operators", action='store_true')
    parser.add_argument("--multiple-rules", action='store_true')
    parser.add_argument("--variable", action='store_true')
    parser.add_argument("--multiple-variables", action='store_true')
    parser.add_argument("--fixed-new-nodes", type=int)
    parser.add_argument("--fixed-rules", type=int)
    parser.add_argument("--addl-rules", type=int)
    parser.add_argument("--addl-data", type=int)

    parser.add_argument("-n", default=10,
                        type=int, help="The number of samples generated")
    return parser


def id_to_char(id: int) -> str:
    assert (id < 26)
    return chr(ord('A')+id)


def generate_graph(graph: nx.DiGraph, features: set, additional_args: dict) -> typing.Tuple[nx.Graph, typing.List[Atom], Atom]:
    n_existing = len(graph.nodes)
    min_n = 1
    max_n = 2
    if len(graph.nodes) == 0:
        min_n = 2
    else:
        max_n = 1

    if "recursive" in features:
        min_n = 1
        max_n = 1
    if "multiple_body_atoms" in features:
        min_n = 2
        max_n = 6
    if "fixed_new_nodes" in additional_args.keys():
        n = additional_args['fixed_new_nodes']
    else:
        n = random.randint(min_n, max_n)
    assert (n >= min_n)
    assert (n <= max_n)

    new_node = []
    out_node = None

    for i in range(n_existing, n_existing + n):
        predicate = id_to_char(i)
        graph.add_node(predicate, with_data=False,
                       entity=tuple([Term("nan")]))
        new_node.append(predicate)

    out_node = random.choice(new_node)

    node_list = list(graph.nodes)
    random.shuffle(node_list)

    for node in node_list:
        if hash(node) == hash(out_node) and ("recursive" not in features):
            continue
        if "multiple_body_atoms" not in features and graph.out_degree(node) >= 1:
            continue
        if "multiple_rules" not in features or random.random() < 1:
            # if "multiple_rules" not in features or random.random() < 0.8:
            graph.add_edge(node, out_node, has_rule=False)
        if "multiple_body_atoms" not in features and graph.in_degree(out_node) >= 1:
            break

    if "variable" in features:
        try:
            cycle = nx.find_cycle(graph, orientation="original")
            it_order = [cycle[0][0], cycle[0][1]] + list(graph.nodes)
            variable_letter = utils.random_upper_case()
            graph.nodes[node]['entity'] = (
                Term(variable_letter, type="variable"), )
        except:
            # no cycle
            assert ("recursive" not in features)
            it_order = list(nx.topological_sort(graph))
        for node in it_order:
            if node not in new_node:
                continue
            in_edges = graph.in_edges(node)
            if len(in_edges) == 0:
                n_variable = 1
                if "multiple_variables" in features:
                    n_variable = random.randint(2, 4)
                terms = []
                for i in range(0, n_variable):
                    variable_letter = utils.random_upper_case()
                    terms.append(Term(variable_letter, type="variable"))
                graph.nodes[node]['entity'] = tuple(terms)
            else:
                entity = set()
                for edge in in_edges:
                    in_node: str = edge[0]
                    for e in graph.nodes[in_node]['entity']:
                        entity.add(e)
                graph.nodes[node]['entity'] = tuple(entity)

    return graph, new_node, out_node


def extract_program(graph: nx.Graph) -> typing.Tuple[typing.List[Rule], defaultdict]:
    rule_map = {}
    for u, v, a in graph.edges(data=True):
        if not graph.edges[u, v]['has_rule']:
            continue
        if v not in rule_map.keys():
            rule_map[v] = []
        rule_map[v].append(a['rule'])
    rules = []
    for k, v in rule_map.items():
        literals: typing.List[Literal] = []
        for literal in v:
            literals.append(literal)
        rule = Rule(
            head=Atom(k, graph.nodes[k]['entity'], None), body=literals)
        rules.append(rule)

    data = defaultdict(lambda: defaultdict(list))
    for node in graph.nodes:
        if graph.nodes[node]['with_data']:
            data_entry = graph.nodes[node]['data']
            utils.merge_data(data, data_entry)
    return rules, data


def evaluate_graph(graph: nx.Graph) -> defaultdict:
    rules, data = extract_program(graph)
    dataset, delta_new = utils.infer(data, rules)
    return dataset, delta_new


def generate_rules(graph: nx.Graph, known_data: defaultdict, features: set) -> bool:
    def rand_num(minv, maxv):
        if "rational_number" in features:
            return decimal.Decimal(random.random())*(maxv - minv) + minv
        else:
            return random.randint(minv, maxv)
    possible_operators = ["Boxminus", "Boxplus", "Diamondplus", "Diamondminus"]
    if "mixed_operators" not in features:
        single_operator = random.choice(possible_operators)
    else:
        selected_operators = set()

    data, _ = evaluate_graph(graph)
    new_target = set()

    for u, v, a in graph.edges(data=True):
        if not graph.edges[u, v]['has_rule']:
            if "mixed_operators" in features:
                # operator = random.choice(possible_operators)
                prob = np.zeros(4)
                for i in range(0, 4):
                    if possible_operators[i] in selected_operators:
                        prob[i] = 0.1
                    else:
                        prob[i] = 0.9
                prob = prob/prob.sum()
                operator = choice(possible_operators, 1, p=prob)[0]
                selected_operators.add(operator)
            else:
                operator = single_operator
            start = rand_num(0, 10)
            end = rand_num(start, 15)
            graph.edges[u, v]['has_rule'] = True
            graph.edges[u, v]['rule'] = Literal(Atom(u, graph.nodes[u]['entity'], None), [
                                                Operator(operator, Interval(start, end, False, False))])
            new_target.add(v)
    if "mixed_operators" in features:
        if len(selected_operators) == 1:
            return False

    rule_map = {}
    for u, v, a in graph.edges(data=True):
        if v in new_target:
            assert (graph.edges[u, v]['has_rule'])
            if v not in rule_map.keys():
                rule_map[v] = []
            rule_map[v].append(a['rule'])
    new_rules = []
    for k, v in rule_map.items():
        literals = []
        for literal in v:
            literals.append(literal)
        rule = Rule(
            head=Atom(k, graph.nodes[k]['entity'], None), body=literals)
        new_rules.append(rule)

    dataset, delta_new = utils.infer(data, new_rules)

    if len(delta_new) == 0:
        return False
    return True


def add_additional_data(data: defaultdict, rate: int):
    num = len(data.keys())*rate
    for i in range(num):
        predicate = id_to_char(random.randint(0, 25))
        while predicate in data.keys():
            predicate = id_to_char(random.randint(0, 25))
        entity = tuple([Term("nan")])
        lrange = random.randint(0, 15)
        rrange = random.randint(lrange, 20)
        interval = Interval(lrange, rrange, False, False)
        utils.merge_data(data, (predicate, entity, interval))


def add_additional_rules(data: defaultdict, rules: typing.List[Rule], rate: int):
    possible_operators = ["Boxminus", "Boxplus", "Diamondplus", "Diamondminus"]
    nanentity = tuple([Term("nan")])
    num = len(rules)*rate
    known_predicate = list(data.keys())
    for i in range(num):
        new_predicate = utils.random_upper_case()
        while new_predicate in known_predicate:
            new_predicate = utils.random_upper_case()
        lrange = random.randint(0, 15)
        rrange = random.randint(lrange, 20)
        interval = Interval(lrange, rrange, False, False)
        
        rules.append(Rule(head=Atom(new_predicate, nanentity, None), body=[Literal(Atom(random.choice(known_predicate), nanentity, None), [
            Operator(random.choice(possible_operators), interval)])]))


def generate(features: set, additional_args: dict, valid: bool) -> Instance:
    def rand_num(minv, maxv):
        if "rational_number" in features:
            return decimal.Decimal(random.random())*(maxv - minv) + minv
        else:
            return random.randint(minv, maxv)

    if "mixed_operators" in features:
        assert ("multiple_body_atoms" in features)

    if "multiple_variables" in features:
        assert ("variable" in features)

    n_rule = 1
    if "multiple_rules" in features:
        n_rule = random.randint(2, 6)
    if "fixed_rules" in additional_args.keys():
        n_rule = additional_args['fixed_rules']
        if n_rule > 1:
            assert ("multiple_rules" in features)
        else:
            assert ("multiple_rules" not in features)
        # n_rule = 2
    graph = nx.DiGraph()
    known_data = defaultdict(lambda: defaultdict(list))
    counter = 0
    for j in range(n_rule):
        while True:
            counter += 1
            old_graph = graph.copy()
            # Generate Graph
            graph, generated_nodes, out_node = generate_graph(
                graph, features, additional_args)

            # dump_graph(graph)

            # Generate Data
            for node in generated_nodes:
                if hash(node) == hash(out_node) and "recursive" not in features:
                    continue
                lrange = rand_num(0, 15)
                rrange = rand_num(lrange, 20)
                terms = []
                for k in graph.nodes[node]['entity']:
                    if k.type == "constant":
                        terms.append(Term("nan"))
                    elif k.type == "variable":
                        terms.append(
                            Term(utils.random_lower_case(), type="constant"))
                entity = tuple(terms)
                interval = Interval(lrange, rrange, False, False)
                graph.nodes[node]['data'] = (
                    node, entity, interval)
                graph.nodes[node]['with_data'] = True
            # dump_graph(graph)
            available = generate_rules(graph, known_data, features)
            # dump_graph(graph)
            if available:
                # print(out_node)
                # nx.draw_networkx(graph)
                # plt.show()
                break
            graph = old_graph
            if counter > n_rule*64:
                print("Failed to generate")
                return None

    # Validation
    rules, data = extract_program(graph)
    query_predicate = out_node
    remove_unrelated_rule(graph, data.keys(), query_predicate)
    rules, data = extract_program(graph)
    dataset, delta_new = evaluate_graph(graph)

    entity, intervals = random.choice(list(delta_new[query_predicate].items()))
    if valid:
        interval: Interval = random.choice(list(intervals))
        start_shrink = rand_num(0, interval.right_value - interval.left_value)
        end_shrink = rand_num(0, max(interval.right_value -
                                     interval.left_value-start_shrink - 1, 0))
        final_interval = Interval(interval.left_value +
                                  start_shrink, interval.right_value-end_shrink, False, False)
    else:
        counter = 0
        while True:
            counter += 1
            lf = rand_num(-34, 28)
            rf = rand_num(lf, 30)
            final_interval = Interval(lf, rf, False, False)
            if not utils.entail(dataset, (query_predicate, entity, final_interval)):
                break
            if counter > 128:
                return None
    if "addl_data" in additional_args.keys():
        add_additional_data(data, additional_args['addl_data'])
    if "addl_rules" in additional_args.keys():
        assert ("variable" not in features)
        add_additional_rules(data, rules, additional_args['addl_rules'])
    query = (query_predicate, entity, final_interval)
    instance = Instance(rules, data, query, valid)

    return instance
    # print(graph)


def remove_unrelated_rule(graph: nx.DiGraph, known_nodes, query_node):
    graph_reversed = graph.reverse()
    reachable_nodes = set()
    node_counter = defaultdict(int)
    node_counter[query_node] += 1
    for known_node in known_nodes:
        reachable_nodes = reachable_nodes.union(
            nx.descendants(graph, known_node))
        node_counter[known_node] += 1
    reachable_nodes_reserved = nx.descendants(graph_reversed, query_node)
    for node in reachable_nodes:
        node_counter[node] += 1
    for node in reachable_nodes_reserved:
        node_counter[node] += 1
    to_remove = []
    for node in graph.nodes():
        if node_counter[node] < 2:
            to_remove.append(node)
    for node in to_remove:
        graph.remove_node(node)


def to_json(instance: Instance):
    return {"data": utils.dump_dataset_array(instance.data),
            "rule": list(map(utils.dump_rule, instance.program)),
            "query": utils.dump_single_data(instance.query),
            "valid": instance.valid}


def to_sign(features: set, additional_args: dict) -> str:
    feature_list = list(features)

    if len(feature_list) == 0:
        feature_list.append("basic")
    for k, v in additional_args.items():
        if v is None:
            continue
        feature_list.append("%s=%s" % (k, str(v)))
    feature_list.sort()
    ret = "-".join(feature_list)
    return ret


def main():
    args = default_argument_parser().parse_args()
    features = set()
    FEATURES_LIST = ["rational_number", "multiple_body_atoms",
                     "recursive", "variable", "mixed_operators",
                     "multiple_rules", "multiple_rules_2", "multiple_rules_3",
                     "multiple_body_atoms_2", "multiple_body_atoms_3", "multiple_body_atoms_4"]
    ADDITIONAL_ARGS_KEYS = ["fixed_new_nodes",
                            "fixed_rules", "addl_data", "addl_rules"]
    additional_args = {}
    for arg in vars(args):
        if getattr(args, arg) and arg in FEATURES_LIST:
            features.add(arg)
        if arg in ADDITIONAL_ARGS_KEYS:
            if getattr(args, arg) is not None:
                additional_args[arg] = getattr(args, arg)
    if not args.single:
        results = []
        for i in tqdm.tqdm(range(0, args.n)):
            instance = None
            while instance is None:
                instance = generate(
                    features, additional_args, True)
            results.append(to_json(instance))
        for i in tqdm.tqdm(range(0, args.n)):
            instance = None
            while instance is None:
                instance = generate(
                    features, additional_args, False)
            results.append(to_json(instance))

        file_name = "data/%s.json" % to_sign(features, additional_args)
        json.dump(results, open(file_name, "w"))
    else:
        instance = None
        while instance is None:
            instance = generate(
                features, additional_args, True)
        print(str(instance))


if __name__ == "__main__":
    main()
