import os
import math
import pickle
import random
import collections
import numpy as np
import networkx as nx
import torch
from loguru import logger
import torch.nn.functional as F
from torch_geometric.data import Data
from tqdm import tqdm

TRAIN_MODE = "train"
VAL_MODE = "val"
TEST_MODE = "test"

GraphCollection = collections.namedtuple(
    'GraphCollection', 
    ['from_idx', 'to_idx', 'node_features', 'edge_features', 'graph_idx', 'num_graphs']
)

class DatasetBase:
    def __init__(self, conf,  mode):
        assert mode in [TRAIN_MODE, VAL_MODE, TEST_MODE]
        self.mode = mode
        # self.conf = conf
        self.dataset_name = conf.dataset.name
        self.init_dataset_stats(conf)
        self.data_type = conf.dataset.data_type

        self.batch_size = conf.training.batch_size
        self.device = conf.training.device
        
        self.load_graphs(conf)
        # self.build_adjacency_info()
        self.preprocess_subgraphs_to_pyG_data()
        self.print_dataset_stats()
        
        self.memo_batch = {}

    def init_dataset_stats(self, conf):
        raise NotImplementedError()   

    def load_graphs(self, conf):
        raise NotImplementedError()
    
    
    def create_pyG_object(self, graph):
        num_nodes = graph.number_of_nodes()
        features = torch.ones(num_nodes, 1, dtype=torch.float, device=self.device)

        edges = list(graph.edges)
        doubled_edges = [[x, y] for (x, y) in edges] #+ [[y, x] for (x, y) in edges]
        edge_index = torch.tensor(np.array(doubled_edges).T, dtype=torch.int64, device=self.device)
        return Data(x = features, edge_index = edge_index), num_nodes

    def preprocess_subgraphs_to_pyG_data(self):
        self.query_graph_data, self.query_graph_sizes = zip(
            *[self.create_pyG_object(query_graph) for query_graph in self.query_graphs]
        )
        self.corpus_graph_data, self.corpus_graph_sizes = zip(
            *[self.create_pyG_object(corpus_graph) for corpus_graph in self.corpus_graphs]
        )
        self.query_graph_data = list(self.query_graph_data)
        self.corpus_graph_data = list(self.corpus_graph_data)
        
    
    # def build_adjacency_info(self):
    #     def adj_list_from_graph_list(graphs):
    #         adj_list = []
    #         for graph in graphs:
    #             unpadded_adj = torch.tensor(nx.adjacency_matrix(graph).todense(), dtype=torch.float, device=self.device)
    #             assert unpadded_adj.shape[0] == unpadded_adj.shape[1]
    #             num_nodes = len(unpadded_adj)
    #             padded_adj = F.pad(unpadded_adj, pad = (0, self.max_node_set_size - num_nodes, 0, self.max_node_set_size - num_nodes))
    #             adj_list.append(padded_adj)
    #         return adj_list

    #     self.query_adj_list = adj_list_from_graph_list(self.query_graphs)
    #     self.corpus_adj_list = adj_list_from_graph_list(self.corpus_graphs)
    
    def _pack_batch_1d(self, graphs):
        from_idx = []
        to_idx = []
        graph_idx = []
        # all_graphs = [individual_graph for graph_tuple in graphs for individual_graph in graph_tuple]
        all_graphs = graphs

        total_nodes, total_edges = 0, 0
        for idx, graph in enumerate(all_graphs):
            num_nodes = graph.number_of_nodes()
            num_edges = graph.number_of_edges()
            edges = np.array(graph.edges(), dtype=np.int32)
            
            from_idx.append(edges[:, 0] + total_nodes)
            to_idx.append(edges[:, 1] + total_nodes)
            graph_idx.append(np.ones(num_nodes, dtype=np.int32) * idx)

            total_nodes += num_nodes
            total_edges += num_edges
        
        return GraphCollection(
            from_idx = torch.tensor(np.concatenate(from_idx, axis=0), dtype=torch.int64, device=self.device),
            to_idx = torch.tensor(np.concatenate(to_idx, axis=0), dtype=torch.int64, device=self.device),
            graph_idx = torch.tensor(np.concatenate(graph_idx, axis=0), dtype=torch.int64, device=self.device),
            num_graphs = len(all_graphs),
            node_features = torch.ones(total_nodes, 1, dtype=torch.float, device=self.device),
            edge_features = torch.ones(total_edges, 1, dtype=torch.float, device=self.device)
        )   
 
    def _pack_batch(self, graphs):
        from_idx = []
        to_idx = []
        graph_idx = []
        all_graphs = [individual_graph for graph_tuple in graphs for individual_graph in graph_tuple]

        total_nodes, total_edges = 0, 0
        for idx, graph in enumerate(all_graphs):
            num_nodes = graph.number_of_nodes()
            num_edges = graph.number_of_edges()
            edges = np.array(graph.edges(), dtype=np.int32)
            
            from_idx.append(edges[:, 0] + total_nodes)
            to_idx.append(edges[:, 1] + total_nodes)
            graph_idx.append(np.ones(num_nodes, dtype=np.int32) * idx)

            total_nodes += num_nodes
            total_edges += num_edges
        
        return GraphCollection(
            from_idx = torch.tensor(np.concatenate(from_idx, axis=0), dtype=torch.int64, device=self.device),
            to_idx = torch.tensor(np.concatenate(to_idx, axis=0), dtype=torch.int64, device=self.device),
            graph_idx = torch.tensor(np.concatenate(graph_idx, axis=0), dtype=torch.int64, device=self.device),
            num_graphs = len(all_graphs),
            node_features = torch.ones(total_nodes, 1, dtype=torch.float, device=self.device),
            edge_features = torch.ones(total_edges, 1, dtype=torch.float, device=self.device)
        )

    def _pack_batch_from_idx(self, batch_idx):
        all_gidx =  [gidx for gidx_tuple in batch_idx for gidx in gidx_tuple]
        from_idx = []
        to_idx = []
        graph_idx = []    
        total_nodes, total_edges = 0, 0
        for list_idx, gidx in enumerate(all_gidx):
            if list_idx%2==0:
                num_nodes =  self.query_graph_node_sizes[gidx] 
                num_edges =  self.query_graph_edge_sizes[gidx]
            #     from_idx.append(self.query_graph_edges[gidx][:, 0] + total_nodes)
            #     to_idx.append(self.query_graph_edges[gidx][:, 1] + total_nodes)
                
            else:
                num_nodes =  self.corpus_graph_node_sizes[gidx]
                num_edges =  self.corpus_graph_edge_sizes[gidx]
                # from_idx.append(self.corpus_graph_edges[gidx][:, 0] + total_nodes)
                # to_idx.append(self.corpus_graph_edges[gidx][:, 1] + total_nodes)
                
                
            # graph_idx.append(torch.ones(num_nodes, dtype=torch.int64, device=self.device) * list_idx)
            total_nodes += num_nodes
            total_edges += num_edges
                
        # return GraphCollection(
        #     from_idx      = torch.cat(from_idx, dim=0),
        #     to_idx        = torch.cat(to_idx, dim=0),
        #     graph_idx     = torch.cat(graph_idx, dim=0),
        #     num_graphs    = len(batch_idx),
        #     node_features = torch.ones(total_nodes, 1, dtype=torch.float, device=self.device),
        #     edge_features = torch.ones(total_edges, 1, dtype=torch.float, device=self.device)
        # )
            
    def create_stratified_batches(self):
        self.batch_setting = 'stratified'
        random.shuffle(self.pos_pairs), random.shuffle(self.neg_pairs)
        pos_to_neg_ratio = len(self.pos_pairs) / len(self.neg_pairs)

        num_pos_per_batch = math.ceil(pos_to_neg_ratio/(1 + pos_to_neg_ratio) * self.batch_size)
        num_neg_per_batch = self.batch_size - num_pos_per_batch

        batches_pos, batches_neg = [], []
        labels_pos, labels_neg = [], []
        for idx in range(0, len(self.pos_pairs), num_pos_per_batch):
            elements_remaining = len(self.pos_pairs) - idx
            elements_chosen = min(num_pos_per_batch, elements_remaining)
            batches_pos.append(self.pos_pairs[idx : idx + elements_chosen])
            labels_pos.append([1.0] * elements_chosen)
        for idx in range(0, len(self.neg_pairs), num_neg_per_batch):
            elements_remaining = len(self.neg_pairs) - idx
            elements_chosen = min(num_neg_per_batch, elements_remaining)
            batches_neg.append(self.neg_pairs[idx : idx + elements_chosen])
            labels_neg.append([0.0] * elements_chosen)

        self.num_batches = min(len(batches_pos), len(batches_neg))
        self.batches = [pos + neg for (pos, neg) in zip(batches_pos[:self.num_batches], batches_neg[:self.num_batches])]
        self.labels = [pos + neg for (pos, neg) in zip(labels_pos[:self.num_batches], labels_neg[:self.num_batches])]

        return self.num_batches

    def create_eval_batches(self, pair_list):
        self.batch_setting = 'eval'
        self.eval_batches = []
        # for idx in range(0, len(pair_list), self.batch_size):
            # self.batches.append(pair_list[idx : idx + self.batch_size])
        # bsz = self.num_corpus_graphs//10
        bsz = self.num_corpus_graphs
        for idx in range(0, len(pair_list), bsz):
            self.eval_batches.append(pair_list[idx : idx + bsz])
        
        self.num_eval_batches = len(self.eval_batches)
        return self.num_eval_batches

    def fetch_batch_by_id(self, idx):
        if self.batch_setting == 'stratified':
            assert idx < self.num_batches
            batch = self.batches[idx]
        elif self.batch_setting == 'eval':
            assert idx < self.num_eval_batches
            batch = self.eval_batches[idx]

        query_graph_idxs, corpus_graph_idxs = zip(*batch)
        
        if self.data_type == "gmn":
            query_graphs = [self.query_graphs[idx] for idx in query_graph_idxs]
            corpus_graphs = [self.corpus_graphs[idx] for idx in corpus_graph_idxs]
            all_graphs = self._pack_batch(zip(query_graphs, corpus_graphs))
        elif self.data_type == "gmn_from_idx":
            all_graphs = self._pack_batch_from_idx(batch)
        elif self.data_type == "pyg":
            query_graphs = [self.query_graph_data[idx] for idx in query_graph_idxs]
            corpus_graphs = [self.corpus_graph_data[idx] for idx in corpus_graph_idxs]
            all_graphs = query_graphs + corpus_graphs #list(zip(query_graphs, corpus_graphs))    

        # query_graph_sizes = [self.query_graph_sizes[idx] for idx in query_graph_idxs]
        # corpus_graph_sizes = [self.corpus_graph_sizes[idx] for idx in corpus_graph_idxs]
        # all_graph_sizes = list(zip(query_graph_sizes, corpus_graph_sizes))

        query_graph_node_sizes = self.query_graph_node_sizes[list(query_graph_idxs)]
        corpus_graph_node_sizes = self.corpus_graph_node_sizes[list(corpus_graph_idxs)]
        all_graph_node_sizes = torch.dstack((query_graph_node_sizes, corpus_graph_node_sizes)).flatten()
        query_graph_edge_sizes = self.query_graph_edge_sizes[list(query_graph_idxs)]
        corpus_graph_edge_sizes = self.corpus_graph_edge_sizes[list(corpus_graph_idxs)]
        all_graph_edge_sizes = torch.dstack((query_graph_edge_sizes, corpus_graph_edge_sizes)).flatten()

        # query_graph_adjs = [self.query_adj_list[idx] for idx in query_graph_idxs]
        # corpus_graph_adjs = [self.corpus_adj_list[idx] for idx in corpus_graph_idxs]
        # all_graph_adjs = list(zip(query_graph_adjs, corpus_graph_adjs))

        if self.batch_setting == 'stratified':
            target = torch.tensor(np.array(self.labels[idx]), dtype=torch.float, device=self.device)
            return all_graphs, all_graph_node_sizes, all_graph_edge_sizes, target #, all_graph_adjs
        elif self.batch_setting == 'eval':
            return all_graphs, all_graph_node_sizes, all_graph_edge_sizes, None #, all_graph_adjs
        else:
            raise NotImplementedError
    
    def insert_all_batches(self):
        for idx in tqdm(range(self.num_batches)): 
            if self.batch_setting == 'stratified':
                batch = self.batches[idx]
            elif self.batch_setting == 'eval':
                batch = self.eval_batches[idx]

            query_graph_idxs, corpus_graph_idxs = zip(*batch)
            
            if self.data_type == "gmn":
                query_graphs = [self.query_graphs[idx] for idx in query_graph_idxs]
                corpus_graphs = [self.corpus_graphs[idx] for idx in corpus_graph_idxs]
                all_graphs = self._pack_batch(zip(query_graphs, corpus_graphs))
            elif self.data_type == "gmn_from_idx":
                all_graphs = self._pack_batch_from_idx(batch)
            elif self.data_type == "pyg":
                query_graphs = [self.query_graph_data[idx] for idx in query_graph_idxs]
                corpus_graphs = [self.corpus_graph_data[idx] for idx in corpus_graph_idxs]
                all_graphs = query_graphs + corpus_graphs #list(zip(query_graphs, corpus_graphs))    


            query_graph_node_sizes = self.query_graph_node_sizes[list(query_graph_idxs)]
            corpus_graph_node_sizes = self.corpus_graph_node_sizes[list(corpus_graph_idxs)]
            all_graph_node_sizes = torch.dstack((query_graph_node_sizes, corpus_graph_node_sizes)).flatten()
            query_graph_edge_sizes = self.query_graph_edge_sizes[list(query_graph_idxs)]
            corpus_graph_edge_sizes = self.corpus_graph_edge_sizes[list(corpus_graph_idxs)]
            all_graph_edge_sizes = torch.dstack((query_graph_edge_sizes, corpus_graph_edge_sizes)).flatten()

            # all_graphs.share_memory_()
            all_graph_node_sizes.share_memory_()
            all_graph_edge_sizes.share_memory_()

            if self.batch_setting == 'stratified':
                target = torch.tensor(np.array(self.labels[idx]), dtype=torch.float, device=self.device)
                target.share_memory_()
                self.memo_batch[idx] = all_graphs, all_graph_node_sizes, all_graph_edge_sizes, target #, all_graph_adjs
            elif self.batch_setting == 'eval':
                self.memo_batch[idx] = all_graphs, all_graph_node_sizes, all_graph_edge_sizes, None #, all_graph_adjs
            else:
                raise NotImplementedError

    # a disk memoization
    def memo_fetch_batch_by_id(self, idx):
        if idx in self.memo_batch.keys():
            # logger.info(f"Fetching from memo for {idx}")
            return self.memo_batch[idx]
        
        logger.debug(f"Memo miss for {idx}")
        rv = self.fetch_batch_by_id(idx)
        rv[1].share_memory_()
        rv[2].share_memory_()
        rv[3].share_memory_()
        self.memo_batch[idx] = rv
        return rv
    
    def print_dataset_stats(self):
        logger.info(f"Dataset: {self.dataset_name}")
        logger.info(f"Mode: {self.mode}")
        logger.info(f"Number of query graphs: {len(self.query_graphs)}")
        logger.info(f"Number of corpus graphs: {len(self.corpus_graphs)}")
        # logger.info(f"Number of positive pairs: {len(self.pos_pairs)}")
        # logger.info(f"Number of negative pairs: {len(self.neg_pairs)}")
        logger.info(f"Maximum node set size: {self.max_node_set_size}")
        logger.info(f"Maximum edge set size: {self.max_edge_set_size}")
        r_list = []
        for k,v in self.relationships.items():
            r_list.append(len(v['pos'])/len(v['neg']))
        logger.info(f"Per Query Average pos to neg ratio: {np.mean(r_list)}")
        logger.info(f"Macro Average pos to neg ratio: {len(self.pos_pairs) / len(self.neg_pairs)}")
        
        
    def run_assertion_checks(self):
        raise NotImplementedError()

    
