"""Prepare and export data here"""
import random
import os
from os.path import expanduser
import shutil

from lgw.graph_generator import GraphGen, BigGraphGen
from lgw.rule_generator import RuleGen
from lgw.args import get_args
from tqdm import tqdm
from natsort import natsorted
import json
import glob
from addict import Dict
import copy
import numpy as np
import pickle
import ray

ray.shutdown()
ray.init()


def create_dir(dirname):
    """Create a new directory only if it does not exist already"""
    if not os.path.exists(dirname):
        os.mkdir(dirname)


def create_new_dir(dirname):
    """Delete the existing directory and then create a new directory"""
    if os.path.exists(dirname):
        shutil.rmtree(dirname)
    os.mkdir(dirname)


class LogicalGraphWorld:
    """
    Main class to generate and save different rule based worlds
    """

    def __init__(self, args):
        self.args = Dict(vars(args))
        if args.save_path[0] != "/":
            self.save_path = os.path.join(expanduser("~"), args.save_path)
        else:
            self.save_path = args.save_path
        self.num_rel_n = [int(s) for s in args.num_rel_choices.split(",")]
        self.per_inverse_n = [float(s) for s in args.per_inverse_choices.split(",")]
        self.corrupt_eps_n = [float(s) for s in args.corrupt_eps_choices.split(",")]
        self.expand_steps_n = [int(s) for s in args.expand_steps_choices.split(",")]
        self.saved_configs = []
        self.max_rule_id = 0
        self.rule_heads = []  # store a copy of the generated rule heads
        self.all_worlds = []

    def get_rule_dirs(self, mode="train"):
        try:
            return next(os.walk(os.path.join(self.save_path, mode)))[1]
        except:
            return []

    def get_max_rule_id(self):
        # last_id = 0
        # rd = self.get_rule_dirs()
        # if len(rd) > 0:
        #     rd = natsorted(rd)
        #     last_id = int(rd[-1].split('_')[1])
        # return last_id + 1
        self.max_rule_id += 1
        return self.max_rule_id

    def get_rule_args(self, rule_dir, join=True):
        """
        Read rule args
        :param rule_dir:
        :return:
        """
        if join:
            rule_dir = os.path.join(self.save_path, rule_dir)
        return Dict(json.load(open(os.path.join(rule_dir, "config.json"), "r")))

    def get_rule_config(self):
        """
        Generate a random rule config
        :return:
        """
        return Dict(
            {
                "num_rel": random.choice(self.num_rel_n),
                "per_inverse": random.choice(self.per_inverse_n),
                "corrupt_eps": random.choice(self.corrupt_eps_n),
                "expand_steps": random.choice(self.expand_steps_n),
            }
        )

    def get_saved_rule_configs(self, mode="train"):
        """
        Read the generated rules in the folder
        :return:
        """
        saved_configs = []
        rd = self.get_rule_dirs(mode=mode)
        if len(rd) > 0:
            for rule_dir in rd:
                mode_rd = os.path.join(mode, rule_dir)
                saved_configs.append(self.get_rule_args(mode_rd))
        return saved_configs

    def has_duplicate_rule(self, rule_arg):
        """
        Test whether current config is present in any saved configs
        :return:
        """
        for mode in ["train", "valid", "test"]:
            saved_configs = self.get_saved_rule_configs(mode=mode)
            if len(saved_configs) > 0:
                for sc in saved_configs:
                    if (
                        sc.num_rel == rule_arg.num_rel
                        and sc.per_inverse == rule_arg.per_inverse
                        and sc.corrupt_eps == rule_arg.corrupt_eps
                        and sc.expand_steps == rule_arg.expand_steps
                    ):
                        return sc.rule_name
        return False

    def split_worlds(self, policy="rand"):
        """
        Split Worlds in train/valid/test
        if policy in "fsrl_2" or "fsrl_2.1", then split contiguously
        :return:
        """
        world_ids = list(range(self.args.num_worlds))
        if policy in ["fsrl_2", "fsrl_2.1"]:
            num_train_val_gis = int(
                len(world_ids) * self.args.world_train_val_test_split
            )
            num_train_gis = int(
                num_train_val_gis * self.args.world_train_val_test_split
            )
            num_val_gis = num_train_val_gis - num_train_gis
            num_test_gis = len(world_ids) - num_train_val_gis
            train_wgis = world_ids[:num_train_gis]
            val_wgis = world_ids[num_train_gis : num_train_gis + num_val_gis]
            test_wgis = world_ids[-num_test_gis:]
        else:
            # random policy
            world_gis = random.sample(
                world_ids, int(len(world_ids) * self.args.world_train_val_test_split)
            )
            test_wgis = [g for g in world_ids if g not in world_gis]
            train_wgis = random.sample(
                world_gis, int(len(world_gis) * self.args.world_train_val_test_split)
            )
            val_wgis = [g for g in world_gis if g not in train_wgis]
        return train_wgis, val_wgis, test_wgis

    def generate_worlds(self, num_worlds=1, policy="random"):
        """
        generate and save worlds
        :param policy: defines the policy of rule creation
        :return:
        """
        generated_worlds = []
        if policy == "random":
            for world_id in range(num_worlds):
                args = copy.deepcopy(self.args)
                rule_arg = self.get_rule_config()
                args.update(rule_arg)
                rule_exists = self.has_duplicate_rule(rule_arg)
                if rule_exists:
                    raise AssertionError("rule cannot be generated")
                allowed_heads = None
                if self.args.fix_num_relations and len(self.rule_heads) > 0:
                    allowed_heads = self.rule_heads
                rg = RuleGen(args, allowed_heads=allowed_heads)
                rg.rule_name = "rule_{}".format(world_id)
                # only store the rule heads for the first world
                # in this way, all subsequent rules will have **max** relations as that of the first
                # rule world. Some rules may have less than max
                if world_id == 0:
                    self.rule_heads = rg.get_compositional_heads()
                generated_worlds.append(copy.deepcopy(rg))

        if policy == "sanity_load_rule":
            for world_id in range(num_worlds):
                rg = RuleGen(
                    self.get_rule_args(self.args.load_rule, join=False),
                    generate_rules=False,
                )
                rg.rule_name = "rule_{}".format(world_id)
                generated_worlds.append(copy.deepcopy(rg))

        if policy == "sanity":
            args = copy.deepcopy(self.args)
            rule_arg = self.get_rule_config()
            args.update(rule_arg)
            rg = RuleGen(args)
            for world_id in range(num_worlds):
                rg.rule_name = "rule_{}".format(world_id)
                generated_worlds.append(copy.deepcopy(rg))

        if policy == "overlap":
            args = copy.deepcopy(self.args)
            rule_arg = self.get_rule_config()
            args.update(rule_arg)
            massive_rule = RuleGen(args)
            splitted_rules = massive_rule.gen_overlap_rules()
            for world_id, world in enumerate(splitted_rules):
                world.rule_name = "rule_{}".format(world_id)
                generated_worlds.append(world)

        if policy == "composition":
            args = copy.deepcopy(self.args)
            rule_arg = self.get_rule_config()
            args.update(rule_arg)
            massive_rule = RuleGen(args)
            splitted_rules = massive_rule.gen_unique_combination_rules()
            for world_id, world in enumerate(splitted_rules):
                world.rule_name = "rule_{}".format(world_id)
                generated_worlds.append(world)

        if policy == "fsrl_1":
            """ Split a large rule base into n distict rules
                where n = args.num_splits
            """
            args = copy.deepcopy(self.args)
            rule_arg = self.get_rule_config()
            args.update(rule_arg)
            massive_rule = RuleGen(args)
            splitted_rules = massive_rule.gen_overlap_rules(mode="distinct")
            for world_id, world in enumerate(splitted_rules):
                world.rule_name = "rule_{}".format(world_id)
                generated_worlds.append(world)

        if policy == "fsrl_2":
            """ Split a large rule base into n overlapping rules in a continual learning setting
                where n = args.num_splits
            """
            args = copy.deepcopy(self.args)
            rule_arg = self.get_rule_config()
            args.update(rule_arg)
            massive_rule = RuleGen(args)
            splitted_rules = massive_rule.gen_overlap_rules(mode="continual")
            for world_id, world in enumerate(splitted_rules):
                world.rule_name = "rule_{}".format(world_id)
                generated_worlds.append(world)

        if policy == "fsrl_2.1":
            """ Split a large rule base into n overlapping rules with n overlap
                where n = args.num_splits
            """
            args = copy.deepcopy(self.args)
            rule_arg = self.get_rule_config()
            args.update(rule_arg)
            massive_rule = RuleGen(args)
            splitted_rules = massive_rule.gen_overlap_rules(mode="overlap")
            for world_id, world in enumerate(splitted_rules):
                world.rule_name = "rule_{}".format(world_id)
                generated_worlds.append(world)

        # splitting worlds
        # train_wgis, val_wgis, test_wgis = self.split_worlds(policy=policy)
        # test_wgis = len(generated_worlds) - 1
        # train_wgis = [0,1,2,3,4,5,6,7]
        # val_wgis = [8]
        # test_wgis = [9]
        num_ct = {"train": 0, "valid": 0, "test": 0}
        for world_id, world in enumerate(generated_worlds):
            num_ct[world.world_mode] += 1
            # if world_id in train_wgis:
            #     world.world_mode = "train"
            # elif world_id in val_wgis:
            #     world.world_mode = "valid"
            # else:
            #     world.world_mode = "test"
        print(num_ct)
        return generated_worlds

    def generate_graphs(
        self,
        world: RuleGen,
        sanity_gen_mode=False,
        save_graphs=True,
        randomize_steps=False,
    ):
        """
        Generate graphs for one rule world
        :param world: the rule world
        :param world_mode: train/test/valid
        :return:
        """
        graphs = Dict({"train": [], "valid": [], "test": []})
        rule_arg = world.args
        args = copy.deepcopy(self.args)
        args.update(rule_arg)
        max_expand_steps = args.expand_steps
        args.expand_steps = 1
        world_graphs = set()
        done_nodes = set()
        gi = 0
        pb = tqdm(total=args.graphs_per_world)
        wait_for = 1000
        while len(world_graphs) < args.graphs_per_world:
            gg = GraphGen(world, args, id=gi)
            gi += 1
            if randomize_steps:
                args.expand_steps = random.choice(range(2, max_expand_steps))
            if gg.graph_generated and gg not in world_graphs:
                if args.add_noise:
                    gg.gen_noise()
                gg.solve_graph()
                nodes = gg.get_all_nodes()
                for n in nodes:
                    node_id = gg.get_node_id(n)
                    if node_id not in world.ent2id:
                        world.ent2id[node_id] = len(world.ent2id)
                for edge, rel in gg.graph.items():
                    if rel not in world.rel2id:
                        world.rel2id[rel] = len(world.rel2id)
                world_graphs.add(gg)
                done_nodes.update(gg.done_nodes)
                pb.set_description("k = {} ".format(args.expand_steps))
                pb.update(1)
            else:
                wait_for -= 1
                if wait_for == 0:
                    # wait for k timesteps. if no new graph is generated, then
                    # increase the expansion limit
                    if args.expand_steps > max_expand_steps:
                        raise AssertionError(
                            "max expansion steps reached, try increasing the amount"
                        )
                    if not randomize_steps:
                        args.expand_steps += 1
                    wait_for = 1000
        pb.close()
        print("Generated {} unique graphs".format(len(world_graphs)))
        world_graphs = list(world_graphs)
        print("Generating meta graph...")
        args.expand_steps = max_expand_steps
        rule_meta_graph = GraphGen(world, args, id=gi, gen_graph=False)
        rule_meta_graph.gen_world_graph(num_cycles=args.meta_graph_cyles)
        # if sanity_gen_mode:
        #     # split graphs by assuming multiple rules
        #     agis = list(range(len(world_graphs)))
        #     def split(a, n):
        #         k, m = divmod(len(a), n)
        #         return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
        #
        #     # splitting world level on train test
        #     wgis = split(agis, self.args.num_worlds)
        #     modes = ['train','valid','test']
        #     all_train_gis, all_val_gis, all_test_gis = self.split_graphs_by_rule(world_graphs)
        #     all_train_gis = random.sample(all_train_gis, len(all_train_gis))
        #     all_val_gis = random.sample(all_val_gis, len(all_val_gis))
        #     all_test_gis = random.sample(all_test_gis, len(all_test_gis))
        #     all_train_gis = list(split(all_train_gis, self.args.num_worlds))
        #     all_val_gis = list(split(all_val_gis, self.args.num_worlds))
        #     all_test_gis = list(split(all_test_gis, self.args.num_worlds))
        #     for gi in range(len(all_train_gis)):
        #         graphs.train = [world_graphs[i] for i in all_train_gis[gi]]
        #         graphs.valid = [world_graphs[i] for i in all_val_gis[gi]]
        #         graphs.test = [world_graphs[i] for i in all_test_gis[gi]]
        #         rg.rule_name = 'rule_{}'.format(self.get_max_rule_id())
        #         if gi in train_wgis:
        #             rg.world_mode = 'train'
        #         elif gi in val_wgis:
        #             rg.world_mode = 'valid'
        #         else:
        #             rg.world_mode = 'test'
        #         self.save_files_json(rg, graphs, meta_graph=rule_meta_graph)
        #         graphs = Dict({'train': [], 'valid': [], 'test': []})
        #
        # else:
        if save_graphs:
            gis = list(range(len(world_graphs)))
            print("Splitting graphs for Train/Val/Test..")
            train_gis, val_gis, test_gis = self.split_graphs_by_rule(world_graphs)
            graphs.train = [world_graphs[i] for i in train_gis]
            graphs.valid = [world_graphs[i] for i in val_gis]
            graphs.test = [world_graphs[i] for i in test_gis]
            self.save_files_json(world, graphs, meta_graph=rule_meta_graph)
            print("Saving done")
            return True
        else:
            return world_graphs

    @ray.remote
    def generate_graphs_big(
        self, wi, sanity_gen_mode=False, save_graphs=True, randomize_steps=False
    ):
        """
        Generate graphs from rule worlds
        """
        world = self.all_worlds[wi]
        rule_arg = world.args
        args = copy.deepcopy(self.args)
        args.update(rule_arg)
        num_nodes = args.num_nodes
        args.num_nodes = 1000000
        # generate the big world
        big_g = BigGraphGen(world, args, gen_graph=False)
        big_g.gen_big_graph(max_nodes=num_nodes, max_cycles=args.gen_graph_cyles)
        big_g.pre_compute_paths()
        # big_g.split_train_test_descriptors()
        # big_g.split_train_test_descriptors_clutrr()
        big_g.descriptor_splits = []
        print("Saving biggraph obj")
        data_folder, mode_folder, rule_folder = self.create_rule_folders(
            world.world_mode, world.rule_name
        )
        big_g.save(rule_folder)
        return True

    @ray.remote
    def sample_graphs_from_big(self, rule_folder="", save_graphs=True):
        """
        Derive graphs from saved big graphs
        """
        graphs = Dict({"train": [], "valid": [], "test": []})
        args = self.args
        big_g = BigGraphGen(None, args, gen_graph=False)
        big_g.load(rule_folder)
        # big_g.pre_compute_paths()
        # big_g.split_train_test_descriptors()
        # big_g.split_train_test_descriptors_clutrr()
        # big_g.make_almost_complete(list(set(big_g.done_nodes))) # do i need this?
        rule_name = rule_folder.split("/")[-1]
        world_mode = rule_folder.split("/")[-2]
        pb = tqdm(total=args.num_train_rows)
        for i in range(args.num_train_rows):
            graphs["train"].append(big_g.get_next_sampled_graph(mode="train"))
            pb.update(1)
            pb.set_description("W: {}".format(rule_name))
        pb.close()
        pb = tqdm(total=args.num_valid_rows)
        mode = "valid"
        if args.easy_mode:
            mode = "train"
        for i in range(args.num_valid_rows):
            graphs["valid"].append(
                big_g.get_next_sampled_graph(
                    mode=mode, choose_used_descriptor=args.easy_mode
                )
            )
            pb.update(1)
            pb.set_description(
                "W: {}, Easy Mode : {}".format(rule_name, args.easy_mode)
            )
        pb.close()
        mode = "test"
        if args.easy_mode:
            mode = "train"
        pb = tqdm(total=args.num_test_rows)
        for i in range(args.num_test_rows):
            graphs["test"].append(
                big_g.get_next_sampled_graph(
                    mode=mode, choose_used_descriptor=args.easy_mode
                )
            )
            pb.update(1)
            pb.set_description(
                "W: {}, Easy Mode : {}".format(rule_name, args.easy_mode)
            )
        pb.close()
        rule_meta_graph = big_g.get_world_graph(
            max_edges_per=args.world_graph_per_edges
        )
        if save_graphs:
            self.save_files_json(
                big_g.rule_gen,
                graphs,
                graphGen=big_g,
                rule_name=rule_name,
                world_mode=world_mode,
                meta_graph=rule_meta_graph,
            )
            print("Saving done")
            return True
        else:
            return graphs

    def get_all_rule_folders(self):
        parent_folder = os.getcwd().split("lgw")[0]
        data_folder = os.path.join(parent_folder, "lgw", self.args.folder_name)
        modes = ["train", "valid", "test"]
        all_folders = []
        for mode in modes:
            mode_folder = os.path.join(data_folder, mode)
            rule_folders = [
                os.path.join(mode_folder, folder)
                for folder in os.listdir(mode_folder)
                if os.path.isdir(os.path.join(mode_folder, folder))
            ]
            all_folders.extend(rule_folders)
        return all_folders

    @ray.remote
    def generate_graphs_parallel(
        self, wi, sanity_gen_mode=False, save_graphs=True, randomize_steps=False
    ):
        """
        Generate graphs for one rule world
        :param world: the rule world
        :param world_mode: train/test/valid
        :return:
        """
        world = self.all_worlds[wi]
        graphs = Dict({"train": [], "valid": [], "test": []})
        rule_arg = world.args
        args = copy.deepcopy(self.args)
        args.update(rule_arg)
        max_expand_steps = args.expand_steps
        args.expand_steps = 1
        world_graphs = set()
        done_nodes = set()
        gi = 0
        pb = tqdm(total=args.graphs_per_world)
        wait_for = 1000
        while len(world_graphs) < args.graphs_per_world:
            gg = GraphGen(world, args, id=gi)
            gi += 1
            if randomize_steps:
                args.expand_steps = random.choice(range(2, max_expand_steps))
            if gg.graph_generated and gg not in world_graphs:
                if args.add_noise:
                    gg.gen_noise()
                gg.solve_graph()
                nodes = gg.get_all_nodes()
                for n in nodes:
                    node_id = gg.get_node_id(n)
                    if node_id not in world.ent2id:
                        world.ent2id[node_id] = len(world.ent2id)
                for edge, rel in gg.graph.items():
                    if rel not in world.rel2id:
                        world.rel2id[rel] = len(world.rel2id)
                world_graphs.add(gg)
                done_nodes.update(gg.done_nodes)
                pb.set_description("k = {} ".format(args.expand_steps))
                pb.update(1)
            else:
                wait_for -= 1
                if wait_for == 0:
                    # wait for k timesteps. if no new graph is generated, then
                    # increase the expansion limit
                    if args.expand_steps > max_expand_steps:
                        print("error: exceeding max expand steps, killing job ...")
                        return True
                    if not randomize_steps:
                        args.expand_steps += 1
                    wait_for = 1000
        pb.close()
        print("Generated {} unique graphs".format(len(world_graphs)))
        world_graphs = list(world_graphs)
        print("Generating meta graph...")
        # log the exact expand steps needed in this world
        world.args.expand_steps = args.expand_steps
        # reset the expand steps to generate the meta graph
        args.expand_steps = args.world_graph_expand_steps
        args.num_nodes = args.world_graph_num_nodes
        rule_meta_graph = GraphGen(world, args, id=gi, gen_graph=False)
        rule_meta_graph.gen_world_graph(num_cycles=args.meta_graph_cyles)
        if save_graphs:
            gis = list(range(len(world_graphs)))
            print("Splitting graphs for Train/Val/Test..")
            train_gis, val_gis, test_gis = self.split_graphs_by_rule(world_graphs)
            graphs.train = [world_graphs[i] for i in train_gis]
            graphs.valid = [world_graphs[i] for i in val_gis]
            graphs.test = [world_graphs[i] for i in test_gis]
            self.save_files_json(world, graphs, meta_graph=rule_meta_graph)
            print("Saving done")
            return True
        else:
            return world_graphs

    def split_graphs_by_rule(self, world_graphs):
        """
        Idea: split the graphs by rule length
        exceptions: always keep 2 and 3 in training
        :param world_graphs:
        :return:
        """
        rulen2gis = {}
        for gi, graph in enumerate(world_graphs):
            rulen = len(graph.used_rules)
            edge = list(graph.target.keys())[0]
            target = graph.target[edge]
            rule_t_marker = "{}_{}".format(rulen, target)
            if rule_t_marker not in rulen2gis:
                rulen2gis[rule_t_marker] = []
            rulen2gis[rule_t_marker].append(gi)
        train_gis = []
        val_gis = []
        test_gis = []
        for rulen, gis in rulen2gis.items():
            if int(rulen.split("_")[0]) in [1, 2]:
                train_gis.extend(gis)
                continue
            tv_gis = random.sample(gis, int(len(gis) * self.args.train_test_split))
            test_gis_b = [g for g in gis if g not in tv_gis]
            train_gis_b = random.sample(
                tv_gis, int(len(tv_gis) * self.args.train_val_split)
            )
            val_gis_b = [g for g in tv_gis if g not in train_gis_b]
            train_gis.extend(train_gis_b)
            val_gis.extend(val_gis_b)
            test_gis.extend(test_gis_b)
        return train_gis, val_gis, test_gis

    # def save_file_fbk15(graphs, e2id, rel2id, folder='lgw_1'):
    #     """
    #     Save the graphs in FBK15 format
    #     :param graphs:
    #     :param e2id: entity to id
    #     :param rel2id: relation to id
    #     :param folder:
    #     :return:
    #     """
    #     folder = os.path.join('data', folder)
    #     create_dir('data')
    #     create_dir(folder)
    #     with open(os.path.join(folder, 'relation2id.txt'), 'w') as fp:
    #         fp.write('{}\n'.format(len(rel2id)))
    #         for rel, id in sorted(rel2id.items(), key=lambda x: x[1]):
    #             fp.write('{} {}\n'.format(rel, id))
    #     with open(os.path.join(folder, 'entity2id.txt'), 'w') as fp:
    #         fp.write('{}\n'.format(len(ent2id)))
    #         for rel, id in sorted(ent2id.items(), key=lambda x: x[1]):
    #             fp.write('{} {}\n'.format(rel, id))
    #     for key in graphs:
    #         with open(os.path.join(folder, '{}2id.txt'.format(key)), 'w') as fp:
    #             fp.write('{}\n'.format(sum([len(graph.graph) for graph in graphs[key]])))
    #             for graph in graphs[key]:
    #                 for body, rel in graph.graph.items():
    #                     fp.write('{} {} {}\n'.format(e2id[graph.get_node_id(body[0])],
    #                                                  e2id[graph.get_node_id(body[1])],
    #                                                  rel2id[rel]))

    def save_files_text(self, rg: RuleGen, graphs):
        """
        Save the graphs in separate files in subfolders from train test and valid
        :param graphs:
        :param rel2id:
        :param rule_folder:
        :param exp_name:
        :return:
        """
        parent_folder = os.getcwd().split("lgw")[0]
        data_folder = os.path.join(parent_folder, "lgw", self.args.folder_name)
        create_dir(data_folder)
        rule_folder = os.path.join(data_folder, rg.rule_name)
        create_new_dir(rule_folder)

        for mode, gs in graphs.items():
            print("Saving {} graphs for mode {}".format(len(gs), mode))
            mode_folder = os.path.join(rule_folder, mode)
            create_dir(mode_folder)
            pb = tqdm(total=len(gs))
            for gi, graph in enumerate(gs):
                with open(os.path.join(mode_folder, "{}.txt".format(gi)), "w") as fp:
                    for body, rel in graph.graph.items():
                        fp.write("{} {} {}\n".format(body[0], body[1], rg.rel2id[rel]))
                with open(
                    os.path.join(mode_folder, "{}_query.txt".format(gi)), "w"
                ) as fp:
                    # source dest label
                    edge = list(graph.target.keys())[0]
                    label = graph.target[edge]
                    fp.write("{} {} {}\n".format(edge[0], edge[1], rg.rel2id[label]))
                pb.update(1)
            pb.close()
        config = copy.copy(self.args)
        config["rules"] = remap_keys(rg.D, rg.rule_prob)
        json.dump(config, open(os.path.join(rule_folder, "config.json"), "w"))
        print("Done writing files for {}".format(self.args.rule_name))

    def jsonify_graph(self, graph, is_meta_graph=False):
        edges = []
        for body, rel in graph.graph.items():
            # edges.append([body[0],body[1], rg.rel2id[rel]])
            edges.append([body[0], body[1], rel])
        # write query
        if not is_meta_graph:
            edge = list(graph.target.keys())[0]
            label = graph.target[edge]
        return {
            "edges": edges,
            "query": [edge[0], edge[1], label] if not is_meta_graph else [0, 0, 0],
            "rules": graph.used_rules,
            "resolution_path": graph.resolution_path if not is_meta_graph else [],
        }

    def create_rule_folders(self, world_mode, rule_name, folder_name=None):
        if folder_name:
            data_folder = folder_name
        else:
            parent_folder = os.getcwd().split("lgw")[0]
            data_folder = os.path.join(parent_folder, "lgw", self.args.folder_name)
        create_dir(data_folder)
        mode_folder = os.path.join(data_folder, world_mode)
        create_dir(mode_folder)
        rule_folder = os.path.join(mode_folder, rule_name)
        create_dir(rule_folder)
        return data_folder, mode_folder, rule_folder

    def save_files_json(
        self,
        rg: RuleGen,
        graphs,
        graphGen=None,
        rule_name="",
        world_mode="train",
        meta_graph=None,
        folder_name=None,
    ):
        """
        Save files in jsonlines format
        save in array [node1, node2, edge]
        :return:
        """
        data_folder, mode_folder, rule_folder = self.create_rule_folders(
            world_mode, rule_name, folder_name
        )
        for mode, gs in graphs.items():
            print("Saving {} graphs for mode {}".format(len(gs), mode))
            with open(os.path.join(rule_folder, "{}.jsonl".format(mode)), "w") as fp:
                pb = tqdm(total=len(gs))
                for gi, graph in enumerate(gs):
                    json.dump(self.jsonify_graph(graph), fp)
                    fp.write("\n")
                    pb.update(1)
                pb.close()
            if mode == "train":
                print("Saving meta graph")
                with open(os.path.join(rule_folder, "meta_graph.jsonl"), "w") as fp:
                    json.dump(self.jsonify_graph(meta_graph, is_meta_graph=True), fp)
        config = copy.copy(graphGen.args)
        config["rules"] = remap_keys(graphGen.rules, {})
        json.dump(config, open(os.path.join(rule_folder, "config.json"), "w"))
        print("Done writing files in {}".format(rule_folder))

    def generate(self):
        """
        Generate n different worlds
        :return:
        """
        ct = 0
        takes = 5
        num_worlds = self.args.num_worlds
        # if self.args.sanity:
        #     num_worlds = 1
        #     self.args.graphs_per_world = self.args.graphs_per_world * self.args.num_worlds
        # split worlds
        # train_wgis, val_wgis, test_wgis = self.split_worlds()
        # generate worlds
        worlds = self.generate_worlds(num_worlds, policy=self.args.policy)
        pba = tqdm(total=num_worlds)
        for world in worlds:
            if self.generate_graphs(
                world,
                sanity_gen_mode=self.args.sanity,
                randomize_steps=self.args.randomize_steps,
            ):
                ct += 1
                pba.update(1)
                # print("Build and saved world {}".format(self.args.rule_name))
            else:
                takes = takes - 1
            if takes == 0:
                print("Too many tries done. Change configuration...")
            else:
                takes = 5
        pba.close()

    def generate_parallel(self):
        """
        Generate n different worlds
        :return:
        """
        sample_graphs = args.sample_graphs
        sample_worlds = args.sample_worlds
        ct = 0
        takes = 5
        num_worlds = self.args.num_worlds
        # if self.args.sanity:
        #     num_worlds = 1
        #     self.args.graphs_per_world = self.args.graphs_per_world * self.args.num_worlds
        # split worlds
        # train_wgis, val_wgis, test_wgis = self.split_worlds()
        # generate worlds
        if sample_worlds:
            self.all_worlds = self.generate_worlds(num_worlds, policy=self.args.policy)
            futures = [
                self.generate_graphs_big.remote(
                    self,
                    i,
                    sanity_gen_mode=self.args.sanity,
                    randomize_steps=self.args.randomize_steps,
                )
                for i in range(len(self.all_worlds))
            ]
            ray.get(futures)
        if sample_graphs:
            rule_folders = self.get_all_rule_folders()
            print(rule_folders)
            print("Found {} rule_folders".format(len(rule_folders)))
            futures = [
                self.sample_graphs_from_big.remote(self, folder)
                for folder in rule_folders
            ]
            ray.get(futures)

    def pretty_rule(self, body, head):
        if type(body) == list:
            return "({},{}) -> {}".format(body[0], body[1], head)
        else:
            return "({}) -> {}".format(body, head)

    def analyze_data(self):
        """
        Given a set of worlds and graphs, create the identity vector
        an Identity vector is of dimension of all possible rules in all
        our worlds
        :return:
        """
        all_rules = []
        for mode in ["train", "valid", "test"]:
            saved_configs = self.get_saved_rule_configs(mode=mode)
            if len(saved_configs) > 0:
                for sc in saved_configs:
                    for rule in sc["rules"]:
                        if type(rule["body"]) == list:
                            all_rules.append(
                                self.pretty_rule(rule["body"], rule["head"])
                            )
        all_rules = list(set(all_rules))
        print("Total rules in the universe : {}".format(len(all_rules)))
        for mode in ["train", "valid", "test"]:
            rd = self.get_rule_dirs(mode=mode)
            if len(rd) > 0:
                for rule_dir in rd:
                    mode_rd = os.path.join(mode, rule_dir)
                    world_id = np.zeros(len(all_rules))
                    for gmode in ["train", "valid", "test"]:
                        graph_file = os.path.join(
                            self.save_path, mode_rd, "{}.jsonl".format(gmode)
                        )
                        with open(graph_file, "r") as fp:
                            for line in fp:
                                graph = json.loads(line)
                                for grule in graph["rules"]:
                                    pr = self.pretty_rule(grule[0], grule[1])
                                    world_id[all_rules.index(pr)] = 1
                    config_mode = json.load(
                        open(os.path.join(self.save_path, mode_rd, "config.json"), "r")
                    )
                    config_mode["ID"] = world_id.tolist()
                    json.dump(
                        config_mode,
                        open(os.path.join(self.save_path, mode_rd, "config.json"), "w"),
                    )


def remap_keys(mapping, prob):
    return [
        {"body": k, "head": v, "p": prob[k] if k in prob else 1.0}
        if type(k) == tuple
        else {"body": k, "head": v, "p": 1.0}
        for k, v in mapping.items()
    ]


def debug_graph_gen(args, config_path):
    config = json.load(open(config_path))
    config = Dict(config)
    rg = RuleGen(config, generate_rules=False)
    import ipdb

    ipdb.set_trace()
    config.expand_steps = 10
    config.num_nodes = 1000
    g = GraphGen(rg, config, id=0, gen_graph=False)
    g.gen_world_graph(num_cycles=100)


if __name__ == "__main__":
    args = get_args()
    lgw = LogicalGraphWorld(args)
    # lgw.generate()
    lgw.generate_parallel()
    # lgw.analyze_data()
