import os
import math
import pickle
import random
import collections
import numpy as np
import networkx as nx
import torch
from loguru import logger
import torch.nn.functional as F
from torch_geometric.data import Data
from utils.utils import *

## MINI-dataset loader only to be used for the distribution experiment ##

## random sampling
RANDOM_MODE = "random"

GraphCollection = collections.namedtuple(
    'GraphCollection', 
    ['from_idx', 'to_idx', 'node_features', 'edge_features', 'graph_idx', 'num_graphs']
)

class DatasetBase:
    def __init__(self, conf,  mode):
        assert mode==RANDOM_MODE
        self.mode = mode
        self.batch_setting = None
        # self.conf = conf
        self.dataset_name = conf.dataset.name
        self.init_dataset_stats(conf)
        self.data_type = conf.dataset.data_type

        self.batch_size = conf.training.batch_size
        self.device = conf.training.device
        
        self.load_graphs(conf)
        
        self.preprocess_subgraphs_to_pyG_data()
        # self.print_dataset_stats()

    def init_dataset_stats(self, conf):
        raise NotImplementedError()   

    def load_graphs(self, conf):
        raise NotImplementedError()
    
    
    def create_pyG_object(self, graph):
        num_nodes = graph.number_of_nodes()
        features = torch.ones(num_nodes, 1, dtype=torch.float, device=self.device)

        edges = list(graph.edges)
        doubled_edges = [[x, y] for (x, y) in edges] #+ [[y, x] for (x, y) in edges]
        edge_index = torch.tensor(np.array(doubled_edges).T, dtype=torch.int64, device=self.device)
        return Data(x = features, edge_index = edge_index), num_nodes

    def preprocess_subgraphs_to_pyG_data(self):
        self.query_graph_data, self.query_graph_sizes = zip(
            *[self.create_pyG_object(query_graph) for query_graph in self.query_graphs]
        )
        self.corpus_graph_data, self.corpus_graph_sizes = zip(
            *[self.create_pyG_object(corpus_graph) for corpus_graph in self.corpus_graphs]
        )
        self.query_graph_data = list(self.query_graph_data)
        self.corpus_graph_data = list(self.corpus_graph_data)
        
    
    def _pack_batch_1d(self, graphs):
        from_idx = []
        to_idx = []
        graph_idx = []
        # all_graphs = [individual_graph for graph_tuple in graphs for individual_graph in graph_tuple]
        all_graphs = graphs

        total_nodes, total_edges = 0, 0
        for idx, graph in enumerate(all_graphs):
            num_nodes = graph.number_of_nodes()
            num_edges = graph.number_of_edges()
            edges = np.array(graph.edges(), dtype=np.int32)
            
            from_idx.append(edges[:, 0] + total_nodes)
            to_idx.append(edges[:, 1] + total_nodes)
            graph_idx.append(np.ones(num_nodes, dtype=np.int32) * idx)

            total_nodes += num_nodes
            total_edges += num_edges
        
        return GraphCollection(
            from_idx = torch.tensor(np.concatenate(from_idx, axis=0), dtype=torch.int64, device=self.device),
            to_idx = torch.tensor(np.concatenate(to_idx, axis=0), dtype=torch.int64, device=self.device),
            graph_idx = torch.tensor(np.concatenate(graph_idx, axis=0), dtype=torch.int64, device=self.device),
            num_graphs = len(all_graphs),
            node_features = torch.ones(total_nodes, 1, dtype=torch.float, device=self.device),
            edge_features = torch.ones(total_edges, 1, dtype=torch.float, device=self.device)
        )   
 
    def _pack_batch(self, graphs):
        from_idx = []
        to_idx = []
        graph_idx = []
        all_graphs = [individual_graph for graph_tuple in graphs for individual_graph in graph_tuple]

        total_nodes, total_edges = 0, 0
        for idx, graph in enumerate(all_graphs):
            num_nodes = graph.number_of_nodes()
            num_edges = graph.number_of_edges()
            edges = np.array(graph.edges(), dtype=np.int32)
            
            from_idx.append(edges[:, 0] + total_nodes)
            to_idx.append(edges[:, 1] + total_nodes)
            graph_idx.append(np.ones(num_nodes, dtype=np.int32) * idx)

            total_nodes += num_nodes
            total_edges += num_edges
        
        return GraphCollection(
            from_idx = torch.tensor(np.concatenate(from_idx, axis=0), dtype=torch.int64, device=self.device),
            to_idx = torch.tensor(np.concatenate(to_idx, axis=0), dtype=torch.int64, device=self.device),
            graph_idx = torch.tensor(np.concatenate(graph_idx, axis=0), dtype=torch.int64, device=self.device),
            num_graphs = len(all_graphs),
            node_features = torch.ones(total_nodes, 1, dtype=torch.float, device=self.device),
            edge_features = torch.ones(total_edges, 1, dtype=torch.float, device=self.device)
        )

    def _pack_batch_from_idx(self, batch_idx):
        all_gidx =  [gidx for gidx_tuple in batch_idx for gidx in gidx_tuple]
        from_idx = []
        to_idx = []
        graph_idx = []    
        total_nodes, total_edges = 0, 0
        for list_idx, gidx in enumerate(all_gidx):
            if list_idx%2==0:
                num_nodes =  self.query_graph_node_sizes[gidx] 
                num_edges =  self.query_graph_edge_sizes[gidx]
                
            else:
                num_nodes =  self.corpus_graph_node_sizes[gidx]
                num_edges =  self.corpus_graph_edge_sizes[gidx]
            
            total_nodes += num_nodes
            total_edges += num_edges
                
        
            
    def create_stratified_batches(self):
        self.batch_setting = 'stratified'
        random.shuffle(self.pos_pairs), random.shuffle(self.neg_pairs)
        pos_to_neg_ratio = len(self.pos_pairs) / len(self.neg_pairs)

        num_pos_per_batch = math.ceil(pos_to_neg_ratio/(1 + pos_to_neg_ratio) * self.batch_size)
        num_neg_per_batch = self.batch_size - num_pos_per_batch

        batches_pos, batches_neg = [], []
        labels_pos, labels_neg = [], []
        for idx in range(0, len(self.pos_pairs), num_pos_per_batch):
            elements_remaining = len(self.pos_pairs) - idx
            elements_chosen = min(num_pos_per_batch, elements_remaining)
            batches_pos.append(self.pos_pairs[idx : idx + elements_chosen])
            labels_pos.append([1.0] * elements_chosen)
        for idx in range(0, len(self.neg_pairs), num_neg_per_batch):
            elements_remaining = len(self.neg_pairs) - idx
            elements_chosen = min(num_neg_per_batch, elements_remaining)
            batches_neg.append(self.neg_pairs[idx : idx + elements_chosen])
            labels_neg.append([0.0] * elements_chosen)

        self.num_batches = min(len(batches_pos), len(batches_neg))
        self.batches = [pos + neg for (pos, neg) in zip(batches_pos[:self.num_batches], batches_neg[:self.num_batches])]
        self.labels = [pos + neg for (pos, neg) in zip(labels_pos[:self.num_batches], labels_neg[:self.num_batches])]

        return self.num_batches
    
    def create_batches(self,shuffle:bool=True):
        if shuffle is False and self.batch_setting == 'regular':
            assert(len(self.batches)==len(self.labels))
            return self.num_batches
            
        self.batch_setting = "regular"
        self.batches = []
        self.labels = []
        list_all = self.all_pairs.copy() # check if this is needed
        gt_all = self.all_gt
        if shuffle:
            c = list(zip(list_all, gt_all))
            random.shuffle(c)  
            list_all, gt_all = zip(*c)
            
        for i in range(0, len(list_all), self.batch_size):
            
            self.batches.append(list_all[i:i+self.batch_size])
            self.labels.append(gt_all[i:i+self.batch_size])
    
        self.num_batches = len(self.batches)  

        return self.num_batches
        
        
        

    def create_eval_batches(self, pair_list):
        self.batch_setting = 'eval'
        self.eval_batches = []
        # for idx in range(0, len(pair_list), self.batch_size):
            # self.batches.append(pair_list[idx : idx + self.batch_size])
        # bsz = self.num_corpus_graphs//10
        bsz = self.num_corpus_graphs
        for idx in range(0, len(pair_list), bsz):
            self.eval_batches.append(pair_list[idx : idx + bsz])
        
        self.num_eval_batches = len(self.eval_batches)
        return self.num_eval_batches

    def fetch_batch_by_id(self, idx):
        """
        returns batch instances and labels by id
        """
        if self.batch_setting == 'stratified' or self.batch_setting == 'regular':
            assert idx < self.num_batches
            batch = self.batches[idx]
        elif self.batch_setting == 'eval':
            assert idx < self.num_eval_batches
            batch = self.eval_batches[idx]

        query_graph_idxs, corpus_graph_idxs = zip(*batch)
        
        if self.data_type == "gmn":
            query_graphs = [self.query_graphs[idx] for idx in query_graph_idxs]
            corpus_graphs = [self.corpus_graphs[idx] for idx in corpus_graph_idxs]
            all_graphs = self._pack_batch(zip(query_graphs, corpus_graphs))
        elif self.data_type == "gmn_from_idx":
            all_graphs = self._pack_batch_from_idx(batch)
        elif self.data_type == "pyg":
            query_graphs = [self.query_graph_data[idx] for idx in query_graph_idxs]
            corpus_graphs = [self.corpus_graph_data[idx] for idx in corpus_graph_idxs]
            all_graphs = query_graphs + corpus_graphs #list(zip(query_graphs, corpus_graphs))    


        query_graph_node_sizes = self.query_graph_node_sizes[list(query_graph_idxs)]
        corpus_graph_node_sizes = self.corpus_graph_node_sizes[list(corpus_graph_idxs)]
        all_graph_node_sizes = torch.dstack((query_graph_node_sizes, corpus_graph_node_sizes)).flatten()
        query_graph_edge_sizes = self.query_graph_edge_sizes[list(query_graph_idxs)]
        corpus_graph_edge_sizes = self.corpus_graph_edge_sizes[list(corpus_graph_idxs)]
        all_graph_edge_sizes = torch.dstack((query_graph_edge_sizes, corpus_graph_edge_sizes)).flatten()

        

        if self.batch_setting == 'stratified' or self.batch_setting == 'regular':
            target = torch.tensor(np.array(self.labels[idx]), dtype=torch.float, device=self.device)
            return all_graphs, all_graph_node_sizes, all_graph_edge_sizes, target #, all_graph_adjs
        elif self.batch_setting == 'eval':
            return all_graphs, all_graph_node_sizes, all_graph_edge_sizes, None #, all_graph_adjs
        else:
            raise NotImplementedError

    def fetch(self):
        
        query_graph_idxs, corpus_graph_idxs = zip(*self.all_pairs)
        
        
        if self.data_type == "gmn":
            query_graphs = [self.query_graphs[idx] for idx in query_graph_idxs]
            corpus_graphs = [self.corpus_graphs[idx] for idx in corpus_graph_idxs]
            all_graphs = self._pack_batch(zip(query_graphs, corpus_graphs))
        elif self.data_type == "gmn_from_idx":
            raise NotImplementedError
        elif self.data_type == "pyg":
            all_graphs = self.query_graphs + self.corpus_graphs
            
        query_graph_node_sizes = self.query_graph_node_sizes[list(query_graph_idxs)]
        corpus_graph_node_sizes = self.corpus_graph_node_sizes[list(corpus_graph_idxs)]
        all_graph_node_sizes = torch.dstack((query_graph_node_sizes, corpus_graph_node_sizes)).flatten()
        query_graph_edge_sizes = self.query_graph_edge_sizes[list(query_graph_idxs)]
        corpus_graph_edge_sizes = self.corpus_graph_edge_sizes[list(corpus_graph_idxs)]
        all_graph_edge_sizes = torch.dstack((query_graph_edge_sizes, corpus_graph_edge_sizes)).flatten()
        
        target = torch.tensor(np.array(self.all_gt), dtype=torch.float, device=self.device)
        return all_graphs, all_graph_node_sizes, all_graph_edge_sizes, target
    
    
    def print_dataset_stats(self):
        logger.info(f"Dataset: {self.dataset_name}")
        logger.info(f"Mode: {self.mode}")
        logger.info(f"Number of query graphs: {len(self.query_graphs)}")
        logger.info(f"Number of corpus graphs: {len(self.corpus_graphs)}")
        
        logger.info(f"Maximum node set size: {self.max_node_set_size}")
        logger.info(f"Maximum edge set size: {self.max_edge_set_size}")
        r_list = []
        try:
            for k,v in self.relationships.items():
                r_list.append(len(v['pos'])/len(v['neg']))
            logger.info(f"Per Query Average pos to neg ratio: {np.mean(r_list)}")
            logger.info(f"Macro Average pos to neg ratio: {len(self.pos_pairs) / len(self.neg_pairs)}")
        except IndexError:
            # no pos/neg info
            pass
        
        
    def run_assertion_checks(self):
        raise NotImplementedError()

    