class SubgraphIsomorphismDataset(DatasetBase):
    def __init__(self, conf,  mode):
        super().__init__(conf, mode)
        
        
    def init_dataset_stats(self, conf):
        #NOTE: this loads all graphs (tr+te+val). Deliberately not storing the graphs. Only size stats are stored. 
        directory = f"{conf.base_dir}data/{conf.dataset.name}/preprocessed"
        q_fname = f"{directory}/relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
        c_fname = f"{directory}/{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
        all_query_graphs = pickle.load(open(q_fname, 'rb'))
        all_corpus_graphs = pickle.load(open(c_fname, 'rb'))
        self.max_edge_set_size = max(
            max([graph.number_of_edges() for graph in all_query_graphs]),
            max([graph.number_of_edges() for graph in all_corpus_graphs])
        )
        self.max_node_set_size = max(
            max([graph.number_of_nodes() for graph in all_query_graphs]),
            max([graph.number_of_nodes() for graph in all_corpus_graphs])
        )
        

    def load_graphs(self, conf):
        directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        
        q_fname = f"{directory}/{self.mode}/{self.mode}_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
        rel_fname = f"{directory}/{self.mode}/{self.mode}_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
        c_fname = f"{directory}/../{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
        
        if not os.path.exists(c_fname):
            raise FileNotFoundError(f"Corpus graphs not found at {c_fname}")
        
        if not os.path.exists(q_fname) or not os.path.exists(rel_fname):
            # Generate splits for the specified range
            all_q_fname = f"{directory}/../relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
            all_rel_fname = f"{directory}/../rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}.pkl"
            logger.info(f"Loaded all queries from {all_q_fname} for the purpose of creating splits")
            logger.info(f"Loaded all relationships from {all_rel_fname}  for the purpose of creating splits")
            
            all_queries = pickle.load(open(all_q_fname, 'rb'))
            all_rels = pickle.load(open(all_rel_fname, 'rb'))
            
            logger.info(f"Filtering queries for required query ratio range in {conf.dataset.MinR} to {conf.dataset.MaxR}")
            relevant_queries = []
            relevant_rels = {}
            new_qidx = 0
            for idx, rel in all_rels.items():
                if len(rel['neg'])>0 and  conf.dataset.MinR <= len(rel['pos'])/len(rel['neg']) <= conf.dataset.MaxR:
                    relevant_queries.append(all_queries[idx])
                    relevant_rels[new_qidx] = rel
                    new_qidx += 1
            
            logger.info(f"Found {len(relevant_queries)} within specified range in {conf.dataset.MinR} to {conf.dataset.MaxR}")
            logger.info(f"Splitting into tr-test-val - 60:20:20")
            all_qidx = list(range(len(relevant_queries)))
            random.shuffle(all_qidx)
            tr_idx = all_qidx[:int(0.6*len(all_qidx))]
            test_idx = all_qidx[int(0.6*len(all_qidx)):int(0.8*len(all_qidx))]
            val_idx = all_qidx[int(0.8*len(all_qidx)):]
            tr_queries = [relevant_queries[idx] for idx in tr_idx]
            test_queries = [relevant_queries[idx] for idx in test_idx]
            val_queries = [relevant_queries[idx] for idx in val_idx]
            tr_rels = {new_idx: relevant_rels[idx] for new_idx, idx in enumerate(tr_idx)}
            test_rels = {new_idx: relevant_rels[idx] for new_idx,idx in enumerate(test_idx)}
            val_rels = {new_idx: relevant_rels[idx] for new_idx,idx in enumerate(val_idx)}
            
            logger.info(f"Saving tr-test-val queries to {directory}")
            os.makedirs(f"{directory}/train", exist_ok=True)
            os.makedirs(f"{directory}/test", exist_ok=True)
            os.makedirs(f"{directory}/val", exist_ok=True)
            
            tr_q_fname = f"{directory}/train/train_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"           
            test_q_fname = f"{directory}/test/test_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
            val_q_fname = f"{directory}/val/val_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
            pickle.dump(tr_queries, open(tr_q_fname, 'wb'))
            pickle.dump(test_queries, open(test_q_fname, 'wb'))
            pickle.dump(val_queries, open(val_q_fname, 'wb'))
            tr_rel_fname = f"{directory}/train/train_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
            test_rel_fname = f"{directory}/test/test_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
            val_rel_fname = f"{directory}/val/val_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
            pickle.dump(tr_rels, open(tr_rel_fname, 'wb'))
            pickle.dump(test_rels, open(test_rel_fname, 'wb'))
            pickle.dump(val_rels, open(val_rel_fname, 'wb'))
            
            
        self.query_graphs = pickle.load(open(q_fname, 'rb'))
        self.corpus_graphs = pickle.load(open(c_fname, 'rb'))
        self.relationships = pickle.load(open(rel_fname, 'rb'))
        
        self.num_query_graphs = len(self.query_graphs)
        self.num_corpus_graphs = len(self.corpus_graphs)
        assert self.num_corpus_graphs == conf.dataset.aug_num_cgraphs, f"{self.num_corpus_graphs} and {conf.dataset.aug_num_cgraphs} mismatch"
        
        self.query_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.query_graphs)), device=self.device)
        self.corpus_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.query_graphs)), device=self.device)
        self.corpus_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edges  = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.query_graphs))
        self.corpus_graph_edges = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.corpus_graphs))
        
        assert sorted(list(self.relationships.keys())) == list(range(self.num_query_graphs)) , f"{sorted(list(self.relationships.keys()))}, and {list(range(self.num_query_graphs))}"
        logger.info(f"Loaded {len(self.query_graphs)} query graphs from {q_fname}")
        logger.info(f"Loaded {len(self.corpus_graphs)} corpus graphs from {c_fname}")
        logger.info(f"Loaded relationships from {rel_fname} for the purpose of creating splits")


        self.pos_pairs, self.neg_pairs = [], []

        for query_idx in range(self.num_query_graphs):
            for corpus_idx in self.relationships[query_idx]['pos']:
                self.pos_pairs.append((query_idx, corpus_idx))
            for corpus_idx in self.relationships[query_idx]['neg']:
                self.neg_pairs.append((query_idx, corpus_idx))
                
        self.all_gt = [] # for evaluation
        self.all_pairs = [] # for evaluation
        for qidx in range(self.num_query_graphs):
            gt = torch.zeros(self.num_corpus_graphs)
            gt[self.relationships[qidx]['pos']] = 1
            self.all_gt.append(gt)
            self.all_pairs.extend(list(zip([qidx]*self.num_corpus_graphs, range(self.num_corpus_graphs))))
        self.all_gt = torch.cat(self.all_gt)


   




        
    def run_assertion_checks(self):
        all_g = self.query_graphs + self.corpus_graphs

        #check nodes are numbered 0 ... n-1
        for g in all_g : 
            assert list(g.nodes) == list(range(g.number_of_nodes()))

        # check no two graphs are isomorphic 

        #Slow check : will take time, can perhaps be parallelized
        # for i in tqdm.tqdm(range(len(all_g))):
        #     for j in range(i+1,len(all_g)):
        #         assert not nx.is_isomorphic(all_g[i],all_g[j])
                
        # Fast Check:  create a list with wl hash of all graphs
        wl_hash = []
        for g in all_g:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=5))
        assert len(set(wl_hash)) == len(wl_hash)

        #check if rel dict matches is subgraph iso relationtshio betweek query and corpus graphs
        from src.graph_data_generator import run_parallel_pool, check_subiso 
        import itertools
        import tqdm
        for qidx in tqdm.tqdm(range(len(self.query_graphs))):
            res = run_parallel_pool(check_subiso, list(zip(self.corpus_graphs,itertools.repeat(self.query_graphs[qidx]))))
            assert (np.where(res)[0] == np.array(self.relationships[qidx]['pos'])).all() and (np.where(~np.array(res))[0] == np.array(self.relationships[qidx]['neg'])).all()


