import random
from lgw.args import get_args
from lgw.rule_generator import RuleGen
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import copy
from addict import Dict
from collections import Counter
from tqdm import tqdm
import pickle
from sklearn.model_selection import train_test_split
import itertools
import json
from os import path
from addict import Dict


class BaseGraphGen:
    def __init__(self, rule_gen: RuleGen, args, id=0, done_nodes=None, gen_graph=True):
        """
        Graph generated is stored in a dictionary with (edge->relation)
        :param rules:
        :param args:
        :param done_nodes: nodes that are done adding edges
        """
        self.id = id
        if type(args) != Dict:
            self.args = Dict(vars(args))
        else:
            self.args = args
        if rule_gen:
            self.rules = rule_gen.D
            self.rule_prob = rule_gen.rule_prob
            self.world_mode = rule_gen.world_mode
        else:
            self.rules = {}
            self.world_mode = ""
            self.rule_prob = {}
        self.rule_gen = rule_gen
        self.pos_predicates = set()
        self.neg_predicates = set()
        self.source_lookup = {}
        self.target_lookup = {}
        for body, head in self.rules.items():
            preds = [head]
            if type(body) == tuple:
                preds.extend(list(body))
            else:
                preds.append(body)
            for p in preds:
                if "+" in p:
                    self.pos_predicates.add(p)
                else:
                    self.neg_predicates.add(p)
        # categorize unit rules
        self.unit_rules = {}
        for body, head in self.rules.items():
            if type(body) == str:
                self.unit_rules[body] = head
                self.unit_rules[head] = body
        self.nodes = []
        if done_nodes:
            self.done_nodes = done_nodes
        else:
            self.done_nodes = set()  # nodes done adding to the graph
        self.graph = {}
        self.forward_edges = []
        self.backward_edges = []
        self.noise_candidates = []  # list of edges which are noise candidates
        self.target = None  # target for link prediction,
        # which is not present in the graph

        # sanity checks
        if self.args.num_nodes < 3:
            raise AssertionError("Too few nodes")
        if self.args.expand_steps > self.args.num_nodes + 2:
            raise AssertionError(
                "expand_steps {} > num_nodes {} + 2".format(
                    self.args.expand_steps, self.args.num_nodes
                )
            )

        # keep track of which rules used in the graph
        self.used_rules = []
        self.resolution_path = []
        self.used_train_descriptors = {}
        self.gen_nodes()

    def gen_nodes(self):
        """
        Generate nodes of the graph
        :return:
        """
        self.nodes = list(range(self.args.num_nodes))
        self.graph = {}

    def copy_edges(self, edge_dict, target_edge=None, resolution_path=None):
        """ Given an edge:relation dict, copy it into a new graph, and re-index the edges, and return that graph
        """
        new_node_map = {}
        for edge in edge_dict:
            if edge[0] not in new_node_map:
                new_node_map[edge[0]] = len(new_node_map)
            if edge[1] not in new_node_map:
                new_node_map[edge[1]] = len(new_node_map)
        gg = BigGraphGen(self.rule_gen, self.args, gen_graph=False)
        gg.graph = {}
        for edge in edge_dict:
            gg.graph[(new_node_map[edge[0]], new_node_map[edge[1]])] = edge_dict[edge]
        gg.nodes = list(range(len(new_node_map)))
        try:
            if target_edge:
                gg.target = {
                    (
                        new_node_map[target_edge[0]],
                        new_node_map[target_edge[1]],
                    ): self.graph[target_edge]
                }
            if resolution_path:
                gg.resolution_path = [new_node_map[node] for node in resolution_path]
        except:
            import ipdb

            ipdb.set_trace()
        return gg

    def is_reverse_edge(self, edge):
        return edge in self.backward_edges

    def sample_body(self, head, with_replacement=True, gamma=0.8):
        """
        Sample a body from the rule base given the head
        :param head: head of the rule
        :param reverse:
        :return:
        """
        candidate_rules = [
            [body, h]
            for body, h in self.rules.items()
            if head == h and type(body) == tuple
        ]
        candidate_bodies = [rule[0] for rule in candidate_rules]
        if len(candidate_bodies) == 0:
            raise AssertionError("cand bodies end")
        body_probs = [self.rule_gen.rule_prob[body] for body in candidate_bodies]
        if sum(body_probs) <= 0:
            raise AssertionError("div by zero")
        # normalize
        body_probs = np.array(body_probs) / sum(body_probs)
        indx = np.random.choice(list(range(len(candidate_bodies))), p=body_probs)
        if not with_replacement:
            self.rule_gen.rule_prob[candidate_bodies[indx]] = (
                self.rule_gen.rule_prob[candidate_bodies[indx]] * gamma
            )
        return candidate_bodies[indx], candidate_rules[indx]

    def sample_edge(self):
        return random.choice([k for k, v in self.graph.items()])

    def sample_predicate(self, reverse=False):
        """
        Sample a predicate from the rule base
        :param reverse:
        :return:
        """
        if reverse:
            return random.choice(list(self.neg_predicates))
        else:
            return random.choice(list(self.pos_predicates))

    def add_edge(self, edge, rel):
        """
        Add edge to the graph
        :return:
        """
        if edge not in self.graph:
            self.graph[edge] = rel
            if rel[0] not in self.source_lookup:
                self.source_lookup[edge[0]] = []
            self.source_lookup[edge[0]].append(edge[-1])
            if rel[-1] not in self.target_lookup:
                self.target_lookup[edge[-1]] = []
            self.target_lookup[edge[-1]].append(edge[0])
            if "+" in rel:
                self.forward_edges.append(edge)
            else:
                self.backward_edges.append(edge)
        else:
            # "edge already in the graph"
            pass

    def remove_edge(self, edge):
        """
        Removes an edge from the graph
        :param edge:
        :return:
        """
        reverse = self.is_reverse_edge(edge)
        if edge in self.graph:
            del self.graph[edge]
            if reverse:
                self.backward_edges.remove(edge)
            else:
                self.forward_edges.remove(edge)

    def remove_both_edges(self, edge):
        """
        Remove both forward and backward edges if present
        :param edge:
        :return:
        """
        self.remove_edge(edge)
        self.remove_edge((edge[-1], edge[0]))

    def set_target(self, from_available=True):
        """
        During graph generation, first set the correct target
        if the edge is already present in the graph, return its predicate
        :return:
        """
        if from_available:
            available_nodes = list(set(self.nodes) - self.done_nodes)
        else:
            # todo: maybe change this line to select with some probability
            # targets from self.done_nodes, in order to foster edge creation
            # from within a done cycle
            if random.uniform(0, 1) > 0.5:
                available_nodes = list(set(self.nodes) - self.done_nodes)
                # available_nodes = list(set(self.nodes)) <- in this setup, we were ending up with a discontinuity
            else:
                available_nodes = list(set(self.done_nodes))
        node_a, node_b = random.sample(available_nodes, 2)
        target_edge = (node_a, node_b)
        if target_edge in self.graph:
            target_rel = self.graph[target_edge]
        else:
            flip = random.uniform(0, 1) > 0.5
            target_rel = self.sample_predicate(reverse=flip)
        self.target = {target_edge: target_rel}
        return node_a, node_b, target_rel

    def get_unique_used_rules(self):
        rules = self.used_rules
        rules = set([tuple(rule[0]) for rule in rules])
        return rules

    def get_rules_count(self):
        return Counter([tuple(rule[0]) for rule in self.used_rules])

    def print_graph(self):
        print(self.graph)

    def get_all_nodes(self):
        return list(set([a for k, v in self.graph.items() for a in list(k)]))

    def convert_nx_graph(self, directed=True):
        """
        Convert to a networkx graph
        :param directed:
        :return:
        """
        if directed:
            G = nx.DiGraph()
        else:
            G = nx.Graph()
        for edge in self.graph:
            G.add_edge(edge[0], edge[1])
        return G

    def draw_nx_graph(self):
        """
        Draw networkx graph
        :return:
        """
        G = self.convert_nx_graph()
        f = plt.figure()
        nx.draw(G, ax=f.add_subplot(111), node_size=5)
        f.savefig("graph.png")

    def get_all_shortest_paths(self, directed=True):
        """
        Get all shortest paths using NetworkX
        :return:
        """
        G = self.convert_nx_graph()
        paths = []
        source, sink = list(self.target.keys())[0]
        for path in nx.all_shortest_paths(G, source=source, target=sink):
            # for path in paths:
            pt = []
            for i in range(len(path) - 1):
                pt.append((path[i], path[i + 1]))
            paths.append(pt)
        # print(paths)
        rel_paths = [[self.graph[e] for e in path] for path in paths]
        return rel_paths

    def solve_graph(self):
        """
        Given the generated graph, return the proof trace to solve given
        the target
        basically find all shortest paths in the graph
        :return:
        """
        rel_paths = self.get_all_shortest_paths(directed=True)
        # print(rel_paths)
        # for now, assume we are not deleting any edge. Later, this condition will not hold
        assert len(rel_paths) > 0

    def get_node_id(self, node):
        return "{}.{}".format(self.id, node)
        # return '{}'.format(node)

    def __eq__(self, other):
        # self_graph = self.convert_nx_graph()
        # other_graph = other.convert_nx_graph()
        # return nx.is_isomorphic(self_graph, other_graph)
        self_rel_paths = self.get_all_shortest_paths()
        self_rel_paths = [" ".join(path) for path in self_rel_paths]
        other_rel_paths = other.get_all_shortest_paths()
        other_rel_paths = [" ".join(path) for path in other_rel_paths]
        return set(self_rel_paths) == set(other_rel_paths)

    def __hash__(self):
        self_rel_paths = self.get_all_shortest_paths()
        self_rel_paths = set([" ".join(path) for path in self_rel_paths])
        return sum([hash(p) for p in self_rel_paths])

    def jsonify(self, is_meta_graph=False):
        edges = []
        for body, rel in self.graph.items():
            # edges.append([body[0],body[1], rg.rel2id[rel]])
            edges.append([body[0], body[1], rel])
        # write query
        if not is_meta_graph:
            edge = list(self.target.keys())[0]
            label = self.target[edge]
        return {
            "edges": edges,
            "query": [edge[0], edge[1], label] if not is_meta_graph else [0, 0, 0],
            "rules": self.used_rules,
            "resolution_path": self.resolution_path if not is_meta_graph else [],
        }


