from torch.utils.data import Dataset
import torch, random
import os
import pickle
#from problems.toposort.state_toposort import StateTopoSort
#from utils.beam_search import beam_search
#from utils import orderCheck, deep_sort_x, level_sorting, level_sorting_xy_pairs, order_check, graph_sorting_DAG

from collections import defaultdict
from itertools import combinations
import networkx as nx
import networkx.algorithms.isomorphism as iso

import numpy as np
from torch.utils.data import DataLoader, SubsetRandomSampler
from tqdm import tqdm

class TopoSortDataset(Dataset):
    """
    cannot support multi-operators currently 
    """
    def __init__(self, filename=None, size=25, num_samples=1024, offset=0, in_degree_fixed=3, in_degree_total=6, resource_constraint_level={-1:4}, level_range=[4, 10], weight_multiply=10., weight_constraint=15., shift=0., distribution=None, seed=0):
        super(TopoSortDataset, self).__init__()

        self.data = []
        self.label = []

        if filename is not None:
            G = nx.read_gpickle(filename)
            #graph, layer_index = self._embedding_transformer(G, in_degree_fixed)
            graph, layer_index = self._embedding_transformer(G, in_degree_total)
            random.shuffle(graph)

            self.data.append(torch.FloatTensor(graph))
            self.label.append(layer_index)
        else:
            graphs_collection = [] # Graphs with each consisting of embedding nodes

            if seed > 0:
               random.seed(seed)

            for _ in range(num_samples):
                graph, D = self._dag_generator(size, in_degree_fixed, in_degree_total, len(resource_constraint_level), level_range, weight_multiply, shift)

                schedule = self._scheduling(D, D.number_of_nodes(), resource_constraint_level, weight_constraint)

                level_index = [0 for _ in range(size)] # list showing the level of each node
                for i in range(len(schedule)):
                    for node in schedule[i]:
                        level_index[node] = i+1
                
                order = [i for i in range(size)]
                random.shuffle(order)
   
                label_new = []
                graph_new = []
                for i in order:
                    embedding = graph[i]
                    graph_new.append(embedding)
                    label_new.append(level_index[embedding[-2]])
                    #label_new.append(level_index[embedding[-3]])

                self.data.append(torch.FloatTensor(graph_new))
                self.label.append(torch.FloatTensor(label_new))

        self.size = len(self.data)
   
    def _embedding_transformer(self, G, in_degree_total):
        node_memory = nx.get_node_attributes(G, 'param_bytes')
        nodes_collection = list(G.nodes())
        random.shuffle(nodes_collection)
        node_input = defaultdict(list)
        for start_p, end_p in G.edges():
            node_input[nodes_collection.index(end_p)].append(nodes_collection.index(start_p))

        graph = [] 
        node_to_level = defaultdict(int)
        level = 0

        while G.number_of_nodes() > 0: 
            level += 1
            nodes_in_processing = [head for head in G.nodes() if G.in_degree(head) == 0]
        
            for node in nodes_in_processing:
                node_id = nodes_collection.index(node)

                parents_id = node_input[node_id]

                embedding = [level] + [node_to_level[p]  for p in parents_id] + [0 for _ in range(in_degree_total-len(parents_id))] + parents_id + [-1 for _ in range(in_degree_total-len(parents_id))] + [node_id, node_memory.get(node, 0)]
                node_to_level[node_id] = level
                graph.append(embedding)
            G.remove_nodes_from(nodes_in_processing)

        return graph, nodes_collection

    def _embedding_generator(self, size, in_degree_fixed, in_degree_total, level_range, weight_multiply, shift):
        """
        Generate graph embedding on node level;
        size: num of nodes per graph
        in_degree_fixed: num parents per node permitted
        level_range: list show the range of level permitted
        weight_multiply: coefficient to adjust memory weight per node
        shift: coefficient to adjust memory weight per node
        return: DiGraph and graph embedding
        """
        num_level = random.randint(level_range[0], level_range[1]) # pick number of level randomly in the range of level_range per graph
        level = [1 for _ in range(num_level)] # design number of noder per level
        remaining = size - num_level # signal to show how many nodes available for assigning after ensuring at least one node per level
        traverse = 0 # point to go through each level

        while remaining > 0:
            addition = random.randint(0, remaining) # pick num of node randomly to add on level(traverse)
            level[traverse % num_level] += addition
            traverse += 1
            remaining -= addition
        caution = [i+1 for i, val in enumerate(level) if val < in_degree_fixed] # indicators of level less than the required in degree number, need to treat in embedding

        labels = [i for i in range(size)] # node labeling between 0 to size-1
        random.shuffle(labels)
        level_to_nodes = defaultdict(int) # collecting labels of nodes per level
        labels_distribution = [] # collecting the distribution of labels for each level
        distributor = 0
        for i in range(num_level):
            level_to_nodes[i+1] = labels[distributor:(distributor+level[i])]
            labels_distribution.append(labels[distributor:(distributor+level[i])])
            distributor += level[i]
                
        
        graph = []
        graph_edges = []

        for i in range(num_level):
            for j in range(level[i]):
                embedding = [i+1] # level of current node
                embedding.extend([random.randint(0, i) for _ in range(in_degree_fixed)]) # pick level of parents randomly
                # ensure one parent on direct upper level of current node
                if max(embedding[1:]) < i:
                    embedding[random.randint(1, in_degree_fixed)] = i
                # ensure num of parents not exceeding the num of designed level 
                for constraint in caution:
                    while embedding[1:].count(constraint) > level[constraint-1]:
                        embedding[embedding.index(constraint)] = 0

                embedding.extend([0 for _ in range(in_degree_total-in_degree_fixed)]) # extend dimension to the designed fixed num

                nodes_to_be_assigned = embedding[1:] # pick labels to selected parents
                while len(nodes_to_be_assigned) > 0:
                    node = nodes_to_be_assigned[0]
                    occurrences = nodes_to_be_assigned.count(node)
                    if node > 0:
                        embedding.extend(random.sample(level_to_nodes[node], occurrences))
                    else:
                        embedding.extend([-1 for _ in range(occurrences)])
                    nodes_to_be_assigned = [element for element in nodes_to_be_assigned if element != node]
               
                label_to_current_node = random.sample(labels_distribution[i], 1)
                labels_distribution[i].remove(label_to_current_node[0])
                embedding.append(label_to_current_node[0])

                embedding.append(random.random() * weight_multiply + shift)

                # collect directed edges for graph building
                if i > 0:
                    #for predecessor in embedding[-2-in_degree_fixed:-2]:
                    for predecessor in embedding[-2-in_degree_total:-2]:
                        if predecessor > -1:
                            graph_edges.append((predecessor, label_to_current_node[0]))        

                graph.append(embedding)
        
        """
        node_weight = defaultdict(int) # Get accumulated weight per node from all its children

        for i in range(len(graph)-1, -1, -1):
            embedding = graph[i]
            node_label = embedding[-2]

            for predecessor in embedding[-2-in_degree_fixed:-2]:
                if predecessor > -1:
                    node_weight[predecessor] += embedding[-1]

            embedding.append(node_weight[node_label])
            graph[i] = embedding
        """

        G = nx.DiGraph()
        G.add_edges_from(graph_edges)

        return graph, G

    def _dag_generator(self, size, in_degree_fixed, in_degree_total, num_operators, level_range, weight_multiply, shift):
        """
        return graph embedding and Digraph with enough attributes 
        """
        while True:
            graph, D = self._embedding_generator(size, in_degree_fixed, in_degree_total, level_range, weight_multiply, shift)
            # ensure graph is connected, acyclic and node size aligned
            if nx.is_connected(D.to_undirected()) and nx.is_directed_acyclic_graph(D) and D.number_of_nodes() == size:
                break

        attributes = {graph[i][-2]:{'operator':random.randint(-num_operators, -1), 'priority':0, 'label':graph[i][-2], 'weight':graph[i][-1]} for i in range(size)}
        #attributes = {graph[i][-3]:{'operator':random.randint(-num_operators, -1), 'priority':0, 'label':graph[i][-3], 'weight':graph[i][-2]} for i in range(size)}
        nx.set_node_attributes(D, attributes)
        # confirm priority of each node for label design
        DAG = self._priority_sorting(D)
       
        return graph, DAG

    def _scheduling(self, D, size, resource_constraint_level, weight_constraint):
        """
        schedule nodes based on resource constraint and weight constraint

        """
        #DAG = D.reverse(copy=True)
        #op = DAG.nodes[0]['operator']
        op = D.nodes[0]['operator']
        resource_constraint = resource_constraint_level[op]
        def path_exploring(resource_constraint, weight_constraint):
            #if DAG.number_of_nodes() <= 0:
            if D.number_of_nodes() <= 0:
                return []
            #candidates = sorted([n for n, d in DAG.in_degree() if d==0], key=lambda x: (DAG.nodes[x]['priority'], -x), reverse=True)
            candidates = sorted([n for n, d in D.in_degree() if d==0], key=lambda x: (D.nodes[x]['priority'], -x), reverse=True)
            schedule = candidates[:resource_constraint]
            while True:
                #weight_in_all = sum([DAG.nodes[node]['weight'] for node in schedule])
                weight_in_all = sum([D.nodes[node]['weight'] for node in schedule])
                if weight_in_all <= weight_constraint:
                    break
                schedule.pop()                
            #DAG.remove_nodes_from(schedule) 
            D.remove_nodes_from(schedule) 
            return [schedule] + path_exploring(resource_constraint, weight_constraint)
        return path_exploring(resource_constraint, weight_constraint)

    """
    def _scheduling(self, D, size, resource_constraint_level, weight_constraint):
        DAG = D.reverse(copy=True)
        candidates = [head for head in range(size) if DAG.in_degree(head) == 0]
        candidates.sort(key = lambda x: (DAG.nodes[x]['priority'], -x), reverse=True)
        op = DAG.nodes[candidates[0]]['operator']
        resource_constraint = resource_constraint_level[op]
        def path_exploring(candidates, resource_constraint, weight_constraint):
            if len(candidates) <= 0:
                return []
            schedule = candidates[:resource_constraint]
            while True:
                weight_in_all = sum([D.nodes[node]['weight'] for node in schedule])
                if weight_in_all <= weight_constraint:
                    break
                schedule.pop()                
            candidates = set(candidates[len(schedule):])
            for node in schedule:
                candidates = candidates.union(set(DAG.successors(node)))
            return [schedule] + path_exploring(sorted(list(candidates), key=lambda x: (DAG.nodes[x]['priority'], -x), reverse=True), resource_constraint, weight_constraint)
        return path_exploring(candidates, resource_constraint, weight_constraint)
    """         
    def _priority_sorting(self, D):
        """        
        algorithm based on list scheduling
        """ 
        #DAG = D.reverse(copy=True)
        #leaf = set([node for node in range(DAG.number_of_nodes()) if DAG.out_degree(node) == 0])
        leaf = set([node for node in range(D.number_of_nodes()) if D.out_degree(node) == 0])
        def priority_rewrite(nodes, level):
            visited = set()
            for node in nodes:
                #if DAG.nodes[node]['priority'] >= level:
                if D.nodes[node]['priority'] >= level:
                    continue
                #DAG.nodes[node]['priority'] = level
                D.nodes[node]['priority'] = level
                #upper_level = set(DAG.predecessors(node)).difference(visited)
                upper_level = set(D.predecessors(node)).difference(visited)
                priority_rewrite(upper_level, level+1)
                visited = visited.union(upper_level)
            return
        priority_rewrite(leaf, 1)
        #return DAG.reverse(copy=True)
        return D
            
    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        #return self.data[idx], self.label[idx]
        return self.data[idx], self.label[idx]