class GraphEditDistanceDataset(DatasetBase):
    def __init__(self, conf,  mode):
        super().__init__(conf, mode)
        
    def init_dataset_stats(self, conf):
        #NOTE: this loads all graphs (tr+te+val). Deliberately not storing the graphs. Only size stats are stored. 
        directory = f"{conf.base_dir}ged_data/{conf.dataset.name}/preprocessed"
        q_fname = f"{directory}/relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        c_fname = f"{directory}/{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        all_query_graphs = pickle.load(open(q_fname, 'rb'))
        all_corpus_graphs = pickle.load(open(c_fname, 'rb'))
        self.max_edge_set_size = max(
            max([graph.number_of_edges() for graph in all_query_graphs]),
            max([graph.number_of_edges() for graph in all_corpus_graphs])
        )
        self.max_node_set_size = max(
            max([graph.number_of_nodes() for graph in all_query_graphs]),
            max([graph.number_of_nodes() for graph in all_corpus_graphs])
        )
        

    def load_graphs(self, conf):
        directory = f"{conf.base_dir}ged_data/{conf.dataset.name}/preprocessed/splits"
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        
        q_fname = f"{directory}/{self.mode}/{self.mode}_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        rel_fname = f"{directory}/{self.mode}/{self.mode}_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        c_fname = f"{directory}/../{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
 
        if not os.path.exists(c_fname):
            raise FileNotFoundError(f"Corpus graphs not found at {c_fname}")
        
        if not os.path.exists(q_fname) or not os.path.exists(rel_fname):
            # Generate splits for the specified range
            all_q_fname = f"{directory}/../relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
            all_gt_fname = f"{directory}/../gt_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}.pkl"
            logger.info(f"Loaded all queries from {all_q_fname} for the purpose of creating splits")
            logger.info(f"Loaded all relationships from {all_gt_fname}  for the purpose of creating splits")
            
            all_queries = pickle.load(open(all_q_fname, 'rb'))
            all_gt = pickle.load(open(all_gt_fname, 'rb'))
            
            logger.info(f"Assigning binary relevance labels  for required maximum skew ratio {conf.dataset.MaxSkew}")
            
            all_rels = {}
            for qidx in all_gt.keys():
                gt_list = all_gt[qidx]
                sorted_gt  = sorted(list(set(all_gt[qidx].int().tolist())))
                #decide threshold
                for i in range(0,len(sorted_gt)):
                    possible_ratio = gt_list[gt_list<=sorted_gt[i]].shape[0]/gt_list.shape[0]
                    if possible_ratio  > conf.dataset.MaxSkew: 
                        break
                threshold = i-1 
                
                all_rels[qidx] = {'pos':torch.where(gt_list<=threshold)[0].tolist(), 'neg':torch.where(gt_list>threshold)[0].tolist()}
                
                
            final_skew = np.mean([len(all_rels[x]['pos'])/len(all_rels[x]['neg']) for x in all_rels.keys()])
            logger.info(f"Found final skew ratio to be {final_skew} which is less than {conf.dataset.MaxSkew}")
            
            logger.info(f"Splitting into tr-test-val - 60:20:20")
            all_qidx = list(range(len(all_queries)))
            random.shuffle(all_qidx)
            tr_idx = all_qidx[:int(0.6*len(all_qidx))]
            test_idx = all_qidx[int(0.6*len(all_qidx)):int(0.8*len(all_qidx))]
            val_idx = all_qidx[int(0.8*len(all_qidx)):]
            tr_queries = [all_queries[idx] for idx in tr_idx]
            test_queries = [all_queries[idx] for idx in test_idx]
            val_queries = [all_queries[idx] for idx in val_idx]
            tr_rels = {new_idx: all_rels[idx] for new_idx, idx in enumerate(tr_idx)}
            test_rels = {new_idx: all_rels[idx] for new_idx,idx in enumerate(test_idx)}
            val_rels = {new_idx: all_rels[idx] for new_idx,idx in enumerate(val_idx)}
            
            logger.info(f"Saving tr-test-val queries to {directory}")
            os.makedirs(f"{directory}/train", exist_ok=True)
            os.makedirs(f"{directory}/test", exist_ok=True)
            os.makedirs(f"{directory}/val", exist_ok=True)
            
            tr_q_fname = f"{directory}/train/train_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"           
            test_q_fname = f"{directory}/test/test_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            val_q_fname = f"{directory}/val/val_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            pickle.dump(tr_queries, open(tr_q_fname, 'wb'))
            pickle.dump(test_queries, open(test_q_fname, 'wb'))
            pickle.dump(val_queries, open(val_q_fname, 'wb'))
            tr_rel_fname = f"{directory}/train/train_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            test_rel_fname = f"{directory}/test/test_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            val_rel_fname = f"{directory}/val/val_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            pickle.dump(tr_rels, open(tr_rel_fname, 'wb'))
            pickle.dump(test_rels, open(test_rel_fname, 'wb'))
            pickle.dump(val_rels, open(val_rel_fname, 'wb'))
            
            
        self.query_graphs = pickle.load(open(q_fname, 'rb'))
        self.corpus_graphs = pickle.load(open(c_fname, 'rb'))
        self.relationships = pickle.load(open(rel_fname, 'rb'))
        
        self.num_query_graphs = len(self.query_graphs)
        self.num_corpus_graphs = len(self.corpus_graphs)
        assert self.num_corpus_graphs == conf.dataset.aug_num_cgraphs, f"{self.num_corpus_graphs} and {conf.dataset.aug_num_cgraphs} mismatch"
        
        self.query_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.query_graphs)), device=self.device)
        self.corpus_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.query_graphs)), device=self.device)
        self.corpus_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edges  = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.query_graphs))
        self.corpus_graph_edges = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.corpus_graphs))
        
        assert sorted(list(self.relationships.keys())) == list(range(self.num_query_graphs)) , f"{sorted(list(self.relationships.keys()))}, and {list(range(self.num_query_graphs))}"
        logger.info(f"Loaded {len(self.query_graphs)} query graphs from {q_fname}")
        logger.info(f"Loaded {len(self.corpus_graphs)} corpus graphs from {c_fname}")
        logger.info(f"Loaded relationships from {rel_fname} for the purpose of creating splits")


        self.pos_pairs, self.neg_pairs = [], []

        for query_idx in range(self.num_query_graphs):
            for corpus_idx in self.relationships[query_idx]['pos']:
                self.pos_pairs.append((query_idx, corpus_idx))
            for corpus_idx in self.relationships[query_idx]['neg']:
                self.neg_pairs.append((query_idx, corpus_idx))

        self.all_gt = [] # for evaluation
        self.all_pairs = [] # for evaluation
        for qidx in range(self.num_query_graphs):
            gt = torch.zeros(self.num_corpus_graphs)
            gt[self.relationships[qidx]['pos']] = 1
            self.all_gt.append(gt)
            self.all_pairs.extend(list(zip([qidx]*self.num_corpus_graphs, range(self.num_corpus_graphs))))
        self.all_gt = torch.cat(self.all_gt)


   




        
    def run_assertion_checks(self):
        all_g = self.query_graphs + self.corpus_graphs

        #check nodes are numbered 0 ... n-1
        for g in all_g : 
            assert list(g.nodes) == list(range(g.number_of_nodes()))

        # check no two graphs are isomorphic 

        #Slow check : will take time, can perhaps be parallelized
        # for i in tqdm.tqdm(range(len(all_g))):
        #     for j in range(i+1,len(all_g)):
        #         assert not nx.is_isomorphic(all_g[i],all_g[j])
                
        # Fast Check:  create a list with wl hash of all graphs
        wl_hash = []
        for g in self.query_graphs:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=20))
        print(f"In queries len(set(wl_hash)) = {len(set(wl_hash))} and len(wl_hash) = {len(wl_hash)}")
        assert len(set(wl_hash)) == len(wl_hash)
        
        wl_hash = []
        for g in self.corpus_graphs:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=20))
        print(f"In corpus len(set(wl_hash)) = {len(set(wl_hash))} and len(wl_hash) = {len(wl_hash)}")
        assert len(set(wl_hash))/len(wl_hash)>0.99

        #Too slow to check GED gt, but maybe someday we can check that too