class GraphGen(BaseGraphGen):
    """ Pathwise graph generator
    """

    def __init__(self, rule_gen: RuleGen, args, id=0, done_nodes=None, gen_graph=True):
        super(rule_gen, args, id, done_nodes, gen_graph)
        # corruption
        self.corrupt_mapping = {}
        self.num_cor_preds = 0
        self.corrupt()
        if gen_graph:
            self.graph_generated = self.gen_clean_graph()

    def gen_clean_graph(self):
        """
        Generate the graph
        1. sample two nodes, add an edge between them, this becomes the target edge
        2.
        :return:
        """
        # expand_steps = random.choice(range(1,self.args.expand_steps))
        expand_steps = self.args.expand_steps
        for i in range(expand_steps):
            if i == 0:
                tr = 0
                while True:
                    try:
                        tr += 1
                        node_a, node_b, target_rel = self.set_target()
                        body, rule = self.sample_body(target_rel)
                        break
                    except:
                        if tr > 5:
                            # too many tries, leave generation
                            return False
                        continue
            else:
                tr = 0
                while True:
                    try:
                        tr += 1
                        edge = self.sample_edge()
                        body, rule = self.sample_body(self.graph[edge])
                        break
                    except:
                        if tr > 5:
                            # too many tries, leave generation
                            return False
                        continue
                self.remove_both_edges(edge)
                node_a, node_b = edge
            self.done_nodes.add(node_a)
            self.done_nodes.add(node_b)
            available_nodes = list(set(self.nodes) - self.done_nodes)
            if len(available_nodes) > 0:
                # add
                new_node = random.choice(available_nodes)
                self.done_nodes.add(new_node)
                self.add_edge((node_a, new_node), body[0])
                self.add_edge((new_node, node_b), body[1])
                self.used_rules.append(rule)
                if self.args.bidirectional:
                    # add inverses
                    if body[0] in self.unit_rules:
                        self.add_edge((new_node, node_a), self.unit_rules[body[0]])
                    if body[1] in self.unit_rules:
                        self.add_edge((node_b, new_node), self.unit_rules[body[1]])
            else:
                # no more nodes left, expansion couldn't succeed
                return False
        # print("Generation done")
        return True

    def gen_world_graph(self, num_cycles=50):
        """
        Generate the "world" graph
        - With the same rule set, generate a graph where we do not remove the answers
        - should set the number of nodes quite high, > 1000
        - generate as many cycles till the usage of the rules hit the number of unique rules available
        """
        pt = 0
        expand_steps = self.args.expand_steps
        num_unique_rules = len(self.rule_gen.get_compositional_bodies())
        change_mode = "down"  # up or down
        ni = 0
        continue_main = False
        while len(self.get_unique_used_rules()) < num_unique_rules:
            # expand_steps = random.choice(range(2, self.args.expand_steps))
            cycle_edges = []  # edges generated in this cycle
            ct = 0
            target_edge = None
            for i in range(expand_steps):
                if pt == 0:
                    tr = 0
                    while True:
                        try:
                            tr += 1
                            node_a, node_b, target_rel = self.set_target()
                            body, rule = self.sample_body(target_rel)
                            target_edge = ((node_a, node_b), target_rel)
                            break
                        except:
                            if tr > 50:
                                # too many tries, leave generation
                                return False
                            continue
                else:
                    tr = 0
                    while True:
                        try:
                            tr += 1
                            if ct == 0:
                                # sample an edge from the entire graph to begin recursive generation
                                edge = self.sample_edge()
                            else:
                                # sample an edge from the current cycle
                                edge = random.choice(cycle_edges)
                            body, rule = self.sample_body(self.graph[edge])
                            break
                        except:
                            if tr > 50:
                                # too many tries, leave generation
                                # return False
                                continue_main = True
                                break
                            continue
                    if continue_main:
                        continue_main = False
                        continue
                    if ct > 0:
                        # remove the edges from the second creation
                        self.remove_both_edges(edge)
                        cycle_edges.remove(edge)
                    node_a, node_b = edge
                self.done_nodes.add(node_a)
                self.done_nodes.add(node_b)
                pt += 1
                ct += 1
                # add
                avlb_nodes = list(set(self.nodes) - self.done_nodes)
                if len(avlb_nodes) == 0:
                    new_node = len(self.nodes)
                    self.nodes.append(new_node)
                else:
                    new_node = random.choice(avlb_nodes)
                self.done_nodes.add(new_node)
                self.add_edge((node_a, new_node), body[0])
                cycle_edges.append((node_a, new_node))
                self.add_edge((new_node, node_b), body[1])
                cycle_edges.append((new_node, node_b))
                self.used_rules.append(rule)
                # add inverses
                if body[0] in self.unit_rules:
                    self.add_edge((new_node, node_a), self.unit_rules[body[0]])
                    cycle_edges.append((new_node, node_a))
                if body[1] in self.unit_rules:
                    self.add_edge((node_b, new_node), self.unit_rules[body[1]])
                    cycle_edges.append((node_b, new_node))
            if ni == 0:
                self.add_edge(target_edge[0], target_edge[1])
            ni += 1

            if expand_steps == 2:
                change_mode = "up"
            if expand_steps == self.args.expand_steps:
                change_mode = "down"
            if change_mode == "down":
                expand_steps -= 1
            else:
                expand_steps += 1
            # print(change_mode)
            # print(expand_steps)
            # print(len(self.get_unique_used_rules()))

        # print("Generation done")
        return True

    def gen_noise(self):
        """
        Generate noise - similar to CLUTRR
            - Supporting
            - Irrelevant
            - Disconnected
        :return:
        """

        def add_supporting_edge():
            """
            Add a cycle by adding two new nodes in the graph
            - Select any node in the graph
            - Add two nodes and form a cycle
            :return:
            """
            try:
                available_nodes = list(set(self.nodes) - self.done_nodes)
                assert len(available_nodes) >= 2
                extra_node = random.choice(available_nodes)
                graph_node = random.choice(list(self.done_nodes))
                edge = (graph_node, extra_node)
                flip = random.uniform(0, 1) > 0.5
                rel = self.sample_predicate(reverse=flip)
                self.add_edge((graph_node, extra_node), rel)
                body, rule = self.sample_body(self.graph[edge])
                node_a, node_b = edge
                self.done_nodes.add(extra_node)
                new_node = random.choice(list(set(self.nodes) - self.done_nodes))
                self.add_edge((node_a, new_node), body[0])
                self.add_edge((new_node, node_b), body[1])
                return True
            except:
                return False

        def add_irrelevant_edge():
            """
            Add a dangling edge to any node in the graph
            :return:
            """
            try:
                available_nodes = list(set(self.nodes) - self.done_nodes)
                assert len(available_nodes) >= 1
                extra_node = random.choice(available_nodes)
                graph_node = random.choice(list(self.done_nodes))
                flip = random.uniform(0, 1) > 0.5
                rel = self.sample_predicate(reverse=flip)
                self.add_edge((graph_node, extra_node), rel)
                # consider the extra node as part of main graph
                self.done_nodes.add(extra_node)
                return True
            except:
                return False

        def add_disconnected_edge():
            """
            Add separate disconnected edges in the graph
            :return:
            """
            try:
                available_nodes = list(set(self.nodes) - self.done_nodes)
                assert len(available_nodes) >= 2
                node_a, node_b = random.sample(available_nodes, 2)
                flip = random.uniform(0, 1) > 0.5
                rel = self.sample_predicate(reverse=flip)
                self.add_edge((node_a, node_b), rel)
                self.done_nodes.add(node_a)
                return True
            except:
                return False

        # max noisy edges should be equal to expand steps + 1
        noise_steps = random.choice(range(2, self.args.expand_steps + 1))
        while noise_steps > 0:
            s = add_supporting_edge()
            i = add_irrelevant_edge()
            d = add_disconnected_edge()
            if s or i or d:
                if s:
                    noise_steps -= 1
                if i:
                    noise_steps -= 1
                if d:
                    noise_steps -= 1
            else:
                break

    def corrupt(self):
        """
        Idea here is to maintain a mapping between the true rules and their display rules.
        for each true rule, create an epsilon greedy probability mapping with display rules
        where, if epsilon = 0, then each rule maps into the same display rule with 100% probability
        keep this mapping separate for pos and neg predicates
        - first create a mapping directory
        - modify the self.rules w.r.t the mapping directory
        :return:
        """
        cor_d = {}
        for preds in [self.pos_predicates, self.neg_predicates]:
            for pred in preds:
                other_preds = list(copy.deepcopy(preds))
                other_preds.remove(pred)
                prev_mappings = [v for k, v in cor_d.items() if k != v]
                for pm in prev_mappings:
                    if pm in other_preds:
                        other_preds.remove(pm)
                if len(other_preds) == 0:
                    raise AssertionError("what")
                cor_prob = self.rule_gen.corrupt_prob[pred]
                if random.uniform(0, 1) > cor_prob:
                    # keep the same
                    cor_d[pred] = pred
                else:
                    # select one from other_preds
                    cor_d[pred] = random.choice(other_preds)
        # now apply
        new_rules = {}
        for body, head in self.rules.items():
            if type(body) == tuple:
                new_rules[(cor_d[body[0]], cor_d[body[1]])] = cor_d[head]
            else:
                new_rules[cor_d[body]] = cor_d[head]
        self.rules = new_rules
        self.corrupt_mapping = cor_d
        num_cor = len([v for k, v in cor_d.items() if k != v])
        self.num_cor_preds = num_cor

    def contains_corrupted_predicate(self):
        """
        returns true if graph contains corrupted predicates
        :return:
        """
        corrupted_preds = [k for k, v in self.corrupt_mapping.items() if k != v]
        if self.num_cor_preds > 0:
            for body, head in self.graph.items():
                if head in corrupted_preds:
                    return True
                if type(body) == tuple:
                    if body[0] in corrupted_preds or body[1] in corrupted_preds:
                        return True
                else:
                    if body in corrupted_preds:
                        return True
        return False