if __name__ == '__main__':

    #files = ('model_0_densenet121.gpickle', 'model_0_densenet169.gpickle', 'model_0_densenet201.gpickle', 'model_0_inception_resnet_v2.gpickle', 'model_0_inception_v3.gpickle', 'model_0_mobilenet_1.00_224.gpickle', 'model_0_mobilenetv2_1.00_224.gpickle', 'model_0_resnet101.gpickle', 'model_0_resnet101v2.gpickle', 'model_0_resnet152.gpickle', 'model_0_resnet152v2.gpickle', 'model_0_vgg16.gpickle', 'model_0_vgg19.gpickle', 'model_0_xception.gpickle', 'model_0_resnet50.gpickle', 'model_0_resnet50v2.gpickle')
    #files = ('model_0_resnet50.gpickle', 'model_0_resnet50v2.gpickle')

    #files = ('model_0_NASNet.gpickle', )

    #for f in files:
        #path_file = "test/graph/" + f
        #path_file = "model_check/graph/" + f
        
        #dataset_name = f.strip('.gpickle').split('_')[2]
        #dataset_name = f.strip('.gpickle').replace('model_0_', '')
    
        #myDataset = TopoSortDataset(path_file, in_degree_total=6)
        #torch.save(myDataset, 'eval_dataset/model_layer_embedding_indegree_6/' + dataset_name + '.pt') 
   
    #resource_constraint_lvl = {-1:1, -2:1, -3:1}
    #resource_constraint_lvl = {-1:3}

    #training_lvl = {-1:1000}

    eval_lvl = {-1:50}
    #level_ranges = [[32, 80], [64, 160], [96, 240], [128, 320], [160, 400], [192, 480], [224, 560], [256, 640], [288, 720], [320, 800]]
    #weight_base=35.
    #multiply = 1.5
    #lvl = -1
    #lvl = 0
    #weight_total = 105
    #weight_total = 192
    #for i in range(100, 1001, 100):
    #    lvl += 1
    #    if i == 900: continue
        #weight_total = weight_base * multiply
        #myDataset = TopoSortDataset(size=i, num_samples=10240, in_degree_fixed=6, in_degree_total=6, resource_constraint_level=eval_lvl, level_range=level_ranges[lvl], weight_multiply=5., weight_constraint=5.*i)
    #    myDataset = TopoSortDataset(size=i, num_samples=10240, in_degree_fixed=6, in_degree_total=6, resource_constraint_level=eval_lvl, level_range=level_ranges[lvl], weight_multiply=5., weight_constraint=weight_total*1.)
        #torch.save(myDataset, "eval_dataset/operator_type_1/scheduling_level_aligned_resource_free/weight_free/large_graph_size/TopoSort" + str(i) + "_Dataset_Eval_6_in_degree_10K_" + str(level_ranges[lvl][0]) + "to" + str(level_ranges[lvl][1]) + "_weight" + str(int(5.*i)) + ".pt") 
    #    torch.save(myDataset, "eval_dataset/operator_type_1/scheduling_level_aligned_resource_free/graph_large_memory_weight192/TopoSort" + str(i) + "_Dataset_Eval_6_in_degree_10K_" + str(level_ranges[lvl][0]) + "to" + str(level_ranges[lvl][1]) + "_weight" + str(weight_total) + ".pt") 
        #lvl += 1
        #multiply += 0.5
    """
    myDataset = TopoSortDataset(size=10, num_samples=5, in_degree_fixed=3, in_degree_total=6, resource_constraint_level=eval_lvl, level_range=[5, 10], weight_multiply=5., shift=0., weight_constraint=20.)
    myDataset = TopoSortDataset(size=30, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[10, 25], weight_multiply=375., shift=25., weight_constraint=1000.)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort30_Dataset_4_in_degree_3resource_priorityInverse_1K_cw25to400_weight1000.pt") 

    myDataset = TopoSortDataset(size=30, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[10, 25], weight_multiply=275., shift=25., weight_constraint=1000.)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort30_Dataset_4_in_degree_3resource_priorityInverse_1K_cw25to300_weight1000.pt") 

    myDataset = TopoSortDataset(size=30, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[10, 25], weight_multiply=175., shift=25., weight_constraint=1000.)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort30_Dataset_4_in_degree_3resource_priorityInverse_1K_cw25to200_weight1000.pt") 

    myDataset = TopoSortDataset(size=30, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[10, 25], weight_multiply=75., shift=25., weight_constraint=1000.)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort30_Dataset_4_in_degree_3resource_priorityInverse_1K_cw25to100_weight1000.pt") 

    myDataset = TopoSortDataset(size=50, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[16, 40], weight_multiply=375., shift=25., weight_constraint=1000.)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort50_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to400_weight1000.pt") 

    myDataset = TopoSortDataset(size=50, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[16, 40], weight_multiply=275., shift=25., weight_constraint=1000.)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort50_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to300_weight1000.pt") 

    myDataset = TopoSortDataset(size=50, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[16, 40], weight_multiply=175., shift=25., weight_constraint=1000.)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort50_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to200_weight1000.pt") 

    myDataset = TopoSortDataset(size=50, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[16, 40], weight_multiply=75., shift=25., weight_constraint=1000.)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort50_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to100_weight1000.pt") 
    """

    #myDataset = TopoSortDataset(size=30, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[10, 25], weight_multiply=9.5, shift=0.5, weight_constraint=35)
    #torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort30_Dataset_4_in_degree_3resource_priorityInverse_1K_cw05to10_weight35.pt") 

    #myDataset = TopoSortDataset(size=30, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[10, 25], weight_multiply=475., shift=25., weight_constraint=1000)
    #torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort30_Dataset_4_in_degree_3resource_priorityInverse_1K_cw25to500_weight1000.pt") 

    #myDataset = TopoSortDataset(size=10, num_samples=5, in_degree_fixed=4, resource_constraint_level=resource_constraint_lvl, level_range=[4, 10], weight_multiply=10, shift=1., weight_constraint=15)
    #torch.save(myDataset, "eval_dataset/operator_type_1/TopoSort15_Dataset_Eval_3_in_degree_5samples.pt") 
    #torch.save(myDataset, "training_dataset/operator_type_1/TopoSort10_Dataset_Training_4_in_degree_10samples.pt") 

    #myDataset = TopoSortDataset(size=10, num_samples=10, in_degree_fixed=4, in_degree_total=4, resource_constraint_level=training_lvl, level_range=[2, 10], weight_multiply=5., weight_constraint=35.)
    #myDataset = TopoSortDataset(size=30, num_samples=128000, in_degree_fixed=6, in_degree_total=6, resource_constraint_level=training_lvl, level_range=[10, 25], weight_multiply=5., weight_constraint=150.)
    #myDataset = TopoSortDataset(size=30, num_samples=100, in_degree_fixed=4, resource_constraint_level=training_lvl, level_range=[10, 25], weight_multiply=5, weight_constraint=5)
    #torch.save(myDataset, "training_dataset/operator_type_1/scheduling_level_aligned_resource_free/weight_free/TopoSort30_Dataset_Training_6_in_degree_6_in_total_128k_10to25.pt") 
    #torch.save(myDataset, "training_dataset/operator_type_1/scheduling_level_aligned_resource_free/TopoSort30_Dataset_Training_4_in_degree_128k_10to25_weight35.pt") 

    myDataset = TopoSortDataset(size=50, num_samples=10240, in_degree_fixed=3, in_degree_total=6, resource_constraint_level=eval_lvl, level_range=[16, 40], weight_multiply=5., weight_constraint=35.)
    #torch.save(myDataset, "eval_dataset/operator_type_1/scheduling_level_aligned_resource_free/weight_free/TopoSort50_Dataset_Eval_6_in_degree_6_in_total_10K_16to40.pt") 
    torch.save(myDataset, "eval_dataset/operator_type_1/scheduling_level_aligned_resource_free/graph_large_memory_weight35/TopoSort50_Dataset_Eval_3_in_degree_6_in_total_10K_16to40_weight35.pt") 
    myDataset = TopoSortDataset(size=50, num_samples=10240, in_degree_fixed=5, in_degree_total=6, resource_constraint_level=eval_lvl, level_range=[16, 40], weight_multiply=5., weight_constraint=35.)
    torch.save(myDataset, "eval_dataset/operator_type_1/scheduling_level_aligned_resource_free/graph_large_memory_weight35/TopoSort50_Dataset_Eval_5_in_degree_6_in_total_10K_16to40_weight35.pt") 
    #myDataset = TopoSortDataset(size=50, num_samples=128000, in_degree_fixed=4, resource_constraint_level=training_lvl, level_range=[20, 40], weight_multiply=5, weight_constraint=5)
    #myDataset = TopoSortDataset(size=30, num_samples=100, in_degree_fixed=4, resource_constraint_level=training_lvl, level_range=[10, 25], weight_multiply=5, weight_constraint=5)
    #torch.save(myDataset, "training_dataset/operator_type_1/TopoSort50_Dataset_Training_4_in_degree_3resource_priorityInverse_128k_20to40_weight5.pt") 

    #myDataset = TopoSortDataset(size=20, num_samples=128000, in_degree_fixed=3, resource_constraint_level=training_lvl)
    #torch.save(myDataset, "training_dataset/operator_type_1/small_volume/TopoSort20_Dataset_Training_3_in_degree_1resources_priorityInverse_128K_nonRepeated.pt") 
    #eval_lvl = {-1:3}
    #myDataset = TopoSortDataset(size=20, num_samples=10240, in_degree_fixed=3, resource_constraint_level=eval_lvl)
    #torch.save(myDataset, "eval_dataset/operator_type_1/small_volume/TopoSort20_Dataset_Eval_3_in_degree_1resources_priorityInverse_10K_nonRepeated.pt") 
    """
    myDataset = TopoSortDataset(size=50, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[16, 40], weight_multiply=9.5, shift=0.5, weight_constraint=35)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort50_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw05to10_weight35.pt") 

    myDataset = TopoSortDataset(size=100, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[16*2, 80], weight_multiply=9.5, shift=0.5, weight_constraint=35)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort100_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw05to10_weight35.pt") 

    myDataset = TopoSortDataset(size=200, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[64, 160], weight_multiply=9.5, shift=0.5, weight_constraint=35)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort200_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw05to10_weight35.pt") 

    myDataset = TopoSortDataset(size=300, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[96, 240], weight_multiply=9.5, shift=0.5, weight_constraint=35)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort300_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw05to10_weight35.pt") 

    myDataset = TopoSortDataset(size=400, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[128, 320], weight_multiply=9.5, shift=0.5, weight_constraint=35)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort400_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw05to10_weight35.pt") 

    myDataset = TopoSortDataset(size=500, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[160, 400], weight_multiply=9.5, shift=0.5, weight_constraint=35)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort500_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw05to10_weight35.pt") 

    myDataset = TopoSortDataset(size=50, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[16, 40], weight_multiply=1.95, shift=0.05, weight_constraint=8)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort50_Dataset_4_in_degree_3resource_priorityInverse_1K_cw005to2_weight8.pt") 

    myDataset = TopoSortDataset(size=100, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[16*2, 80], weight_multiply=1.95, shift=0.05, weight_constraint=8)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort100_Dataset_4_in_degree_3resource_priorityInverse_1K_cw005to2_weight8.pt") 

    myDataset = TopoSortDataset(size=200, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[64, 160], weight_multiply=1.95, shift=0.05, weight_constraint=8)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort200_Dataset_4_in_degree_3resource_priorityInverse_1K_cw005to2_weight8.pt") 

    myDataset = TopoSortDataset(size=300, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[96, 240], weight_multiply=1.95, shift=0.05, weight_constraint=8)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort300_Dataset_4_in_degree_3resource_priorityInverse_1K_cw005to2_weight8.pt") 

    myDataset = TopoSortDataset(size=400, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[128, 320], weight_multiply=1.95, shift=0.05, weight_constraint=8)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort400_Dataset_4_in_degree_3resource_priorityInverse_1K_cw005to2_weight8.pt") 

    myDataset = TopoSortDataset(size=500, num_samples=1000, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[160, 400], weight_multiply=1.95, shift=0.05, weight_constraint=8)
    torch.save(myDataset, "training_dataset/operator_type_1/small_volume_adaptive_learning/TopoSort500_Dataset_4_in_degree_3resource_priorityInverse_1K_cw005to2_weight8.pt") 

    myDataset = TopoSortDataset(size=50, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[16, 40], weight_multiply=475., shift=25., weight_constraint=1000)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort50_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to500_weight1000.pt") 

    myDataset = TopoSortDataset(size=100, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[32, 80], weight_multiply=475., shift=25., weight_constraint=1000)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort100_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to500_weight1000.pt") 

    myDataset = TopoSortDataset(size=200, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[64, 160], weight_multiply=475., shift=25., weight_constraint=1000)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort200_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to500_weight1000.pt") 

    myDataset = TopoSortDataset(size=300, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[96, 240], weight_multiply=475., shift=25., weight_constraint=1000)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort300_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to500_weight1000.pt") 

    myDataset = TopoSortDataset(size=400, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[128, 320], weight_multiply=475., shift=25., weight_constraint=1000)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort400_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to500_weight1000.pt") 

    myDataset = TopoSortDataset(size=500, num_samples=10240, in_degree_fixed=4, resource_constraint_level=eval_lvl, level_range=[160, 400], weight_multiply=475., shift=25., weight_constraint=1000)
    torch.save(myDataset, "eval_dataset/operator_type_1/validation_set/TopoSort500_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_cw25to500_weight1000.pt") 
    """

    #for i in range(1000, 5001, 1000):
    #    myDataset = TopoSortDataset(size=i, num_samples=10)

    #    torch.save(myDataset, "eval_large_size/TopoSort" + str(i) + "_Dataset_Validation_10_in_degree.pt") 

    #torch.save(myDataset, "TopoSort20_Dataset_Training_10_in_degree.pt") 

    #myDataset = torch.load("eval_dataset/operator_type_1/TopoSort50_Dataset_Eval_4_in_degree_3resource_priorityInverse_10K_20to40_weight5.pt") 
    #dataset = torch.load("eval_dataset/operator_type_1/TopoSort15_Dataset_Eval_3_in_degree_5samples.pt", map_location=torch.device('cuda'))
    #dataset = torch.load("training_dataset/operator_type_1/TopoSort10_Dataset_Training_4_in_degree_10samples.pt", map_location=torch.device('cuda')) 

    #indices = torch.randperm(len(myDataset))[:3]
    #training_dataloader = DataLoader(myDataset, batch_size=1, sampler=SubsetRandomSampler(indices))
    #print(indices)
    #training_dataloader = DataLoader(myDataset, batch_size=1, shuffle=True)

    #for batch_id, batch in enumerate(tqdm(training_dataloader)):
    #for batch in tqdm(training_dataloader):
        #print(batch_id)
        #print("training_data: ", batch[0])
        #print("training_label: ", batch[1])