#TODO: We can refactor and remove this massive duplication
class UneqGraphEditDistanceDataset(DatasetBase):
    def __init__(self, conf,  mode):
        super().__init__(conf, mode)
        
    def init_dataset_stats(self, conf):
        #NOTE: this loads all graphs (tr+te+val). Deliberately not storing the graphs. Only size stats are stored. 
        directory = f"{conf.base_dir}uneq_ged_data/{conf.dataset.name}/preprocessed"
        q_fname = f"{directory}/relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        c_fname = f"{directory}/{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        all_query_graphs = pickle.load(open(q_fname, 'rb'))
        all_corpus_graphs = pickle.load(open(c_fname, 'rb'))
        self.max_edge_set_size = max(
            max([graph.number_of_edges() for graph in all_query_graphs]),
            max([graph.number_of_edges() for graph in all_corpus_graphs])
        )
        self.max_node_set_size = max(
            max([graph.number_of_nodes() for graph in all_query_graphs]),
            max([graph.number_of_nodes() for graph in all_corpus_graphs])
        )
        

    def load_graphs(self, conf):
        directory = f"{conf.base_dir}uneq_ged_data/{conf.dataset.name}/preprocessed/splits"
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        
        q_fname = f"{directory}/{self.mode}/{self.mode}_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        rel_fname = f"{directory}/{self.mode}/{self.mode}_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        c_fname = f"{directory}/../{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
 
        if not os.path.exists(c_fname):
            raise FileNotFoundError(f"Corpus graphs not found at {c_fname}")
        
        if not os.path.exists(q_fname) or not os.path.exists(rel_fname):
            # Generate splits for the specified range
            all_q_fname = f"{directory}/../relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
            all_gt_fname = f"{directory}/../gt_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}.pkl"
            logger.info(f"Loaded all queries from {all_q_fname} for the purpose of creating splits")
            logger.info(f"Loaded all relationships from {all_gt_fname}  for the purpose of creating splits")
            
            all_queries = pickle.load(open(all_q_fname, 'rb'))
            all_gt = pickle.load(open(all_gt_fname, 'rb'))
            
            logger.info(f"Assigning binary relevance labels  for required maximum skew ratio {conf.dataset.MaxSkew}")
            
            all_rels = {}
            for qidx in all_gt.keys():
                gt_list = all_gt[qidx]
                sorted_gt  = sorted(list(set(all_gt[qidx].int().tolist())))
                #decide threshold
                for i in range(0,len(sorted_gt)):
                    possible_ratio = gt_list[gt_list<=sorted_gt[i]].shape[0]/gt_list.shape[0]
                    if possible_ratio  > conf.dataset.MaxSkew: 
                        break
                threshold = i-1 
                
                all_rels[qidx] = {'pos':torch.where(gt_list<=threshold)[0].tolist(), 'neg':torch.where(gt_list>threshold)[0].tolist()}
                
                
            final_skew = np.mean([len(all_rels[x]['pos'])/len(all_rels[x]['neg']) for x in all_rels.keys()])
            logger.info(f"Found final skew ratio to be {final_skew} which is less than {conf.dataset.MaxSkew}")
            
            logger.info(f"Splitting into tr-test-val - 60:20:20")
            all_qidx = list(range(len(all_queries)))
            random.shuffle(all_qidx)
            tr_idx = all_qidx[:int(0.6*len(all_qidx))]
            test_idx = all_qidx[int(0.6*len(all_qidx)):int(0.8*len(all_qidx))]
            val_idx = all_qidx[int(0.8*len(all_qidx)):]
            tr_queries = [all_queries[idx] for idx in tr_idx]
            test_queries = [all_queries[idx] for idx in test_idx]
            val_queries = [all_queries[idx] for idx in val_idx]
            tr_rels = {new_idx: all_rels[idx] for new_idx, idx in enumerate(tr_idx)}
            test_rels = {new_idx: all_rels[idx] for new_idx,idx in enumerate(test_idx)}
            val_rels = {new_idx: all_rels[idx] for new_idx,idx in enumerate(val_idx)}
            
            logger.info(f"Saving tr-test-val queries to {directory}")
            os.makedirs(f"{directory}/train", exist_ok=True)
            os.makedirs(f"{directory}/test", exist_ok=True)
            os.makedirs(f"{directory}/val", exist_ok=True)
            
            tr_q_fname = f"{directory}/train/train_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"           
            test_q_fname = f"{directory}/test/test_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            val_q_fname = f"{directory}/val/val_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            pickle.dump(tr_queries, open(tr_q_fname, 'wb'))
            pickle.dump(test_queries, open(test_q_fname, 'wb'))
            pickle.dump(val_queries, open(val_q_fname, 'wb'))
            tr_rel_fname = f"{directory}/train/train_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            test_rel_fname = f"{directory}/test/test_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            val_rel_fname = f"{directory}/val/val_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            pickle.dump(tr_rels, open(tr_rel_fname, 'wb'))
            pickle.dump(test_rels, open(test_rel_fname, 'wb'))
            pickle.dump(val_rels, open(val_rel_fname, 'wb'))
            
            
        self.query_graphs = pickle.load(open(q_fname, 'rb'))
        self.corpus_graphs = pickle.load(open(c_fname, 'rb'))
        self.relationships = pickle.load(open(rel_fname, 'rb'))
        
        self.num_query_graphs = len(self.query_graphs)
        self.num_corpus_graphs = len(self.corpus_graphs)
        assert self.num_corpus_graphs == conf.dataset.aug_num_cgraphs, f"{self.num_corpus_graphs} and {conf.dataset.aug_num_cgraphs} mismatch"
        
        self.query_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.query_graphs)), device=self.device)
        self.corpus_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.query_graphs)), device=self.device)
        self.corpus_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edges  = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.query_graphs))
        self.corpus_graph_edges = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.corpus_graphs))
        
        assert sorted(list(self.relationships.keys())) == list(range(self.num_query_graphs)) , f"{sorted(list(self.relationships.keys()))}, and {list(range(self.num_query_graphs))}"
        logger.info(f"Loaded {len(self.query_graphs)} query graphs from {q_fname}")
        logger.info(f"Loaded {len(self.corpus_graphs)} corpus graphs from {c_fname}")
        logger.info(f"Loaded relationships from {rel_fname} for the purpose of creating splits")


        self.pos_pairs, self.neg_pairs = [], []

        for query_idx in range(self.num_query_graphs):
            for corpus_idx in self.relationships[query_idx]['pos']:
                self.pos_pairs.append((query_idx, corpus_idx))
            for corpus_idx in self.relationships[query_idx]['neg']:
                self.neg_pairs.append((query_idx, corpus_idx))

        self.all_gt = [] # for evaluation
        self.all_pairs = [] # for evaluation
        for qidx in range(self.num_query_graphs):
            gt = torch.zeros(self.num_corpus_graphs)
            gt[self.relationships[qidx]['pos']] = 1
            self.all_gt.append(gt)
            self.all_pairs.extend(list(zip([qidx]*self.num_corpus_graphs, range(self.num_corpus_graphs))))
        self.all_gt = torch.cat(self.all_gt)


   




        
    def run_assertion_checks(self):
        all_g = self.query_graphs + self.corpus_graphs

        #check nodes are numbered 0 ... n-1
        for g in all_g : 
            assert list(g.nodes) == list(range(g.number_of_nodes()))

        # check no two graphs are isomorphic 

        #Slow check : will take time, can perhaps be parallelized
        # for i in tqdm.tqdm(range(len(all_g))):
        #     for j in range(i+1,len(all_g)):
        #         assert not nx.is_isomorphic(all_g[i],all_g[j])
                
        # Fast Check:  create a list with wl hash of all graphs
        wl_hash = []
        for g in self.query_graphs:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=20))
        print(f"In queries len(set(wl_hash)) = {len(set(wl_hash))} and len(wl_hash) = {len(wl_hash)}")
        assert len(set(wl_hash)) == len(wl_hash)
        
        wl_hash = []
        for g in self.corpus_graphs:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=20))
        print(f"In corpus len(set(wl_hash)) = {len(set(wl_hash))} and len(wl_hash) = {len(wl_hash)}")
        assert len(set(wl_hash))/len(wl_hash)>0.99

        #Too slow to check GED gt, but maybe someday we can check that too
