import os
import random
import networkx
import json
import numpy


class ListDataLoader:

    def __init__(self, task_list, batch_size=64):
        self.task_list = task_list
        self.batch_num = int(numpy.floor(len(task_list) / batch_size))
        self.current_batch = 0
        self.batch_size = batch_size

    def __next__(self):
        if self.batch_num == self.current_batch:
            raise StopIteration
        batched_tasks = self.task_list[
            self.current_batch
            * self.batch_size : (self.current_batch + 1)
            * self.batch_size
        ]
        self.current_batch += 1
        return batched_tasks

    def __iter__(self):
        self.current_batch = 0
        return self


# The code below creates a new "gqa_info.json" file considering the top-k names, attributes, and relations. 
# At the moment, the code ignores the alias field from the original "gqa_info.json" file, as its usage is unclear. 

#TODO: extend this code so that we filter out the right and left relation types. 
# The extension should pay attention to the scene_graph field as the relation ids are based on 0-based numbering. 
def create_topk_meta_info(
    meta_info, topk_names, topk_attrs, topk_relas, meta_info_output
):
    new_meta_info = dict()
    for cls, topk in zip(["name", "attr", "rel"], [topk_names, topk_attrs, topk_relas]):
        freq = meta_info[cls]["freq"]
        canon = meta_info[cls]["canon"]

        new_freq = dict()
        new_idx = dict()
        new_canon = dict()
        new_alias = dict()

        index = 0
        for key in freq.keys():
            if index < topk:
                new_freq[key] = freq[key]
                new_canon[key] = canon[key]
                new_idx[key] = index
                index = index + 1
        new_meta_info[cls] = dict()
        new_meta_info[cls]["freq"] = new_freq
        new_meta_info[cls]["idx"] = new_idx
        new_meta_info[cls]["canon"] = new_canon
        if cls == "attr" or cls == "rel":
            new_meta_info[cls]["alias"] = new_alias
        new_meta_info[cls]["num"] = len(new_idx)

    with open(meta_info_output, "w") as outfile:
        json.dump(new_meta_info, outfile)

def maintain_topk_properties(scene, topk_names, topk_attrs, topk_relas):
    # Map from object id to its type id
    object2type = {
        obj: idx for (obj, idx) in scene["names"].items() if idx < topk_names
    }

    # Map from object id to a list of its attribute ids
    object2attrs = dict()
    for obj, attrs in scene["attributes"].items():
        new_attrs = [attr for attr in attrs if attr < topk_attrs]
        if obj in object2type and new_attrs:
            object2attrs[obj] = new_attrs

    object2rels = dict()
    for key in scene["relations"].keys():
        new_rels_map = dict(
            [
                (sub, relation)
                for sub, relation in scene["relations"][key].items()
                if sub in object2type
                and relation < topk_relas 
                and relation != -1
            ]
        )
        if key in object2type and new_rels_map:
            object2rels[key] = new_rels_map

    return object2type, object2attrs, object2rels


def find_superclass(object_type, is_a_graph):
    superclass = object_type
    while True:
        # Pick a dest edge from current node
        outgoing_edges = list(is_a_graph.out_edges(superclass))
        if outgoing_edges:
            edge_id = random.randint(0, len(outgoing_edges) - 1)
            dest = outgoing_edges[edge_id][1]
            # Stop the random walks and return the previously found superclass if reaching a thing or object class
            if dest == "thing" or random.random() > 1: # dest == "object" or 
                return superclass
            else:
                superclass = dest
        else:
            return superclass


def find_seed_classes(superclass, is_a_graph, seed_classes):
    if not is_a_graph.has_node(superclass):
        return [superclass]
    return [
        seed
        for seed in seed_classes
        if is_a_graph.has_node(seed) and networkx.has_path(is_a_graph, seed, superclass)
    ]


def create_database(
    image_id,
    object_ids,
    object2type,
    object2attrs,
    object2rels,
    idx2word,
    directory,
):
    image_id = str(image_id)
    if not os.path.exists(os.path.join(directory, image_id)):
        os.makedirs(os.path.join(directory, image_id))

    if object2type is not None:
        file = open(os.path.join(directory, image_id, "e_name.csv"), "w")
        for obj in object_ids:
            gold_type_id = object2type[obj]
            gold_type_name = idx2word.idx_to_name(gold_type_id)
            file.write(f'"{gold_type_name}"\t{obj}\n')
        file.close()

    if object2attrs is not None:
        file = open(os.path.join(directory, image_id, "attr.csv"), "w")
        for obj in object2attrs:
            gold_sub_attr_ids = object2attrs[obj]
            for gold_sub_attr_id in gold_sub_attr_ids:
                gold_sub_attr_name = idx2word.idx_to_attr(gold_sub_attr_id)
                file.write(f'"{gold_sub_attr_name}"\t{obj}\n')
        file.close()

    if object2rels is not None:
        file = open(os.path.join(directory, image_id, "rela.csv"), "w")
        for obj in object2rels:
            gold_rel_dict = object2rels[obj]
            for sub, gold_rel_id in gold_rel_dict.items():
                gold_rel_name = idx2word.idx_to_rela(gold_rel_id)
                file.write(f'"{gold_rel_name}"\t{obj}\t{sub}\n')
        file.close()