class SubgraphIsomorphismDataset(DatasetBase):
    def __init__(self, conf,  mode):
        super().__init__(conf, mode)
        
        
    def init_dataset_stats(self, conf):
        #NOTE: this loads all graphs (tr+te+val). Deliberately not storing the graphs. Only size stats are stored. 
        directory = f"{conf.base_dir}data/{conf.dataset.name}/preprocessed"
        q_fname = f"{directory}/relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
        c_fname = f"{directory}/{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
        all_query_graphs = pickle.load(open(q_fname, 'rb'))
        all_corpus_graphs = pickle.load(open(c_fname, 'rb'))
        self.max_edge_set_size = max(
            max([graph.number_of_edges() for graph in all_query_graphs]),
            max([graph.number_of_edges() for graph in all_corpus_graphs])
        )
        self.max_node_set_size = max(
            max([graph.number_of_nodes() for graph in all_query_graphs]),
            max([graph.number_of_nodes() for graph in all_corpus_graphs])
        )
        

    def load_graphs(self, conf):
        directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        
        q_fname = f"{directory}/{self.mode}/{self.mode}_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
        rel_fname = f"{directory}/{self.mode}/{self.mode}_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
        c_fname = f"{directory}/../{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
        
        if not os.path.exists(c_fname):
            raise FileNotFoundError(f"Corpus graphs not found at {c_fname}")
        
        ### !!!
        ## for random model
        assert self.mode == RANDOM_MODE
        ####
        set_seed(conf.training.seed)

        
        all_q_fname = f"{directory}/../relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
        all_rel_fname = f"{directory}/../rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}.pkl"
        
        all_queries = pickle.load(open(all_q_fname, 'rb'))
        all_rels = pickle.load(open(all_rel_fname, 'rb'))
        
        corpus_graphs = pickle.load(open(c_fname, 'rb'))
        ## pick corpus graphs randomly
        corpus_idx = random.sample(range(len(corpus_graphs)), conf.expt.num_corpus)
        self.corpus_graphs = [corpus_graphs[idx] for idx in corpus_idx]
        
        ## pick query graphs randomly
        query_idx = random.sample(range(len(all_queries)), conf.expt.num_queries)
        self.query_graphs = [all_queries[idx] for idx in query_idx]
        
        def reindex(relenrtry):
            return { "pos":[corpus_idx.index(x) for x in relenrtry['pos'] if x in corpus_idx],
                    "neg":[corpus_idx.index(x) for x in relenrtry['neg'] if x in corpus_idx]}
            
        self.relationships = {new_idx: reindex(all_rels[idx]) for new_idx, idx in enumerate(query_idx)}
        
        self.num_query_graphs = len(self.query_graphs)
        self.num_corpus_graphs = len(self.corpus_graphs)
        
        self.query_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.query_graphs)), device=self.device)
        self.corpus_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.query_graphs)), device=self.device)
        self.corpus_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edges  = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.query_graphs))
        self.corpus_graph_edges = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.corpus_graphs))
        
        ## not checking relevance ratios here
        
        self.pos_pairs, self.neg_pairs = [], []

        for query_idx in range(self.num_query_graphs):
            for corpus_idx in self.relationships[query_idx]['pos']:
                self.pos_pairs.append((query_idx, corpus_idx))
            for corpus_idx in self.relationships[query_idx]['neg']:
                self.neg_pairs.append((query_idx, corpus_idx))
                
        self.all_gt = [] 
        self.all_pairs = []
        for qidx in range(self.num_query_graphs):
            gt = torch.zeros(self.num_corpus_graphs)
            gt[self.relationships[qidx]['pos']] = 1
            self.all_gt.append(gt)
            self.all_pairs.extend(list(zip([qidx]*self.num_corpus_graphs, range(self.num_corpus_graphs))))
        self.all_gt = torch.cat(self.all_gt)
        return
            
        
        
    def run_assertion_checks(self):
        all_g = self.query_graphs + self.corpus_graphs

        #check nodes are numbered 0 ... n-1
        for g in all_g : 
            assert list(g.nodes) == list(range(g.number_of_nodes()))

        # check no two graphs are isomorphic 

        #Slow check : will take time, can perhaps be parallelized
        # for i in tqdm.tqdm(range(len(all_g))):
        #     for j in range(i+1,len(all_g)):
        #         assert not nx.is_isomorphic(all_g[i],all_g[j])
                
        # Fast Check:  create a list with wl hash of all graphs
        wl_hash = []
        for g in all_g:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=5))
        assert len(set(wl_hash)) == len(wl_hash)

        #check if rel dict matches is subgraph iso relationtshio betweek query and corpus graphs
        from src.graph_data_generator import run_parallel_pool, check_subiso 
        import itertools
        import tqdm
        for qidx in tqdm.tqdm(range(len(self.query_graphs))):
            res = run_parallel_pool(check_subiso, list(zip(self.corpus_graphs,itertools.repeat(self.query_graphs[qidx]))))
            assert (np.where(res)[0] == np.array(self.relationships[qidx]['pos'])).all() and (np.where(~np.array(res))[0] == np.array(self.relationships[qidx]['neg'])).all()