class BigGraphGen(BaseGraphGen):
    """ Generate a Massive Graph first, then sample smaller graphs from it
    """

    def __init__(
        self, rule_gen: RuleGen = None, args=None, id=0, done_nodes=None, gen_graph=True
    ):
        super().__init__(rule_gen, args, id, done_nodes, gen_graph)
        self.target_cycle = None
        if gen_graph:
            self.graph_generated = self.gen_big_graph()

    def get_common_nodes(self, edge):
        return list(
            set(self.source_lookup[edge[0]]).intersection(
                set(self.target_lookup[edge[-1]])
            )
        )

    def gen_big_graph(self, max_nodes=1000, max_cycles=10):
        """
        Generate the massive graph
        Make it almost complete
        :max_cycles: max number of times each min rule is allowed to be present
        :return:
        """
        # keep track of rule usage
        # complete the initial graph generation unitll all rules are used in min times
        used_nodes = []
        pb = tqdm(total=100)
        while True:
            break_generation = False
            break_cause = ""
            # expand_steps = random.choice(range(1,self.args.expand_steps))
            expand_steps = random.choice(range(2, self.args.expand_steps))
            added_edges = []
            # one cycle is one resolution step
            for i in range(expand_steps):
                if i == 0:
                    tr = 0
                    while True:
                        try:
                            tr += 1
                            node_a, node_b, target_rel = self.set_target(
                                from_available=False
                            )
                            self.add_edge((node_a, node_b), target_rel)
                            added_edges.append((node_a, node_b))
                            body, rule = self.sample_body(
                                target_rel, with_replacement=False
                            )
                            break
                        except Exception as e:
                            if tr > 5:
                                # too many tries, leave generation
                                break_generation = True
                                break_cause = repr(e)
                                break
                            continue
                else:
                    tr = 0
                    while True:
                        try:
                            tr += 1
                            edge = random.choice(added_edges)
                            body, rule = self.sample_body(
                                self.graph[edge], with_replacement=False
                            )
                            break
                        except:
                            if tr > 5:
                                # too many tries, leave generation
                                break_generation = True
                                break_cause = "too many tries 2"
                                break
                            continue
                    node_a, node_b = edge
                if break_generation:
                    break
                self.done_nodes.add(node_a)
                self.done_nodes.add(node_b)
                used_nodes.append(node_a)
                used_nodes.append(node_b)
                available_nodes = list(set(self.nodes) - self.done_nodes)
                if len(available_nodes) > 0:
                    # add
                    new_node = random.choice(available_nodes)
                    self.done_nodes.add(new_node)
                    used_nodes.append(new_node)
                    self.add_edge((node_a, new_node), body[0])
                    self.add_edge((new_node, node_b), body[1])
                    added_edges.append((node_a, new_node))
                    added_edges.append((new_node, node_b))
                    self.used_rules.append(rule)
                    # # add inverses
                    # if body[0] in self.unit_rules:
                    #     self.add_edge((new_node, node_a), self.unit_rules[body[0]])
                    #     added_edges.append((new_node, node_a))
                    # if body[1] in self.unit_rules:
                    #     self.add_edge((node_b, new_node), self.unit_rules[body[1]])
                    #     added_edges.append((node_b, new_node))
                else:
                    # no more nodes left, expansion couldn't succeed
                    break_cause = "no more nodes left"
                    break
            rules_used = len(self.get_unique_used_rules())
            total_rules = len(self.rule_gen.get_compositional_bodies())
            # pb.update(int((rules_used / total_rules) * 100) - pb.n)
            pb.update(int((len(self.done_nodes) / max_nodes) * 100) - pb.n)
            pb.set_description(
                "N = {} / {}, E = {}, Used Rules = {}, Total Rules = {}, Cycles Left : {}, Cause: {}".format(
                    len(self.done_nodes),
                    max_nodes,
                    len(self.graph),
                    rules_used,
                    total_rules,
                    max_cycles,
                    break_cause,
                )
            )
            if rules_used == total_rules:
                max_cycles -= 1
                if max_cycles == 0:
                    break
                self.rule_gen.copy_probs()
                self.used_rules = []
            if len(self.done_nodes) > max_nodes:
                break
            if rules_used == 0:
                self.rule_gen.copy_probs()
        pb.close()
        # print("Making graph almost complete")
        # self.make_almost_complete(list(set(self.done_nodes)))
        print("Generation done")

    def make_almost_complete(self, nodes):
        # make the graph almost complete
        pb = tqdm(total=len(nodes))
        for node_a in nodes:
            for node_b in nodes:
                if node_a != node_b:
                    if (node_a, node_b) not in self.graph:
                        if (
                            node_a in self.source_lookup
                            and node_b in self.target_lookup
                        ):
                            common_nodes = self.get_common_nodes((node_a, node_b))
                            if len(common_nodes) > 0:
                                candidate_heads = [
                                    self.rule_gen.D[
                                        (
                                            self.graph[(node_a, node)],
                                            self.graph[(node, node_b)],
                                        )
                                    ]
                                    for node in common_nodes
                                    if (
                                        self.graph[(node_a, node)],
                                        self.graph[(node, node_b)],
                                    )
                                    in self.rule_gen.D
                                ]
                                if len(candidate_heads) > 0:
                                    sample_head = random.choice(candidate_heads)
                                    self.add_edge((node_a, node_b), sample_head)
            pb.update(1)
            pb.set_description(
                "N = {}, E = {}".format(len(self.done_nodes), len(self.graph))
            )
        pb.close()

    def compute_target2edges(self):
        self.t2e = {}
        for edge, rel in self.graph.items():
            if rel not in self.t2e:
                self.t2e[rel] = []
            self.t2e[rel].append(edge)

    def get_path_descriptor(self, path):
        """
        Given a path of nodes, get the string descriptor
        edge, [1,2,3] -> "R_1+,R_0-"
        """
        desc = []
        for ni in range(len(path) - 1):
            desc.append(self.graph[(path[ni], path[ni + 1])])
        return ",".join(desc)

    def pre_compute_paths(self, max_len=10):
        """ Precompute all simple paths
        1. Sample an edge of type R
        2. Compute all simple paths with some cutoff
        3. Save in the dictionary of R : { length : { desc: [list] } }
        """
        self.G = self.convert_nx_graph()
        self.compute_target2edges()
        path_dict = {}
        path_cutoff = 10  # self.args.path_cutoff
        pb = tqdm(total=len(self.t2e))
        pb.set_description("C : {}".format(path_cutoff))
        for target, edges in self.t2e.items():
            path_dict[target] = {}
            for edge in edges:
                for path in nx.all_simple_paths(
                    self.G, source=edge[0], target=edge[1], cutoff=path_cutoff
                ):
                    path_len = len(path) - 1  # because of relations
                    if path_len == 1:  # skipping paths which is one edge
                        continue
                    path_desc = self.get_path_descriptor(path)
                    # only allow paths of the same sign
                    pdl = path_desc.split(",")
                    path_rs_sign = [p.split("_")[-1] for p in pdl]
                    if len(set(path_rs_sign)) != 1:
                        continue
                    # if path_len not in path_dict:
                    #     path_dict[path_len] = {}
                    if path_desc not in path_dict[target]:
                        path_dict[target][path_desc] = []
                    path_dict[target][path_desc].append(path)
            pb.update(1)
        pb.close()
        self.path_dict = path_dict
        print("all paths computed")

    def split_train_test_descriptors(self, split_per=0.8):
        """ Split path descriptors in train and test
        split such that the target is equally represented
        """
        splits_d = {}
        for target, p in self.path_dict.items():
            desc = list(p.keys())
            if len(desc) > 0:
                tv, test = train_test_split(desc, shuffle=True, train_size=split_per)
                train, valid = train_test_split(tv, shuffle=True, train_size=split_per)
                splits_d[target] = {"train": train, "valid": valid, "test": test}
        self.descriptor_splits = splits_d

    def split_train_test_descriptors_clutrr(self, split_per=0.8):
        """Split path descriptors in train, valid and test
        split such that target is equally represented, and
        in same fasion as cluttr : all k=2 in train, and split uniformly k=3,4,5 ...
        """
        splits_d = {}
        for target, p in self.path_dict.items():
            desc_lens = {}  # length -> descriptors
            desc = list(p.keys())
            if len(desc) > 0:
                for d in desc:
                    d_len = len(d.split(","))
                    if d_len not in desc_lens:
                        desc_lens[d_len] = []
                    desc_lens[d_len].append(d)
                splits_d[target] = {"train": [], "valid": [], "test": []}
                for dlen, len_ds in desc_lens.items():
                    if dlen == 2:
                        splits_d[target]["train"].extend(len_ds)
                    else:
                        if len(len_ds) > 1:
                            train, test = train_test_split(
                                len_ds, shuffle=True, train_size=split_per
                            )
                            valid = random.sample(
                                train, int(len(train) * (1 - split_per))
                            )
                            # train, valid = train_test_split(
                            #     tv, shuffle=True, train_size=split_per
                            # )
                            splits_d[target]["train"].extend(train)
                            splits_d[target]["valid"].extend(valid)
                            splits_d[target]["test"].extend(test)
        self.descriptor_splits = splits_d

    def split_with_known_test_descriptors(self, split_per=0.8, test_descriptors=None):
        ## take the entire set of descriptors, re-calculate the split such that test ones are always in test
        splits_d = {}
        for target, p in self.path_dict.items():
            desc_lens = {}  # length -> descriptors
            desc = list(p.keys())
            if len(desc) > 0:
                for d in desc:
                    d_len = len(d.split(","))
                    if d_len not in desc_lens:
                        desc_lens[d_len] = []
                    desc_lens[d_len].append(d)
                splits_d[target] = {"train": [], "valid": [], "test": []}
                for dlen, len_ds in desc_lens.items():
                    if dlen == 2:
                        splits_d[target]["train"].extend(len_ds)
                    else:
                        # add the len_ds in test which are in test_descriptors
                        test = [ld for ld in len_ds if ld in test_descriptors]
                        rest = list(set(len_ds) - set(test))
                        if len(rest) > 1:
                            train, valid = train_test_split(
                                rest, shuffle=True, train_size=split_per
                            )
                            splits_d[target]["train"].extend(train)
                            splits_d[target]["valid"].extend(valid)
                        elif len(rest) > 0:
                            splits_d[target]["train"].extend(rest)
                        splits_d[target]["test"].extend(test)
        self.descriptor_splits = splits_d

    def get_edge_dict_nx(self, nx_graph):
        graph = {}
        for edge in nx_graph.edges():
            graph[edge] = self.graph[edge]
        return graph

    def get_next_target(self):
        """ Infinitely yield a target
        """
        if not self.target_cycle:
            self.target_cycle = itertools.cycle(self.descriptor_splits.keys())
        return next(self.target_cycle)

    def get_next_sampled_graph(
        self,
        mode="train",
        pc=0.8,
        num_n_min=1,
        num_n_max=5,
        gamma=0.8,
        delete_reverse_paths=True,
        choose_used_descriptor=False,
    ):
        """
        Sample and return a path by randomly choosing a descriptor from the mode
        graph choosing strategy : 
            P_c = probability of choosing the neighbor of node
            num_n_min = min number of neighbors to choose from
            num_n_max = max number of neighbors to choose from
            gamma = value to decrease P_c with for the next branch
            choose_used_descriptor = make the task super easy, where during validation
            and testing, we choose descriptors which have been already used in training

        """
        if not hasattr(self, "G"):
            self.G = self.convert_nx_graph()
        # choose a target uniformly
        desc = []
        while len(desc) == 0:
            target = self.get_next_target()
            desc = self.descriptor_splits[target][mode]
        if (
            choose_used_descriptor
            and target in self.used_train_descriptors
            and len(self.used_train_descriptors[target]) > 0
        ):
            sampled_desc = random.choice(self.used_train_descriptors[target])
        else:
            sampled_desc = random.choice(desc)  # uniformly choose a descriptor
            if target not in self.used_train_descriptors:
                self.used_train_descriptors[target] = []
            self.used_train_descriptors[target].append(sampled_desc)
        path = random.choice(self.path_dict[target][sampled_desc])
        query = (path[0], path[-1])
        queue = []
        sampled_path = []
        b_pc = pc
        for node in path:
            pc = b_pc
            queue.append(node)
            while len(queue) > 0:
                nd = queue.pop(0)
                sampled_path.append(nd)
                if random.uniform(0, 1) < pc:
                    nbrs = [
                        nbr
                        for nbr in self.G[nd]
                        if (nbr not in path) or (nbr not in sampled_path)
                    ]
                    num_nbrs_to_choose = random.choice(range(num_n_min, num_n_max))
                    num_nbrs_to_choose = min(num_nbrs_to_choose, len(nbrs))
                    nbrs = random.sample(nbrs, num_nbrs_to_choose)
                    queue.extend(nbrs)
                pc = pc * gamma
        subgraph = nx.DiGraph(self.G.subgraph(sampled_path))
        # make sure to delete all possible paths of length k less than the resoluition path
        resolution_path_len = len(path)

        def del_paths(query, forward=False):
            all_simple_paths = nx.all_simple_paths(
                subgraph, source=query[0], target=query[1], cutoff=resolution_path_len
            )
            edges_to_del = []
            for sp in all_simple_paths:
                for spi in range(len(sp) - 1):
                    edges_to_del.append((sp[spi], sp[spi + 1]))
            edges_to_del = list(set(edges_to_del))
            path_edges = [(path[spi], path[spi + 1]) for spi in range(len(path) - 1)]
            edges_to_del = [e for e in edges_to_del if e not in path_edges]
            cutoff = 0
            if forward:
                cutoff = 1  # (keeping the resolution_path)
            while (
                len(
                    list(
                        nx.all_simple_paths(
                            subgraph,
                            source=query[0],
                            target=query[1],
                            cutoff=resolution_path_len,
                        )
                    )
                )
            ) > cutoff:
                edge2del = random.choice(edges_to_del)
                subgraph.remove_edge(edge2del[0], edge2del[1])
                edges_to_del.remove(edge2del)
                if len(edges_to_del) == 0:
                    break
            # subgraph.remove_edges_from(edges_to_del)

        del_paths(query, forward=True)
        reverse_query = (query[-1], query[0])
        if delete_reverse_paths:
            del_paths(reverse_query)
        sampled_graph_dict = self.get_edge_dict_nx(subgraph)
        # drop the target edge and its reverse
        if query in sampled_graph_dict:
            del sampled_graph_dict[query]
        if reverse_query in sampled_graph_dict:
            del sampled_graph_dict[reverse_query]
        if len(sampled_graph_dict) == 0:
            import ipdb

            ipdb.set_trace()
        sampled_graph = self.copy_edges(
            sampled_graph_dict, target_edge=query, resolution_path=path
        )

        return sampled_graph

    def get_world_graph(self, max_edges_per=1.0):
        """
        The idea of the world graph is that it captures some representative sample of the world
        We sample max_edge_per, and then return the largest connected component of the graph
        """
        max_edges = int(len(self.graph) * max_edges_per)
        sampled_edges = random.sample(list(self.graph), max_edges)
        sG = nx.DiGraph()
        for edge in sampled_edges:
            sG.add_edge(edge[0], edge[1])
        print(nx.info(sG))
        sampled_graph = self.copy_edges(self.get_edge_dict_nx(sG))
        return sampled_graph

    def stringify_keys(self, d):
        c = {}
        for k, v in d.items():
            if type(k) == tuple:
                c["{}:{}".format(k[0], k[1])] = v
            else:
                c[k] = v
        return c

    def unstringify_keys(self, d):
        c = {}
        for k, v in d.items():
            if ":" in k:
                f = [int(u) if u.isdigit() else u for u in k.split(":")]
                c[(f[0], f[1])] = v
            else:
                c[k] = v
        return c

    def save(self, save_path=""):
        save_dict = {
            "args": self.args,
            "forward_edges": self.forward_edges,
            "backward_edges": self.backward_edges,
            "descriptor_splits": self.descriptor_splits,
            "done_nodes": list(self.done_nodes),
            "graph": self.stringify_keys(self.graph),
            "id": self.id,
            "pos_predicates": list(self.pos_predicates),
            "neg_predicates": list(self.neg_predicates),
            "path_dict": self.path_dict,
            "rules": self.stringify_keys(self.rules),
            "rule_prob": self.stringify_keys(self.rule_prob),
            "t2e": self.t2e,
            "unit_rules": self.unit_rules,
            "used_rules": self.used_rules,
            "world_mode": self.world_mode,
        }
        json.dump(save_dict, open(path.join(save_path, "graph_prop.json"), "w"))

    def load(self, load_path=""):
        load_dict = json.load(open(path.join(load_path, "graph_prop.json")))
        self.args = Dict(load_dict["args"])
        self.forward_edges = load_dict["forward_edges"]
        self.backward_edges = load_dict["backward_edges"]
        self.descriptor_splits = load_dict["descriptor_splits"]
        self.done_nodes = set(load_dict["done_nodes"])
        self.graph = self.unstringify_keys(load_dict["graph"])
        self.id = load_dict["id"]
        self.pos_predicates = set(load_dict["pos_predicates"])
        self.neg_predicates = set(load_dict["neg_predicates"])
        self.path_dict = load_dict["path_dict"]
        self.rules = self.unstringify_keys(load_dict["rules"])
        self.rule_prob = self.unstringify_keys(load_dict["rule_prob"])
        self.t2e = load_dict["t2e"]
        self.unit_rules = load_dict["unit_rules"]
        self.used_rules = load_dict["used_rules"]
        self.world_mode = load_dict["world_mode"]


