# Benchmark designed for testing the curriculum learning and the pseudolabeling idea

import os
import pickle
import json

from models.idx2word import Idx2Word

import networkx
import random
import pickle
import copy
from utilities import (
    maintain_topk_properties,
    find_superclass,
    find_seed_classes,
    ListDataLoader,
)

class InstanceGenerator:
    def __init__(
        self,
        data_dir,
        tasks_filename,
        topk_names,
        topk_attrs,
        topk_relas,
        meta_f,
        is_a_filename,
        reasoning_directory,
    ):
        # Load the json with the domain of the NNs
        self.meta_info = json.load(open(meta_f, "r"))
        self.idx2word = Idx2Word(self.meta_info)
        self.is_a_filename = is_a_filename

        self.engine = None
        self.reasoning_directory = reasoning_directory
        self.topk_names = topk_names
        self.topk_attrs = topk_attrs
        self.topk_relas = topk_relas

        self.program = """name(X,Y) :- e_name(X,Y)
is_a(X,Y) :- e_is_a(X,Y)
name(N2,O) :- name(N1,O), is_a(N1,N2)
"""
        self.attr_canon = self.meta_info["attr"]["idx"].keys()
        self.object_domain = [
            obj
            for obj, idx in self.meta_info["name"]["idx"].items()
            if idx < topk_names
        ]
        self.rela_canon = self.meta_info["rel"]["idx"].keys()

        # Load the is_a relation
        # Create the is_a graph
        # Pickle file with all the training instances
        self.is_a_graph = networkx.DiGraph()
        with open(is_a_filename) as infile:
            while True:
                line = infile.readline()
                if not line:
                    break
                line = line.replace("\n", "")
                terms = line.split("\t")
                self.is_a_graph.add_edge(terms[0], terms[1])

        # Load all the training instances
        tasks = list()
        with open(os.path.join(data_dir, tasks_filename), "rb") as tasks_file:
            tasks.extend(pickle.load(tasks_file))
        self.data = ListDataLoader(tasks)

        self.sln_queries_id = 0
        self.sln_queries_dict = dict()

        self.sla_queries_id = 0
        self.sla_queries_dict = dict()

        self.slr_queries_id = 0
        self.slr_queries_dict = dict()

        self.sli1_queries_id = 1000
        self.sli1_queries_dict = dict()

        self.sli3_queries_id = 0
        self.sli3_queries_dict = dict()

        self.sli5_queries_id = 0
        self.sli5_queries_dict = dict()

        self.sli5b_queries_id = 0
        self.sli5b_queries_dict = dict()

        self.sli6_name_queries_id = 0
        self.sli6_name_queries_dict = dict()

        self.sli6_attr_queries_id = 0
        self.sli6_attr_queries_dict = dict()

        self.slii3_queries_id = 0
        self.slii3_queries_dict = dict()

        self.name_anno_id = 0
        self.attr_anno_id = 0
        self.rela_anno_id = 0

        self.block_self_relations = True

        subdictionary = dict()
        subdictionary["sln"] = list()
        subdictionary["sla"] = list()
        subdictionary["slr"] = list()
        subdictionary["i1"] = list()
        subdictionary["i3"] = list()
        subdictionary["i5"] = list()
        subdictionary["i5b"] = list()
        subdictionary["i6_name"] = list()
        subdictionary["i6_attr"] = list()
        subdictionary["ii3"] = list()

        self.metadata = dict()
        self.metadata["name"] = dict(
            [(name, copy.deepcopy(subdictionary)) for name in self.object_domain]
        )
        self.metadata["attr"] = dict(
            [(name, copy.deepcopy(subdictionary)) for name in self.attr_canon]
        )
        self.metadata["rela"] = dict(
            [(name, copy.deepcopy(subdictionary)) for name in self.rela_canon]
        )

        # This option specifies the number of supervised samples per name type 
        self.sln_queries_num = 0 # combination for structural pruning (set at 4)
        # This option specifies the number of supervised samples per attribute type 
        self.sla_queries_num = 0
        # This option specifies the number of supervised samples per relation type 
        self.slr_queries_num = 0
        self.sli1_queries_num = 0
        # self.sli1_queries_num = 30000
        self.sli3_queries_num = 0 #5000
        self.sli5_queries_num = 5000 # 10_000 # 5000 # combination for structural pruning
        self.sli5b_queries_num = 0
        self.sli6_name_queries_num = 0
        self.sli6_attr_queries_num = 0
        self.slii3_queries_num = 0

    def choose_training_instances(
        self,
        vqar_instances_file,
        vqar_statistics_file,
        outputfilename,
    ):
        with open(vqar_instances_file, "rb") as instances_file:
            vqar_instances = pickle.load(instances_file)
        with open(vqar_statistics_file, "rb") as metadata_file:
            vqar_statistics = pickle.load(metadata_file)

        random_instances = dict()
        for name in self.object_domain:
            candidates = vqar_statistics["name"][name]["sln"]
            if len(candidates) > self.sln_queries_num:
                candidates = list()
                for _ in range(self.sln_queries_num):
                    index = random.randrange(len(vqar_statistics["name"][name]["sln"]))
                    candidates.append(vqar_statistics["name"][name]["sln"][index])
            random_instances.update(
                {
                    candidate: vqar_instances["sln"][candidate]
                    for candidate in candidates
                }
            )

        for name in self.attr_canon:
            candidates = vqar_statistics["attr"][name]["sla"]
            if len(candidates) > self.sla_queries_num:
                candidates = list()
                for _ in range(self.sla_queries_num):
                    index = random.randrange(len(vqar_statistics["attr"][name]["sla"]))
                    candidates.append(vqar_statistics["attr"][name]["sla"][index])
            random_instances.update(
                {
                    candidate: vqar_instances["sla"][candidate]
                    for candidate in candidates
                }
            )

        for name in self.rela_canon:
            candidates = vqar_statistics["rela"][name]["slr"]
            if len(candidates) > self.slr_queries_num:
                candidates = list()
                for _ in range(self.slr_queries_num):
                    index = random.randrange(len(vqar_statistics["rela"][name]["slr"]))
                    candidates.append(vqar_statistics["rela"][name]["slr"][index])
            random_instances.update(
                {
                    candidate: vqar_instances["slr"][candidate]
                    for candidate in candidates
                }
            )

        i1_list = list()
        i3_list = list()
        i5_list = list()
        i5b_list = list()
        i6_name_list = list()
        i6_attr_list = list()
        ii3_list = list()

        for name in self.object_domain:
            i1_list.extend(vqar_statistics["name"][name]["i1"])
            i3_list.extend(vqar_statistics["name"][name]["i3"])
            i5_list.extend(vqar_statistics["name"][name]["i5"])
            i5b_list.extend(vqar_statistics["name"][name]["i5b"])
            i6_name_list.extend(vqar_statistics["name"][name]["i6_name"])
            ii3_list.extend(vqar_statistics["name"][name]["ii3"])

        for name in self.attr_canon:
            i1_list.extend(vqar_statistics["attr"][name]["i1"])
            i3_list.extend(vqar_statistics["attr"][name]["i3"])
            i5_list.extend(vqar_statistics["attr"][name]["i5"])
            i6_attr_list.extend(vqar_statistics["attr"][name]["i6_attr"])
            ii3_list.extend(vqar_statistics["attr"][name]["ii3"])

        for name in self.rela_canon:
            i1_list.extend(vqar_statistics["rela"][name]["i1"])
            i3_list.extend(vqar_statistics["rela"][name]["i3"])
            i5_list.extend(vqar_statistics["rela"][name]["i5"])
            i5b_list.extend(vqar_statistics["rela"][name]["i5b"])
            i6_name_list.extend(vqar_statistics["rela"][name]["i6_name"])
            i6_attr_list.extend(vqar_statistics["rela"][name]["i6_attr"])
            ii3_list.extend(vqar_statistics["rela"][name]["ii3"])

        random.shuffle(i1_list)
        random.shuffle(i3_list)
        random.shuffle(i5_list)
        random.shuffle(i5b_list)
        random.shuffle(i6_name_list)
        random.shuffle(i6_attr_list)
        random.shuffle(ii3_list)

        random_instances.update(
            {
                candidate: vqar_instances["i1"][candidate]
                for candidate in i1_list[: self.sli1_queries_num]
            }
        )
        random_instances.update(
            {
                candidate: vqar_instances["i3"][candidate]
                for candidate in i3_list[: self.sli3_queries_num]
            }
        )
        random_instances.update(
            {
                candidate: vqar_instances["i5"][candidate]
                for candidate in i5_list[: self.sli5_queries_num]
            }
        )
        random_instances.update(
            {
                candidate: vqar_instances["i5b"][candidate]
                for candidate in i5b_list[: self.sli5b_queries_num]
            }
        )

        random_instances.update(
            {
                candidate: vqar_instances["i6_name"][candidate]
                for candidate in i6_name_list[: self.sli6_name_queries_num]
            }
        )
        random_instances.update(
            {
                candidate: vqar_instances["i6_attr"][candidate]
                for candidate in i6_attr_list[: self.sli6_attr_queries_num]
            }
        )
        random_instances.update(
            {
                candidate: vqar_instances["ii3"][candidate]
                for candidate in ii3_list[: self.slii3_queries_num]
            }
        )

        with open(outputfilename, "wb") as file:
            pickle.dump(random_instances, file, protocol=pickle.HIGHEST_PROTOCOL)

    def create_sl_name_instances(self, image_id, object_ids, object2type, idx2word):
        for obj in object_ids:
            # if object2type[obj] < topk_names:
            gold_type_id = object2type[obj]
            gold_type_name = idx2word.idx_to_name(gold_type_id)

            # Create the Datalog query
            query = f"Q(O) :- name({gold_type_name},O)"

            # Compute the lineage for each ground truth answer
            # The lineage is of the form: answer, rewriting, list of substitutions
            # There is always a single rewriting over the trainable predicates in our case;
            # otherwise, there would be multiple rewritings with multiple substitutions each.
            lineage = (obj, [[f"name({gold_type_name},{obj})"]], [1])

            # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
            query_id = f"{image_id}_sln_{self.sln_queries_id}"
            self.sln_queries_dict[query_id] = (
                image_id,
                query,
                lineage,
                [obj],
                [],
            )
            self.sln_queries_id = self.sln_queries_id + 1
            self.metadata["name"][gold_type_name]["sln"].append(query_id)

    def create_sl_attr_instances(self, image_id, object_ids, object2attrs, idx2word):
        for obj in object_ids:
            gold_attr_ids = object2attrs[obj]
            for gold_attr_id in gold_attr_ids:
                gold_attr_name = idx2word.idx_to_attr(gold_attr_id)
                # Create the Datalog query
                query = f"Q(O) :- attr({gold_attr_name},O)"

                # Compute the lineage for each ground truth answer
                # The lineage is of the form: answer, rewriting, list of substitutions
                # There is always a single rewriting over the trainable predicates in our case;
                # otherwise, there would be multiple rewritings with multiple substitutions each.
                lineage = (obj, [[f"attr({gold_attr_name},{obj})"]], [1])

                # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
                query_id = f"{image_id}_sla_{self.sla_queries_id}"
                self.sla_queries_dict[query_id] = (
                    image_id,
                    query,
                    lineage,
                    [obj],
                    [],
                )
                self.sla_queries_id = self.sla_queries_id + 1
                self.metadata["attr"][gold_attr_name]["sla"].append(query_id)

    def create_sl_rel_instances(self, image_id, object_ids, object2rels, idx2word):
        for obj in object2rels.keys():
            gold_rel_dict = object2rels[obj]
            for sub, gold_rel_id in gold_rel_dict.items():
                gold_rel_name = idx2word.idx_to_rela(gold_rel_id)
                # Create the Datalog query
                query = f"Q(O1,O2) :- rela({gold_rel_name},O1,O2)"

                # Compute the lineage for each ground truth answer
                # The lineage is of the form: answer, rewriting, list of substitutions
                # There is always a single rewriting over the trainable predicates in our case;
                # otherwise, there would be multiple rewritings with multiple substitutions each.
                lineage = ((obj, sub), [[f"rela({gold_rel_name},{obj},{sub})"]], [1])

                # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
                query_id = f"{image_id}_slr_{self.slr_queries_id}"
                self.slr_queries_dict[query_id] = (
                    image_id,
                    query,
                    lineage,
                    [],
                    [(obj, sub)],
                )
                self.slr_queries_id = self.slr_queries_id + 1
                self.metadata["rela"][gold_rel_name]["slr"].append(query_id)

    # I1. Create training samples based on queries of the form Q(O) :- name(n,O), where n is a superclass different from object and thing
    # One such training sample is created per input bounding box  
    def create_i1_instances(
        self, image_id, object_ids, object2type, idx2word, is_a_graph, object_domain
    ):
        for obj in object_ids:
            gold_type_id = object2type[obj]
            gold_type_name = idx2word.idx_to_name(gold_type_id)
            # Look at the is_a relation and choose an entry that connects to gold_type_name other than object and thing
            gold_type_superclass = find_superclass(gold_type_name, is_a_graph)
            # Find all the possible source nodes that connect to gold_type_superclass
            gold_type_name_alternatives = find_seed_classes(
                gold_type_superclass, is_a_graph, object_domain
            )
            # Each alternative is a possible

            # Create the Datalog query
            query = f"Q(O) :- name({gold_type_superclass},O)"

            # Compute the lineage for each ground truth answer
            # The lineage is of the form: answer, rewriting, list of substitutions
            # There is always a single rewriting over the trainable predicates in our case;
            # otherwise, there would be multiple rewritings with multiple substitutions each.
            proofs = [
                [f"name({gold_type_name_alternative},{obj})"]
                for gold_type_name_alternative in gold_type_name_alternatives   
            ]

            gt = [
                1 if gold_type_name_alternative == gold_type_name else 0
                for gold_type_name_alternative in gold_type_name_alternatives
            ]
            assert len([i for i in gt if i == 1]) > 0
            lineage = (obj, proofs, gt)

            # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
            query_id = f"{image_id}_sli1_{self.sli1_queries_id}"
            self.sli1_queries_dict[query_id] = (
                image_id,
                query,
                lineage,
                [obj],
                []
            )
            self.sli1_queries_id = self.sli1_queries_id + 1
            self.metadata["name"][gold_type_name]["i1"].append(query_id)

    # I3. Create training samples based on queries of the form Q(O1) :- name(n,O1), rel(r,O1,O2), where n is a superclass different from object and thing
    # One such training sample is created per input bounding box  
    def create_i3_instances(
        self,
        image_id,
        object_ids,
        object2type,
        object2rels,
        idx2word,
        is_a_graph,
        object_domain,
    ):
        for obj in object2rels.keys():
            gold_type_id = object2type[obj]
            gold_type_name = idx2word.idx_to_name(gold_type_id)
            # Look at the is_a relation and choose an entry that connects to gold_type_name other than object and thing
            gold_type_superclass = find_superclass(gold_type_name, is_a_graph)
            # Find all the possible source nodes that connect to gold_type_superclass
            gold_type_name_alternatives = find_seed_classes(
                gold_type_superclass, is_a_graph, object_domain
            )
            # Each alternative is a possible

            gold_rel_dict = object2rels[obj]
            for sub, gold_rel_id in gold_rel_dict.items():
                gold_rel_name = idx2word.idx_to_rela(gold_rel_id)
                # Create the Datalog query
                query = f"Q(O1) :- name({gold_type_superclass},O1),rela({gold_rel_name},O1,O2)"

                # Compute the lineage for each ground truth answer
                # The lineage is of the form: answer, rewriting, list of substitutions
                # There is always a single rewriting over the trainable predicates in our case;
                # otherwise, there would be multiple rewritings with multiple substitutions each.

                proofs = list()
                gt = list()
                for gold_type_name_alternative in gold_type_name_alternatives:
                    for box2 in object_ids:
                        name_fact = f"name({gold_type_name_alternative},{obj})"
                        rela_fact = f"rela({gold_rel_name},{obj},{box2})"
                        proofs.append([name_fact, rela_fact])
                        # Update the ground truth
                        if (box2, gold_rel_id) in object2rels[
                            obj
                        ].items() and gold_type_name_alternative == gold_type_name:
                            gt.append(1)
                        else:
                            gt.append(0)
                assert len([i for i in gt if i == 1]) > 0
                lineage = (obj, proofs, gt)

                # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
                query_id = f"{image_id}_sli3_{self.sli3_queries_id}"
                self.sli3_queries_dict[query_id] = (
                    image_id,
                    query,
                    lineage,
                    [obj],
                    [(obj,box2) for box2 in object_ids]
                )
                self.sli3_queries_id = self.sli3_queries_id + 1
                self.metadata["name"][gold_type_name]["i3"].append(query_id)
                self.metadata["rela"][gold_rel_name]["i3"].append(query_id)

    # I5. Create training samples based on queries of the form Q(O1) :- name(n,O1), rel(r,O1,O2), attr(a2,O2), where n is a superclass different from object and thing
    # The lineage per answer o1 is of the form:
    # \bigvee_{i} name(ni,o1) \wedge rel(r,o1,oj) \wedge attr(a2,oj)
    def create_i5_instances(
        self,
        image_id,
        object_ids,
        object2type,
        object2attrs,
        object2rels,
        idx2word,
        is_a_graph,
        object_domain,
    ):
        for obj in object2rels.keys():
            gold_type_id = object2type[obj]
            gold_type_name = idx2word.idx_to_name(gold_type_id)
            # Look at the is_a relation and choose an entry that connects to gold_type_name other than object and thing
            gold_type_superclass = find_superclass(gold_type_name, is_a_graph)
            ### assert gold_type_superclass == gold_type_name # Assertion disabled 
            # Find all the possible source nodes that connect to gold_type_superclass
            gold_type_name_alternatives = find_seed_classes(
                gold_type_superclass, is_a_graph, object_domain
            )
            # Each alternative is a possible

            gold_rel_dict = object2rels[obj]
            for sub, gold_rel_id in gold_rel_dict.items():
                gold_rel_name = idx2word.idx_to_rela(gold_rel_id)
                if sub in object2attrs:
                    gold_sub_attr_ids = object2attrs[sub]
                    for gold_sub_attr_id in gold_sub_attr_ids:
                        gold_sub_attr_name = idx2word.idx_to_attr(gold_sub_attr_id)
                        # Create the Datalog query
                        query = f"Q(O1) :- name({gold_type_superclass},O1),rela({gold_rel_name},O1,O2),attr({gold_sub_attr_name},O2)"

                        # Compute the lineage for each ground truth answer
                        # The lineage is of the form: answer, rewriting, list of substitutions
                        # There is always a single rewriting over the trainable predicates in our case;
                        # otherwise, there would be multiple rewritings with multiple substitutions each.

                        proofs = list()
                        gt = list()
                        for gold_type_name_alternative in gold_type_name_alternatives:
                            for box2 in object_ids:
                                name_fact = f"name({gold_type_name_alternative},{obj})"
                                rela_fact = f"rela({gold_rel_name},{obj},{box2})"
                                attr_fact = f"attr({gold_sub_attr_name},{box2})"
                                proofs.append([name_fact, rela_fact, attr_fact])
                                if (
                                    box2 in object2attrs
                                    and gold_sub_attr_id in object2attrs[box2]
                                    and (box2, gold_rel_id) in object2rels[obj].items()
                                    and gold_type_name_alternative == gold_type_name
                                ):
                                    gt.append(1)
                                else:
                                    gt.append(0)
                        assert len([i for i in gt if i == 1]) > 0
                        lineage = (obj, proofs, gt)

                        # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
                        query_id = f"{image_id}_sli5_{self.sli5_queries_id}"
                        self.sli5_queries_dict[query_id] = (
                            image_id,
                            query,
                            lineage,
                            object_ids,
                            [(obj,box2) for box2 in object_ids]
                        )
                        self.sli5_queries_id = self.sli5_queries_id + 1
                        self.metadata["name"][gold_type_name]["i5"].append(query_id)
                        self.metadata["rela"][gold_rel_name]["i5"].append(query_id)
                        self.metadata["attr"][gold_sub_attr_name]["i5"].append(query_id)

    # I5b. Create training samples based on queries of the form Q(O1) :- name(n1,O1), rel(r,O1,O2), name(n2,O2), where n1, n2 are superclasses different from object and thing
    def create_i5b_instances(
        self,
        image_id,
        object_ids,
        object2type,
        object2rels,
        idx2word,
        is_a_graph,
        object_domain,
    ):
        for obj in object2rels.keys():
            gold_type_id = object2type[obj]
            gold_type_name = idx2word.idx_to_name(gold_type_id)
            # Look at the is_a relation and choose an entry that connects to gold_type_name other than object and thing
            gold_type_superclass = find_superclass(gold_type_name, is_a_graph)
            ### assert gold_type_superclass == gold_type_name # Assertion disabled 
            # Find all the possible source nodes that connect to gold_type_superclass
            gold_type_name_alternatives = find_seed_classes(
                gold_type_superclass, is_a_graph, object_domain
            )
            # Each alternative is a possible

            gold_rel_dict = object2rels[obj]
            for sub, gold_rel_id in gold_rel_dict.items():
                gold_rel_name = idx2word.idx_to_rela(gold_rel_id)
                
                gold_sub_type_id = object2type[sub]
                gold_sub_name = idx2word.idx_to_name(gold_sub_type_id)
                # Look at the is_a relation and choose an entry that connects to gold_type_name other than object and thing
                gold_sub_superclass = find_superclass(gold_sub_name, is_a_graph)
                ### assert gold_type_superclass == gold_type_name # Assertion disabled 
                # Find all the possible source nodes that connect to gold_type_superclass
                gold_sub_name_alternatives = find_seed_classes(
                    gold_sub_superclass, is_a_graph, object_domain
                )

                # Create the Datalog query
                query = f"Q(O1) :- name({gold_type_superclass},O1),rela({gold_rel_name},O1,O2),name({gold_sub_superclass},O2)"
                
                
                # Compute the lineage for each ground truth answer
                # The lineage is of the form: answer, rewriting, list of substitutions
                # There is always a single rewriting over the trainable predicates in our case;
                # otherwise, there would be multiple rewritings with multiple substitutions each.

                proofs = list()
                gt = list()
                for gold_type_name_alternative in gold_type_name_alternatives:
                    for gold_sub_name_alternative in gold_sub_name_alternatives: 
                        for box2 in object_ids:
                            name_fact = f"name({gold_type_name_alternative},{obj})"
                            rela_fact = f"rela({gold_rel_name},{obj},{box2})"
                            attr_fact = f"name({gold_sub_name_alternative},{box2})"
                            proofs.append([name_fact, rela_fact, attr_fact])
                            if (
                                (box2, gold_rel_id) in object2rels[obj].items()
                                and gold_type_name_alternative == gold_type_name
                                and gold_sub_name_alternative == idx2word.idx_to_name(object2type[box2])
                            ):
                                gt.append(1)
                            else:
                                gt.append(0)
                assert len([i for i in gt if i == 1]) > 0
                lineage = (obj, proofs, gt)

                # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
                query_id = f"{image_id}_sli5b_{self.sli5b_queries_id}"
                self.sli5b_queries_dict[query_id] = (
                    image_id,
                    query,
                    lineage,
                    object_ids,
                    [(obj,box2) for box2 in object_ids]
                )
                self.sli5b_queries_id = self.sli5b_queries_id + 1
                self.metadata["name"][gold_type_name]["i5b"].append(query_id)
                self.metadata["rela"][gold_rel_name]["i5b"].append(query_id)
                self.metadata["name"][gold_sub_name]["i5b"].append(query_id)

    # I6. Create queries of the form Q(O1) :- rel(r,O1,O2), attr(a2,O2)
    # The lineage per answer o1 is of the form:
    # \bigvee_{i} rel(r,o1,oj) \wedge attr(a2,oj)
    def create_i6_instances_attr(
        self, image_id, object_ids, object2attrs, object2rels, idx2word
    ):
        for obj in object2rels.keys():
            gold_rel_dict = object2rels[obj]
            for sub, gold_rel_id in gold_rel_dict.items():
                gold_rel_name = idx2word.idx_to_rela(gold_rel_id)
                if sub in object2attrs:
                    gold_sub_attr_ids = object2attrs[sub]
                    for gold_sub_attr_id in gold_sub_attr_ids:
                        gold_sub_attr_name = idx2word.idx_to_attr(gold_sub_attr_id)
                        # Create the Datalog query
                        query = f"Q(O1) :- rela({gold_rel_name},O1,O2),attr({gold_sub_attr_name},O2)"

                        # Compute the lineage for each ground truth answer
                        # The lineage is of the form: answer, rewriting, list of substitutions
                        # There is always a single rewriting over the trainable predicates in our case;
                        # otherwise, there would be multiple rewritings with multiple substitutions each.

                        proofs = list()
                        gt = list()
                        for box2 in object_ids:
                            rela_fact = f"rela({gold_rel_name},{obj},{box2})"
                            attr_fact = f"attr({gold_sub_attr_name},{box2})"
                            proofs.append([rela_fact, attr_fact])
                            if (
                                box2 in object2attrs
                                and gold_sub_attr_id in object2attrs[box2]
                                and (box2, gold_rel_id) in object2rels[obj].items()
                            ):
                                gt.append(1)
                            else:
                                gt.append(0)
                        assert len([i for i in gt if i == 1]) > 0
                        lineage = (obj, proofs, gt)

                        # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
                        query_id = f"{image_id}_sli6_attr_{self.sli6_attr_queries_id}"
                        self.sli6_attr_queries_dict[query_id] = (
                            image_id,
                            query,
                            lineage,
                            object_ids,
                            [(obj,box2) for box2 in object_ids]
                        )
                        self.sli6_attr_queries_id = self.sli6_attr_queries_id + 1
                        self.metadata["rela"][gold_rel_name]["i6_attr"].append(query_id)
                        self.metadata["attr"][gold_sub_attr_name]["i6_attr"].append(
                            query_id
                        )

    # I6. Create queries of the form Q(O1) :- rel(r,O1,O2), name(n,O2), where n is a superclass different from object and thing
    # The lineage per answer o1 is of the form:
    # \bigvee_{i} rel(r,o1,oj) \wedge name(n,oj)
    def create_i6_instances_name(
        self, image_id, object_ids, object2type, object2rels, idx2word, is_a_graph, object_domain,
    ):
        for obj in object2rels.keys():
            gold_rel_dict = object2rels[obj]
            for sub, gold_rel_id in gold_rel_dict.items():
                gold_rel_name = idx2word.idx_to_rela(gold_rel_id)

                gold_sub_name_id = object2type[sub]
                gold_sub_name = idx2word.idx_to_name(gold_sub_name_id)
                
                gold_sub_superclass = find_superclass(gold_sub_name, is_a_graph)
                ### assert gold_type_superclass == gold_type_name # Assertion disabled 
                # Find all the possible source nodes that connect to gold_type_superclass
                gold_sub_name_alternatives = find_seed_classes(
                    gold_sub_superclass, is_a_graph, object_domain
                )
                
                # Create the Datalog query
                query = f"Q(O1) :- rela({gold_rel_name},O1,O2),name({gold_sub_superclass},O2)"

                # Compute the lineage for each ground truth answer
                # The lineage is of the form: answer, rewriting, list of substitutions
                # There is always a single rewriting over the trainable predicates in our case;
                # otherwise, there would be multiple rewritings with multiple substitutions each.

                proofs = list()
                gt = list()
                for gold_sub_name_alternative in gold_sub_name_alternatives:
                    for box2 in object_ids:
                        rela_fact = f"rela({gold_rel_name},{obj},{box2})"
                        name_fact = f"name({gold_sub_name_alternative},{box2})"
                        proofs.append([rela_fact, name_fact])
                        if (box2, gold_rel_id) in object2rels[obj].items() and gold_sub_name_alternative == idx2word.idx_to_name(object2type[box2]):
                            gt.append(1)
                        else:
                            gt.append(0)
                assert len([i for i in gt if i == 1]) > 0
                lineage = (obj, proofs, gt)

                # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
                query_id = f"{image_id}_sli6_name_{self.sli6_name_queries_id}"
                self.sli6_name_queries_dict[query_id] = (
                    image_id,
                    query,
                    lineage,
                    object_ids,
                    [(obj,box2) for box2 in object_ids]
                )
                self.sli6_name_queries_id = self.sli6_name_queries_id + 1
                self.metadata["rela"][gold_rel_name]["i6_name"].append(query_id)
                self.metadata["name"][gold_sub_name]["i6_name"].append(query_id)

    # II3. Create queries of the form Q(O1) :- rel(r,O1,O2), name(n2,O2), attr(a2,O2), where n2 is a superclass different from object and thing
    def create_ii3_instance(
        self,
        image_id,
        object_ids,
        object2type,
        object2attrs,
        object2rels,
        idx2word,
        is_a_graph,
        object_domain,
    ):
        for obj in object2rels.keys():
            gold_rel_dict = object2rels[obj]
            for sub, gold_rel_id in gold_rel_dict.items():
                if obj != sub:
                    gold_sub_type_id = object2type[sub]
                    gold_sub_type_name = idx2word.idx_to_name(gold_sub_type_id)
                    # Look at the is_a relation and choose an entry that connects to gold_type_name other than object and thing
                    gold_sub_type_superclass = find_superclass(
                        gold_sub_type_name, is_a_graph
                    )

                    # Find all the possible source nodes that connect to gold_type_superclass
                    gold_sub_type_name_alternatives = find_seed_classes(
                        gold_sub_type_superclass, is_a_graph, object_domain
                    )
                    # Each alternative is a possible

                    gold_rel_name = idx2word.idx_to_rela(gold_rel_id)
                    if sub in object2attrs:
                        gold_sub_attr_ids = object2attrs[sub]
                        for gold_sub_attr_id in gold_sub_attr_ids:
                            gold_sub_attr_name = idx2word.idx_to_attr(gold_sub_attr_id)

                            # Create the Datalog query
                            query = f"Q(O1) :- rela({gold_rel_name},O1,O2), name({gold_sub_type_superclass},O2), attr({gold_sub_attr_name},O2)"

                            # Compute the lineage for each ground truth answer

                            proofs = list()
                            gt = list()
                            boxes2 = [obj_id for obj_id in object_ids if obj_id != obj]
                            for gold_sub_type_name_alternative in gold_sub_type_name_alternatives:
                                for box2 in boxes2:
                                    rela_fact = f"rela({gold_rel_name},{obj},{box2})"
                                    name2_fact = (
                                        f"name({gold_sub_type_name_alternative},{box2})"
                                    )
                                    attr_fact = f"attr({gold_sub_attr_name},{box2})"
                                    proofs.append(
                                        [
                                            rela_fact,
                                            name2_fact,
                                            attr_fact,
                                        ]
                                    )
                                    if (box2, gold_rel_id) in object2rels[
                                        obj
                                    ].items() and gold_sub_type_name_alternative == idx2word.idx_to_name(object2type[box2]):
                                        gt.append(1)
                                    else:
                                        gt.append(0)
                            assert len([i for i in gt if i == 1]) > 0
                            lineage = (obj, proofs, gt)

                            # Each query has a unique id composed of the 'image_id', the type of the query and a unique query identifier
                            query_id = f"{image_id}_slii3_{self.slii3_queries_id}"
                            self.slii3_queries_dict[query_id] = (
                                image_id,
                                query,
                                lineage,
                                object_ids,
                                [(obj,box2) for box2 in object_ids]
                            )
                            self.slii3_queries_id = self.slii3_queries_id + 1
                            self.metadata["name"][gold_sub_type_name]["ii3"].append(
                                query_id
                            )
                            self.metadata["rela"][gold_rel_name]["ii3"].append(query_id)
                            self.metadata["attr"][gold_sub_attr_name]["ii3"].append(
                                query_id
                            )

    def generate_training_instances(self, filename1, filename2):
        index = 0
        done = set()
        for batch in self.data:
            for datapoint in batch:
                if not datapoint["image_id"] in done:
                    print(datapoint["image_id"])
                    scene = datapoint["scene_graph"]
                    # Each scene graph should be filtered out so that we maintain only the desired classes.
                    object2type, object2attrs, object2rels = maintain_topk_properties(
                        scene, self.topk_names, self.topk_attrs, self.topk_relas
                    )

                    # Keep the object_ids that are in the topk list
                    object_ids = [obj for obj, _ in object2type.items()]
                    if len(object_ids) == 1:
                        continue

                    image_id = datapoint["image_id"]
                    # Create supervised instances for the name classifier
                    if self.sln_queries_num != 0:
                        self.create_sl_name_instances(
                            image_id, object_ids, object2type, self.idx2word
                        )

                    # Create supervised instances for the attr classifier
                    if self.sla_queries_num != 0:
                        self.create_sl_attr_instances(
                            image_id, object_ids, object2attrs, self.idx2word
                        )

                    # Create supervised instances for the rel classifier
                    if self.slr_queries_num != 0:
                        self.create_sl_rel_instances(
                            image_id, object_ids, object2rels, self.idx2word
                        )

                    # I1. Create queries of the form Q(O) :- name(n,O), where n is a superclass different from object and thing
                    if self.sli1_queries_num != 0:
                        self.create_i1_instances(
                            image_id,
                            object_ids,
                            object2type,
                            self.idx2word,
                            self.is_a_graph,
                            self.object_domain,
                        )

                    # I3. Create queries of the form Q(O1) :- name(n,O1), rel(r,O1,O2)
                    if self.sli3_queries_num != 0:
                        self.create_i3_instances(
                            image_id,
                            object_ids,
                            object2type,
                            object2rels,
                            self.idx2word,
                            self.is_a_graph,
                            self.object_domain,
                        )

                    # I5. Create queries of the form Q(O1) :- name(n,O1), rel(r,O1,O2), attr(a2,O2)
                    if self.sli5_queries_num != 0:
                        self.create_i5_instances(
                            image_id,
                            object_ids,
                            object2type,
                            object2attrs,
                            object2rels,
                            self.idx2word,
                            self.is_a_graph,
                            self.object_domain,
                        )

                    # I5b. Create queries of the form Q(O1) :- name(n1,O1), rel(r,O1,O2), name(n2,O2)
                    if self.sli5b_queries_num != 0:
                        self.create_i5b_instances(
                            image_id,
                            object_ids,
                            object2type,
                            object2rels,
                            self.idx2word,
                            self.is_a_graph,
                            self.object_domain,
                        )

                    # I6. Create queries of the form Q(O1) :- rel(r,O1,O2), attr(n,O2)
                    if self.sli6_attr_queries_num != 0:
                        self.create_i6_instances_attr(
                            image_id,
                            object_ids,
                            # object2type,
                            object2attrs,
                            object2rels,
                            self.idx2word,
                        )

                    # I6. Create queries of the form Q(O1) :- rel(r,O1,O2), name(n,O2)
                    if self.sli6_name_queries_num != 0:
                        self.create_i6_instances_name(
                            image_id,
                            object_ids,
                            object2type,
                            object2rels,
                            self.idx2word,
                            self.is_a_graph,
                            self.object_domain,
                        )

                    # II3. Create queries of the form Q(O1) :- rel(r,O1,O2), name(n2,O2), attr(a2,O2)
                    if self.slii3_queries_num != 0:
                        self.create_ii3_instance(
                            image_id,
                            object_ids,
                            object2type,
                            object2attrs,
                            object2rels,
                            self.idx2word,
                            self.is_a_graph,
                            self.object_domain,
                        )

                    done.add(datapoint["image_id"])

                    # Create a new datapoint per training instance
                    # The image_id, scene_graph, url, object_feature and object_ids are unique per image, hence these are kept only once
                    # Hence, we maintain a pickle file for the query, the lineage of the correct and false answers
                    # and a separate pickle file with all the remaining information, removing the 'question' field

                    index = index + 1

        self._materialize_instances_and_metadata(filename1, filename2)

    def _materialize_instances_and_metadata(self, filename1, filename2):
        subdictionary = dict()
        subdictionary["sln"] = self.sln_queries_dict
        subdictionary["sla"] = self.sla_queries_dict
        subdictionary["slr"] = self.slr_queries_dict
        subdictionary["i1"] = self.sli1_queries_dict
        subdictionary["i3"] = self.sli3_queries_dict
        subdictionary["i5"] = self.sli5_queries_dict
        subdictionary["i5b"] = self.sli5b_queries_dict
        subdictionary["i6_name"] = self.sli6_name_queries_dict
        subdictionary["i6_attr"] = self.sli6_attr_queries_dict
        subdictionary["ii3"] = self.slii3_queries_dict
        with open(filename1, "wb") as file:
            pickle.dump(subdictionary, file, protocol=pickle.HIGHEST_PROTOCOL)

        with open(filename2, "wb") as file:
            pickle.dump(self.metadata, file, protocol=pickle.HIGHEST_PROTOCOL)

    def create_name_annotations(
        self, image_id, object_ids, object2type, idx2word, number_of_instances_per_class
    ):
        test_map = dict()
        for obj in object_ids:
            # if object2type[obj] < topk_names:
            gold_type_id = object2type[obj]
            gold_type_name = idx2word.idx_to_name(gold_type_id)
            query_id = f"{image_id}_name_anno_{self.name_anno_id}"
            # test_instances[query_id] = (image_id, "name", (obj), gold_type_name)
            if gold_type_name not in test_map:
                test_map[gold_type_name] = list()
            test_map[gold_type_name].append(
                (query_id, image_id, "name", (obj), gold_type_name)
            )
            self.name_anno_id = self.name_anno_id + 1
        for key in test_map:
            random.shuffle(test_map[key])

        test_instances = dict()
        for key in test_map:
            for index in range(number_of_instances_per_class):
                query_id, image_id, t, (obj), gold_type_name = test_map[key][index]
                test_instances[query_id] = (image_id, t, (obj), gold_type_name)
        return test_instances

    def create_attr_annotations(self, image_id, object_ids, object2attrs, idx2word):
        test_instances = dict()
        for obj in object_ids:
            gold_attr_names = [
                idx2word.idx_to_attr(gold_attr_id) for gold_attr_id in object2attrs[obj]
            ]
            query_id = f"{image_id}_attr_anno_{self.attr_anno_id}"
            test_instances[query_id] = (image_id, "attr", (obj), gold_attr_names)
            self.attr_anno_id = self.attr_anno_id + 1
        return test_instances

    def create_rel_annotations(self, image_id, object2rels, idx2word):
        test_instances = dict()
        for obj in object2rels.keys():
            gold_rel_dict = object2rels[obj]
            for sub, gold_rel_id in gold_rel_dict.items():
                gold_rel_name = idx2word.idx_to_rela(gold_rel_id)
                query_id = f"{image_id}_rela_anno_{self.rela_anno_id}"
                test_instances[query_id] = (image_id, "rela", (obj, sub), gold_rel_name)
                self.rela_anno_id = self.rela_anno_id + 1
        return test_instances

    def generate_test_instances(
        self,
        outputfilename,
        idx2word,
        name_inst=True,
        attr_inst=True,
        rela_inst=True,
        number_of_instances_per_class=200,
    ):
        # At the moment, we disable attribute testing, due to puzzle with the loss
        name_test_map = dict()
        attr_test_map = dict()
        rela_test_map = dict()
        done = set()
        for batch in self.data:
            for datapoint in batch:
                # print(datapoint["image_id"])
                if not datapoint["image_id"] in done:
                    # List of all object ids in the image
                    scene = datapoint["scene_graph"]
                    # Map from object id to its type id
                    object2type, object2attrs, object2rels = maintain_topk_properties(
                        scene, self.topk_names, self.topk_attrs, self.topk_relas
                    )
                    # Keep the object_ids that are in the topk list
                    object_ids = [obj for obj, _ in object2type.items()]
                    image_id = datapoint["image_id"]
                    if name_inst:
                        # Create supervised instances for the name classifier
                        for obj in object_ids:
                            gold_type_id = object2type[obj]
                            gold_type_name = idx2word.idx_to_name(gold_type_id)
                            query_id = f"{image_id}_name_anno_{self.name_anno_id}"
                            if gold_type_name not in name_test_map:
                                name_test_map[gold_type_name] = list()
                            name_test_map[gold_type_name].append(
                                (query_id, image_id, "name", (obj), gold_type_name)
                            )
                            self.name_anno_id = self.name_anno_id + 1

                    if attr_inst:
                        for obj in object2attrs:
                            for gold_attr_id in object2attrs[obj]:
                                gold_attr_name = idx2word.idx_to_attr(gold_attr_id)
                                query_id = f"{image_id}_attr_anno_{self.rela_anno_id}"
                                if gold_attr_name not in attr_test_map:
                                    attr_test_map[gold_attr_name] = list()
                                attr_test_map[gold_attr_name].append(
                                    (
                                        query_id,
                                        image_id,
                                        "attr",
                                        (obj),
                                        [
                                            idx2word.idx_to_attr(gold_attr_id)
                                            for gold_attr_id in object2attrs[obj]
                                        ],
                                    )
                                )
                                self.attr_anno_id = self.attr_anno_id + 1

                    if rela_inst:
                        # Create supervised instances for the name classifier
                        for obj in object2rels:
                            gold_rel_dict = object2rels[obj]
                            for sub, gold_rel_id in gold_rel_dict.items():
                                gold_rel_name = idx2word.idx_to_rela(gold_rel_id)
                                query_id = f"{image_id}_rela_anno_{self.rela_anno_id}"
                                if gold_rel_name not in rela_test_map:
                                    rela_test_map[gold_rel_name] = list()
                                rela_test_map[gold_rel_name].append(
                                    (
                                        query_id,
                                        image_id,
                                        "rela",
                                        (obj, sub),
                                        gold_rel_name,
                                    )
                                )
                                self.rela_anno_id = self.rela_anno_id + 1

                    done.add(datapoint["image_id"])

                    # Create a new datapoint per training instance
                    # The image_id, scene_graph, url, object_feature and object_ids are unique per image, hence these are kept only once
                    # Hence, we maintain a pickle file for the query, the lineage of the correct and false answers
                    # and a separate pickle file with all the remaining information, removing the 'question' field

        for key in name_test_map:
            random.shuffle(name_test_map[key])
        for key in attr_test_map:
            random.shuffle(attr_test_map[key])
        for key in rela_test_map:
            random.shuffle(rela_test_map[key])
        name_test_instances = dict()
        attr_test_instances = dict()
        rela_test_instances = dict()

        for key in name_test_map:
            index = 0
            while index < number_of_instances_per_class and index < len(
                name_test_map[key]
            ):
                query_id, image_id, t, (obj), gold_type_name = name_test_map[key][index]
                name_test_instances[query_id] = (image_id, t, (obj), gold_type_name)
                index = index + 1

        for key in attr_test_map:
            index = 0
            while index < number_of_instances_per_class and index < len(
                attr_test_map[key]
            ):
                query_id, image_id, t, (obj), gold_attr_names = attr_test_map[key][
                    index
                ]
                attr_test_instances[query_id] = (image_id, t, (obj), gold_attr_names)
                index = index + 1

        for key in rela_test_map:
            index = 0
            while index < number_of_instances_per_class and index < len(
                rela_test_map[key]
            ):
                query_id, image_id, t, (obj, sub), gold_rela_name = rela_test_map[key][
                    index
                ]
                rela_test_instances[query_id] = (
                    image_id,
                    t,
                    (obj, sub),
                    gold_rela_name,
                )
                index = index + 1

        with open(outputfilename, "wb") as file:
            pickle.dump(
                {**name_test_instances, **attr_test_instances, **rela_test_instances},
                file,
                protocol=pickle.HIGHEST_PROTOCOL,
            )