class GraphEditDistanceDataset(DatasetBase):
    def __init__(self, conf,  mode):
        super().__init__(conf, mode)
        
    def init_dataset_stats(self, conf):
        #NOTE: this loads all graphs (tr+te+val). Deliberately not storing the graphs. Only size stats are stored. 
        directory = f"{conf.base_dir}ged_data/{conf.dataset.name}/preprocessed"
        q_fname = f"{directory}/relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        c_fname = f"{directory}/{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        all_query_graphs = pickle.load(open(q_fname, 'rb'))
        all_corpus_graphs = pickle.load(open(c_fname, 'rb'))
        self.max_edge_set_size = max(
            max([graph.number_of_edges() for graph in all_query_graphs]),
            max([graph.number_of_edges() for graph in all_corpus_graphs])
        )
        self.max_node_set_size = max(
            max([graph.number_of_nodes() for graph in all_query_graphs]),
            max([graph.number_of_nodes() for graph in all_corpus_graphs])
        )
        

    def load_graphs(self, conf):
        directory = f"{conf.base_dir}ged_data/{conf.dataset.name}/preprocessed/splits"
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        
        q_fname = f"{directory}/{self.mode}/{self.mode}_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        rel_fname = f"{directory}/{self.mode}/{self.mode}_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        c_fname = f"{directory}/../{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
 
        if not os.path.exists(c_fname):
            raise FileNotFoundError(f"Corpus graphs not found at {c_fname}")
        
       
        ## for random model
        assert self.mode == RANDOM_MODE
            ####
        set_seed(conf.training.seed)
        
        
        all_q_fname = f"{directory}/../relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        all_gt_fname = f"{directory}/../gt_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}.pkl"
        
        all_queries = pickle.load(open(all_q_fname, 'rb'))
        all_gt = pickle.load(open(all_gt_fname, 'rb'))
        
        corpus_graphs = pickle.load(open(c_fname, 'rb'))
        ## pick corpus graphs randomly
        corpus_idx = random.sample(range(len(corpus_graphs)), conf.expt.num_corpus)
        self.corpus_graphs = [corpus_graphs[idx] for idx in corpus_idx]
        
        ## pick query graphs randomly
        query_idx = random.sample(range(len(all_queries)), conf.expt.num_queries)
        self.query_graphs = [all_queries[idx] for idx in query_idx]
        
        def reindex(relenrtry):
            return relenrtry[corpus_idx]
        all_gt = {new_idx: reindex(all_gt[idx]) for new_idx, idx in enumerate(query_idx)}
        
        all_rels = {}
        for qidx in all_gt.keys():
            gt_list = all_gt[qidx]
            sorted_gt  = sorted(list(set(all_gt[qidx].int().tolist())))
            #decide threshold
            for i in range(0,len(sorted_gt)):
                possible_ratio = gt_list[gt_list<=sorted_gt[i]].shape[0]/gt_list.shape[0]
                if possible_ratio  > conf.dataset.MaxSkew: 
                    break
            threshold = i-1 
            
            all_rels[qidx] = {'pos':torch.where(gt_list<=threshold)[0].tolist(), 'neg':torch.where(gt_list>threshold)[0].tolist()}
        
        self.relationships = all_rels
        
        self.num_query_graphs = len(self.query_graphs)
        self.num_corpus_graphs = len(self.corpus_graphs)
        
        self.query_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.query_graphs)), device=self.device)
        self.corpus_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.query_graphs)), device=self.device)
        self.corpus_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edges  = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.query_graphs))
        self.corpus_graph_edges = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.corpus_graphs))
        
        ## not checking relevance ratios here
        
        self.pos_pairs, self.neg_pairs = [], []

        for query_idx in range(self.num_query_graphs):
            for corpus_idx in self.relationships[query_idx]['pos']:
                self.pos_pairs.append((query_idx, corpus_idx))
            for corpus_idx in self.relationships[query_idx]['neg']:
                self.neg_pairs.append((query_idx, corpus_idx))
                
        self.all_gt = [] 
        self.all_pairs = []
        for qidx in range(self.num_query_graphs):
            gt = torch.zeros(self.num_corpus_graphs)
            gt[self.relationships[qidx]['pos']] = 1
            self.all_gt.append(gt)
            self.all_pairs.extend(list(zip([qidx]*self.num_corpus_graphs, range(self.num_corpus_graphs))))
        self.all_gt = torch.cat(self.all_gt)
        return
        
        
    def run_assertion_checks(self):
        all_g = self.query_graphs + self.corpus_graphs

        #check nodes are numbered 0 ... n-1
        for g in all_g : 
            assert list(g.nodes) == list(range(g.number_of_nodes()))

        # check no two graphs are isomorphic 

        #Slow check : will take time, can perhaps be parallelized
        # for i in tqdm.tqdm(range(len(all_g))):
        #     for j in range(i+1,len(all_g)):
        #         assert not nx.is_isomorphic(all_g[i],all_g[j])
                
        # Fast Check:  create a list with wl hash of all graphs
        wl_hash = []
        for g in self.query_graphs:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=20))
        print(f"In queries len(set(wl_hash)) = {len(set(wl_hash))} and len(wl_hash) = {len(wl_hash)}")
        assert len(set(wl_hash)) == len(wl_hash)
        
        wl_hash = []
        for g in self.corpus_graphs:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=20))
        print(f"In corpus len(set(wl_hash)) = {len(set(wl_hash))} and len(wl_hash) = {len(wl_hash)}")
        assert len(set(wl_hash))/len(wl_hash)>0.99

        #Too slow to check GED gt, but maybe someday we can check that too