def stress_test():
    """
    Generate random rules and random graphs and solve them
    :return:
    """
    num_rels = [4, 5, 6, 7, 8]
    per_inverses = [0.2, 0.5, 0.7]
    num_nodes = [10, 15, 20]
    expand_steps = [5, 6, 7, 8]
    simulations = 10
    gt = 0
    nt = 0
    et = 0
    rl = 0
    ws = 0
    preds = 0
    cor_preds = 0
    cor_graphs = 0
    missed = 0
    uniq_graphs = set()
    for s in range(simulations):
        for num_rel in num_rels:
            for per_inverse in per_inverses:
                for num_node in num_nodes:
                    for expand_step in expand_steps:
                        args = get_args(
                            "--num_rel {} --per_inverse {} --num_nodes {} --expand_steps {}".format(
                                num_rel, per_inverse, num_node, expand_step
                            )
                        )
                        rg = RuleGen(args)
                        rl += len(rg.D)
                        ws += 1
                        gg = GraphGen(rg, args)
                        if gg.graph_generated:
                            gg.gen_noise()
                            gg.solve_graph()
                            if gg not in uniq_graphs:
                                gt += 1
                                nt += len(gg.get_all_nodes())
                                et += len(gg.graph)
                                preds += len(gg.pos_predicates) + len(gg.neg_predicates)
                                cor_preds += gg.num_cor_preds
                                if gg.contains_corrupted_predicate():
                                    cor_graphs += 1
                                uniq_graphs.add(gg)
                        else:
                            missed += 1
                        print("{}\r\n".format(gt))

    print("{} simulations complete".format(simulations))
    print("{} Rules generated".format(rl))
    print("{} worlds generated".format(ws))
    print("{} graphs generated with {} nodes and {} edges".format(gt, nt, et))
    print("{} unique graphs generated".format(len(uniq_graphs)))
    print("{} graphs could not be generated".format(missed))
    print("{} percentage predicates corrupted".format((cor_preds / preds) * 100))
    print(
        "{} percentage of graphs contain at least one corrupted rule".format(
            (cor_graphs / gt) * 100
        )
    )


if __name__ == "__main__":
    args = get_args("--num_rel 10 --num_splits 2 --per_inverse 0.1 --num_nodes 100000")
    # rg = RuleGen(args)
    # rga = rg.create_sorted_bfs_kb()
    # args.expand_steps = 10
    # gg = BigGraphGen(rga, args, id=0, gen_graph=False)
    # gg.gen_big_graph(max_nodes=2000)
    # gg.pre_compute_paths()
    # gg.split_train_test_descriptors()
    # gg.save()
    gga = BigGraphGen(None, args, id=0, gen_graph=False)
    gga.load()
    num_graphs = 5000
    pb = tqdm(total=num_graphs)
    graphs = []
    for i in range(num_graphs):
        graphs.append(gga.get_next_sampled_graph().jsonify())
        pb.update(1)
    pb.close()
    world_graph = gga.get_world_graph()
    # G = gg.convert_nx_graph()
    # f = plt.figure()
    # nx.draw(G, ax=f.add_subplot(111))
    # f.savefig("graph.png")
    # print(G)
    # print(gg.target)
    # gg.print_graph()
    # print('solving')
    # gg.solve_graph()
    # stress_test()