class UneqGraphEditDistanceDataset(DatasetBase):
    def __init__(self, conf,  mode):
        super().__init__(conf, mode)
        
    def init_dataset_stats(self, conf):
        #NOTE: this loads all graphs (tr+te+val). Deliberately not storing the graphs. Only size stats are stored. 
        directory = f"{conf.base_dir}uneq_ged_data/{conf.dataset.name}/preprocessed"
        q_fname = f"{directory}/relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        c_fname = f"{directory}/{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        all_query_graphs = pickle.load(open(q_fname, 'rb'))
        all_corpus_graphs = pickle.load(open(c_fname, 'rb'))
        self.max_edge_set_size = max(
            max([graph.number_of_edges() for graph in all_query_graphs]),
            max([graph.number_of_edges() for graph in all_corpus_graphs])
        )
        self.max_node_set_size = max(
            max([graph.number_of_nodes() for graph in all_query_graphs]),
            max([graph.number_of_nodes() for graph in all_corpus_graphs])
        )
        

    def load_graphs(self, conf):
        directory = f"{conf.base_dir}uneq_ged_data/{conf.dataset.name}/preprocessed/splits"
        if not os.path.exists(directory):
            os.makedirs(directory)
        
        
        q_fname = f"{directory}/{self.mode}/{self.mode}_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        rel_fname = f"{directory}/{self.mode}/{self.mode}_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        c_fname = f"{directory}/../{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
 
        if not os.path.exists(c_fname):
            raise FileNotFoundError(f"Corpus graphs not found at {c_fname}")
        
        
        ## for random model
        assert self.mode == RANDOM_MODE
        ####
        set_seed(conf.training.seed)
        
        
        all_q_fname = f"{directory}/../relabeled_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        all_gt_fname = f"{directory}/../gt_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}.pkl"
        
        all_queries = pickle.load(open(all_q_fname, 'rb'))
        all_gt = pickle.load(open(all_gt_fname, 'rb'))
        
        corpus_graphs = pickle.load(open(c_fname, 'rb'))
        ## pick corpus graphs randomly
        corpus_idx = random.sample(range(len(corpus_graphs)), conf.expt.num_corpus)
        self.corpus_graphs = [corpus_graphs[idx] for idx in corpus_idx]
        
        ## pick query graphs randomly
        query_idx = random.sample(range(len(all_queries)), conf.expt.num_queries)
        self.query_graphs = [all_queries[idx] for idx in query_idx]
        
        def reindex(relenrtry):
            return relenrtry[corpus_idx]
        all_gt = {new_idx: reindex(all_gt[idx]) for new_idx, idx in enumerate(query_idx)}
        
        all_rels = {}
        for qidx in all_gt.keys():
            gt_list = all_gt[qidx]
            sorted_gt  = sorted(list(set(all_gt[qidx].int().tolist())))
            #decide threshold
            for i in range(0,len(sorted_gt)):
                possible_ratio = gt_list[gt_list<=sorted_gt[i]].shape[0]/gt_list.shape[0]
                if possible_ratio  > conf.dataset.MaxSkew: 
                    break
            threshold = i-1 
            
            all_rels[qidx] = {'pos':torch.where(gt_list<=threshold)[0].tolist(), 'neg':torch.where(gt_list>threshold)[0].tolist()}
        
        self.relationships = all_rels
        
        self.num_query_graphs = len(self.query_graphs)
        self.num_corpus_graphs = len(self.corpus_graphs)
        
        self.query_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.query_graphs)), device=self.device)
        self.corpus_graph_node_sizes = torch.tensor(
            list(map(lambda g: g.number_of_nodes(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.query_graphs)), device=self.device)
        self.corpus_graph_edge_sizes = torch.tensor(
            list(map(lambda g: g.number_of_edges(), self.corpus_graphs)), device=self.device)
        
        self.query_graph_edges  = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.query_graphs))
        self.corpus_graph_edges = list(map(lambda g: torch.tensor(np.array(g.edges()), dtype=torch.int64, device=self.device), self.corpus_graphs))
        
        ## not checking relevance ratios here
        
        self.pos_pairs, self.neg_pairs = [], []

        for query_idx in range(self.num_query_graphs):
            for corpus_idx in self.relationships[query_idx]['pos']:
                self.pos_pairs.append((query_idx, corpus_idx))
            for corpus_idx in self.relationships[query_idx]['neg']:
                self.neg_pairs.append((query_idx, corpus_idx))
                
        self.all_gt = [] 
        self.all_pairs = []
        for qidx in range(self.num_query_graphs):
            gt = torch.zeros(self.num_corpus_graphs)
            gt[self.relationships[qidx]['pos']] = 1
            self.all_gt.append(gt)
            self.all_pairs.extend(list(zip([qidx]*self.num_corpus_graphs, range(self.num_corpus_graphs))))
        self.all_gt = torch.cat(self.all_gt)
        return
        

        
    def run_assertion_checks(self):
        all_g = self.query_graphs + self.corpus_graphs

        #check nodes are numbered 0 ... n-1
        for g in all_g : 
            assert list(g.nodes) == list(range(g.number_of_nodes()))

        # check no two graphs are isomorphic 

        #Slow check : will take time, can perhaps be parallelized
        # for i in tqdm.tqdm(range(len(all_g))):
        #     for j in range(i+1,len(all_g)):
        #         assert not nx.is_isomorphic(all_g[i],all_g[j])
                
        # Fast Check:  create a list with wl hash of all graphs
        wl_hash = []
        for g in self.query_graphs:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=20))
        print(f"In queries len(set(wl_hash)) = {len(set(wl_hash))} and len(wl_hash) = {len(wl_hash)}")
        assert len(set(wl_hash)) == len(wl_hash)
        
        wl_hash = []
        for g in self.corpus_graphs:
            wl_hash.append(nx.algorithms.weisfeiler_lehman_graph_hash(g,iterations=20))
        print(f"In corpus len(set(wl_hash)) = {len(set(wl_hash))} and len(wl_hash) = {len(wl_hash)}")
        assert len(set(wl_hash))/len(wl_hash)>0.99

        #Too slow to check GED gt, but maybe someday we can check that too