import copy
import queue
import types
import numpy as np

# ntype.string
ntype = types.SimpleNamespace(noexist   =0, # the node does not exist in the original graph. 
                        string    =1, # a string in the string dictionary
                        binary    =2, # node value is binary
                        numerical =3, # node value is a number, integer or float
                        strval    =4, # node value is a string
                        strlist   =5, # node value is a list of strings.  
                        dict      =6, # node value is a dictionary, will be expanded
                        list      =7, # node value is a list of other types, will be expanded 
                        category  =8, # node value
                        none      =9, # node value
                       )

etype = types.SimpleNamespace(noexist   =0, # reserved
                        tree      =1, # dictionary - value edge
                        key       =2, # node - string edge. The string is key of the node 
                        value     =3  # node - string edge. The string is either an entry of the list node or the value of the node 
                       )


item_entries = {'0': {'item': 'minecraft:iron_pickaxe',
  'count': 0,
  'damage': 0,
  'maxdamage': 250},
 '1': {'item': 'minecraft:diamond_block',
  'count': 0,
  'damage': 0,
  'maxdamage': 0},
 '2': {'item': 'minecraft:diamond', 'count': 0, 'damage': 0, 'maxdamage': 0},
 '3': {'item': 'polycraft:tree_tap',
  'count': 0,
  'damage': 0,
  'maxdamage': 0,
  'enabled': 'true',
  'facing': 'down'},
 '4': {'item': 'polycraft:sack_polyisoprene_pellets',
  'count': 0,
  'damage': 0,
  'maxdamage': 0},
 '5': {'item': 'polycraft:wooden_pogo_stick',
  'count': 0,
  'damage': 0,
  'maxdamage': 64},
 '6': {'item': 'minecraft:planks',
  'count': 0,
  'damage': 0,
  'maxdamage': 0,
  'variant': 'oak'},
 '7': {'item': 'minecraft:log',
  'count': 0,
  'damage': 0,
  'maxdamage': 0,
  'axis': 'y',
  'variant': 'oak'},
 '8': {'item': 'polycraft:key',
  'count': 0,
  'damage': 4,
  'maxdamage': 0,
  'color': 'blue'},
 '9': {'item': 'minecraft:stick', 'count': 0, 'damage': 0, 'maxdamage': 0},
 '10': {'item': 'minecraft:sapling',
  'count': 0,
  'damage': 0,
  'maxdamage': 0,
  'stage': '0',
  'type': 'oak'},
 '11': {'item': 'polycraft:block_of_platinum',
  'count': 0,
  'damage': 0,
  'maxdamage': 0},
 '12': {'item': 'polycraft:block_of_titanium',
  'count': 0,
  'damage': 0,
  'maxdamage': 0},
 '13': {'item': 'minecraft:bedrock',
  'count': 0,
  'damage': 0,
  'maxdamage': 0},
 '14': {'item': 'minecraft:crafting_table',
  'count': 0,
  'damage': 0,
  'maxdamage': 0}}

fxd_inventory= {'0': 'minecraft:iron_pickaxe',
                        '1': 'minecraft:diamond_block',
                        '2': 'minecraft:diamond',
                        '3': 'polycraft:tree_tap',
                        '4': 'polycraft:sack_polyisoprene_pellets',
                        '5': 'polycraft:wooden_pogo_stick',
                        '6': 'minecraft:planks',
                        '7': 'minecraft:log',
                        '8': 'polycraft:key',
                        '9': 'minecraft:stick',
                        '10': 'minecraft:sapling',
                        '11': 'polycraft:block_of_platinum',
                        '12': 'polycraft:block_of_titanium',
                        '13': 'minecraft:bedrock',
                        '14': 'minecraft:air',
                        '15': 'minecraft:crafting_table'}

inverse_inventory={'minecraft:iron_pickaxe': '0',
 'minecraft:diamond_block': '1',
 'minecraft:diamond': '2',
 'polycraft:tree_tap': '3',
 'polycraft:sack_polyisoprene_pellets': '4',
 'polycraft:wooden_pogo_stick': '5',
 'minecraft:planks': '6',
 'minecraft:log': '7',
 'polycraft:key': '8',
 'minecraft:stick': '9',
 'minecraft:sapling': '10',
 'polycraft:block_of_platinum': '11',
 'polycraft:block_of_titanium': '12',
 'minecraft:bedrock': '13',
 'minecraft:air':'14',
 'minecraft:crafting_table':'15'}

class NewNodeException(Exception):
    pass



class JsonToGraph:
    
    def __init__(self):
        """
        Initialize the object with constructing mode
        """
        
        # working mode: constructing mode or processing mode 
        self.constructing_mode = True

        # used to mark strings
        self.str_marker = "8UNQ5RVoTP"

        if self.constructing_mode:
            self.switch_mode("constructing")

    def switch_mode(self, mode=None):
        """
        Two working modes: constructing or processing. If the constructing mode is on, then the object is intialized. 
        """

        if mode == "constructing":
            self.constructing_mode = True
        elif mode == "processing":
            self.constructing_mode = False 
        else:
            raise Exception("No such working mode", mode)

        # initialize all dictionaries
        if self.constructing_mode:

            self.node_type = dict()
            self.node_id_dict = dict()

            # node 0 always represent the entire json object. 
            self.global_id = 0
            self.node_id_dict[0] = dict()
            self.node_type[0] = ntype.dict

    def fill_missing_object_properties(self,entry):
        '''
        Adds missing object properties. They are missing because they dont apply to the particular object so they are marked as none
        '''
        if 'axis' not in entry:
            entry['axis']='y'
        if 'variant' not in entry:
            entry['variant']='none'
        if 'facing' not in entry:
            entry['facing']='none'
        else:
            entry['facing'] = entry['facing'].upper()
        if 'hinge' not in entry:
            entry['hinge']='none'
        if 'powered' not in entry:
            entry['powered']='none'
        if 'half' not in entry:
            entry['half']='none'
        if 'enabled' not in entry:
            entry['enabled']='none'
        if 'open' not in entry:
            entry['open']='none'
        if 'color' not in entry:
            entry['color']='none'
        return entry

    def remove_partial_object_properties(self,entry):
        '''
        Adds missing object properties. They are missing because they dont apply to the particular object so they are marked as none
        '''
        if 'axis' in entry:
            del entry['axis']
        if 'variant' in entry:
            del entry['variant']
        if 'facing' in entry:
            del entry['facing']
        if 'hinge' in entry:
            del entry['hinge']
        if 'powered'in entry:
            del entry['powered']
        if 'half'  in entry:
            del entry['half']
        if 'enabled'  in entry:
            del entry['enabled']
        if 'open'  in entry:
            del entry['open']
        if 'color'  in entry:
            del entry['color']
        return entry


    def distance(self,pos1,pos2):
        return float(np.sum(np.sqrt((np.array(pos1)-np.array(pos2))**2)))

    def normalize_polycraft_json(self,jsondict, remove_map = False):
        '''
        Normalized the json structure so that the nodes in the resulting graph are always the same (only attributes vary)
        
        '''
        #fix map:
        entry = copy.deepcopy(jsondict['blockInFront'])
        jsondict['blockInFront'] = self.remove_partial_object_properties(entry)#self.fill_missing_object_properties(entry)
        pos_map={}
        if not remove_map:
            for i,pos in enumerate(list(jsondict['map'])):
                pos_map[pos]=f"map_{i}"
                entry = copy.deepcopy(jsondict['map'].pop(pos))
                jsondict['map'][f"map_{i}"] = self.remove_partial_object_properties(entry)#self.fill_missing_object_properties(entry)
                
        entry = jsondict['player']
        player_pos = jsondict['player']['pos']
        jsondict['player']['pos'] = pos_map[f"{str(entry['pos'][0])},{str(entry['pos'][1])},{str(entry['pos'][2])}"] if not remove_map else 'none'

        

        #normalize inventory
        curr_inventory = jsondict['inventory']
        del curr_inventory['selectedItem']
        fxd_inventory = copy.deepcopy(item_entries)
        for key,val in curr_inventory.items():
            fxd_inventory[inverse_inventory[val['item']]]=val
        jsondict['inventory']=fxd_inventory
        ###################################################
        #normalize entities
        entities = jsondict['entities']
        fixed_entities={}
        
        for key in list(entities.keys()):
            if entities[key]['name']=='item.tile.sapling.oak':
                entry = copy.deepcopy(entities[key])
                entry['id']='sapling_id'
                entry['dist']=self.distance(player_pos,entry['pos'])
                entry['pos'] = pos_map[f"{str(entry['pos'][0])},{str(entry['pos'][1])},{str(entry['pos'][2])}"] if not remove_map else 'none'
                
                fixed_entities['sapling_id']=entry

            if entities[key]['name']=="entity.polycraft.Pogoist.name":
                entry = copy.deepcopy(entities[key])
                entry['id']='pogoist_id'
                entry['dist']=self.distance(player_pos,entry['pos'])
                entry['pos'] = pos_map[f"{str(entry['pos'][0])},{str(entry['pos'][1])},{str(entry['pos'][2])}"] if not remove_map else 'none'
                #print(player_pos,entry['pos'])
                
                fixed_entities['pogoist_id']=entry

            

        entities = fixed_entities
        jsondict['entities'] = fixed_entities

        if remove_map:
            del jsondict['map']
            del jsondict['destinationPos']

        return copy.deepcopy(jsondict)


    def normalize_gridworld_json(self, json_obj, remove_map = False):
        inventory_order = {"0": "minecraft:planks", "1": "polycraft:wooden_pogo_stick", "2":"minecraft:bedrock", "3": "minecraft:air","4": "polycraft:sack_polyisoprene_pellets", "5": "minecraft:crafting_table","6":  "polycraft:tree_tap", "7":"minecraft:stick","8": "minecraft:log"}
        inverse_order = {v:k for k,v in inventory_order.items()}
    
        fixed_inventory = {}
        for k,v in json_obj['inventory'].items():
            if k in inventory_order:
                item  = v['item']

                fixed_inventory[inverse_order[item]]=copy.deepcopy(v)
        
        json_obj['inventory']=copy.deepcopy(fixed_inventory)
        after = json_obj['inventory']['0']['item']

        json_obj['step']=0

        if remove_map:
            del json_obj['map']
            del json_obj['destinationPos']

        return json_obj


    def prune_json(self, json_obj, task):
        
        if task == 'monopoly':
            # json_obj = json_obj['prev_state']
            del json_obj['true_next_state']
            del json_obj['actions_and_params']
            del json_obj['prev_state']['history']
            del json_obj['prev_state']['cards']
            del json_obj['prev_state']['die_sequence']
            del json_obj['prev_state']['locations']
            del json_obj['prev_state']['location_sequence']
            for i in range(1,5):
                del json_obj['prev_state']['players']['player_'+str(i)]['outstanding_trade_offer']
                del json_obj['prev_state']['players']['player_'+str(i)]['outstanding_property_offer']
                del json_obj['prev_state']['players']['player_'+str(i)]['mortgaged_assets']
                del json_obj['prev_state']['players']['player_'+str(i)]['assets']
                del json_obj['prev_state']['players']['player_'+str(i)]['currently_in_jail']
                del json_obj['prev_state']['players']['player_'+str(i)]['option_to_buy']
                del json_obj['prev_state']['players']['player_'+str(i)]['is_property_offer_outstanding']
                del json_obj['prev_state']['players']['player_'+str(i)]['is_trade_offer_outstanding']
                del json_obj['prev_state']['players']['player_'+str(i)]['num_railroads_possessed']
                del json_obj['prev_state']['players']['player_'+str(i)]['num_utilities_possessed']
                del json_obj['prev_state']['players']['player_'+str(i)]['num_total_houses']
                del json_obj['prev_state']['players']['player_'+str(i)]['num_total_hotels']
                if 'full_color_sets_possessed' in json_obj['prev_state']['players']['player_'+str(i)]:
                    del json_obj['prev_state']['players']['player_'+str(i)]['full_color_sets_possessed']
            import sys
            print(json_obj,file=sys.stderr)
            exit()
        elif task == "gridworld" and False:
            
            json_obj['goal']['Distribution']= "Uniformed"
            
            #json_obj=self.normalize_polycraft_json(copy.deepcopy(json_obj), remove_map=True)
            json_obj=self.normalize_gridworld_json(copy.deepcopy(json_obj), remove_map=True)

        elif task == "polycraft" or task == 'gridworld':
            
            json_obj['goal']['Distribution']= "Uniformed"
            
            json_obj=self.normalize_polycraft_json(copy.deepcopy(json_obj), remove_map=True)
            
            
           
        return json_obj

    def mark_json(self, json_obj, task, marker=None):
        """
        mark json object with different markers, so it can be processed later. The purpose is to simplify the json object.  
        Several markers can be created. 
        """
 
        if type(marker) is str:
            self.str_marker = marker

        if task == "monopoly":
            pass     

        elif task == "gridworld": # this part of code is task related. You can do whatever you want here
            #for i in range(10):
            #    for j in range(10):
            #        pos = str(i) + ",4," + str(j)
            #        # add substring "USEKEYID!" in front of the key pos 
            #        # The operation removes pos and add a new entry with key ("USEKEYID!" + pos)
            #        new_key = self._mark_str(pos)

            #        json_obj["map"][new_key] = json_obj["map"].pop(pos)["name"]

#            json_obj["player"]["pos"] = str(json_obj["player"]["pos"][0]) + "," + str(json_obj["player"]["pos"][1]) + "," + str(json_obj["player"]["pos"][2])
            pass
        return json_obj


    
    def _mark_str(self, st, use_keyid=False, neglect_str=False, no_parent=False, is_cateogry=False):
        """
            We mark a key string st by "8UNQ5RVoTP!0000!" + st. Here "0000" are bits indicating the handling of key-value pair. The bits 
            indicates: (USEKEYID, NEGLECTSTR, NOPARENENT, ISCATEGORY)
            use_keyid: if on, use the key as the id of the node, irrespect to the json structure leading to the key. NOTE: if a key 
                         string is used in multiple nodes, then only the last node decides the attribute for the node created 
                         from the key. 
            neglect_str: if off, the key string will be a string node (created before or by this string) and connect this node to the 
                        string; otherwise, do not create a node for the string. Likely the string (e.g. a hash encoding) 
                        only appears once in the json file, then there is no need to create a string node and create a link to this node. 
            no_parenent: if on, do not connect this node to its parent node in the json structure. 
            is_cateogry: this marker indicates that the value of this node is categorical. Then the value will be encoded as one-hot encoding.    
        """
        
        bits = str(int(use_keyid)) + str(int(neglect_str)) + str(int(no_parent)) + str(int(use_keyid))

        marked_str = self.str_marker + '!' + bits + '!' + st

        return marked_str

    def _read_masked_str(self, marked_str):

        use_keyid, neglect_str, no_parenent, is_cateogry=False, False, False, False
        st = marked_str

        if st.startswith(self.str_marker):
            bits = marked_str[len(self.str_marker) + 1: len(self.str_marker) + 5]
            st = marked_str[len(self.str_marker) + 6:]
            
            use_keyid   = bool(int(bits[0]))
            neglect_str = bool(int(bits[1]))
            no_parenent = bool(int(bits[2]))
            is_cateogry = bool(int(bits[3]))

        return st, use_keyid, neglect_str, no_parenent, is_cateogry

    def _str_marked(self, st):
        return st.startswith(self.str_marker)


    def _check_node(self, key, parent_id=None, constructing_mode=False):
        """
        Check the node in the dictionary and return the node id. If the node is not in the dict, it adds a new node and assigns 
        a new node id if the working model is constructing; it will return None if the mode is processing. 
        
        When checking the node id, if parent_id is not None, then the node id is decided by (parent_id, key); 
        otherwise, it is decided by key. 

        When creating a new node that is a dictionary and has node_id, we will store an empty dictionary in node_id_dict[node_id], 
        which will be used later by its children. 
        
        """

        # decide which dictionary to use
        if parent_id is not None:  
            the_dict = self.node_id_dict[parent_id]
        else:
            the_dict = self.node_id_dict

        if key in the_dict:
            # the node exisits
            node_id = the_dict[key]
            is_new = False 

        elif constructing_mode is True:
            # need to create the new node 
            # the global id increases by one as the new node id. NO other function should touch self.global_id  
            self.global_id += 1
            node_id = self.global_id
            the_dict[key] = node_id
            is_new = True 

        else: #in processing mode, but the key is not in the dict 
            #print(the_dict)
            print("Cannot find key in the node dict. Parent id and key string are: " + str(parent_id) + ', ' + key)
            return -1, False 

        return node_id, is_new


    def _process_node(self, parent_id, key, value, constructing_mode=False, graph=None):
    
        """
        process a node in the form "parent: {key: value}". parent_id is an integer id for the dictionary; `key` is a 
        key in the dictionary; and `value` is the value corresponding to the `key` in the dictionary. 

        In construction mode, a new node will be created for (parent, key), and the node id is stored at `node_id_dict[parent][key]`. 
        If use_keyid, then the node id is stored at `node_id_dict[key]`.   

        If `value` is a dictionary, then create a new entry `node_id_dict[node_id] = dict()`, which is ready to store entries in value. 

        If `value` is category, then create a child node for each possible category value, and connect the node to the child node. 
        At the same time, `node_id_dict[node_id][category] = child_node_id` stores possible values. 

        If `value` is string, and connect the node to the string node. At the same time, `node_id_dict[node_id][string] = string_node_id` 
        stores all possible string values. 

        If `value` is list of strings, and connect the node to every string node of a string in the list.

        """

        # get options from the key string
        key, use_keyid, neglect_str, no_parent, is_category = self._read_masked_str(key)

        if is_category and (type(value) is not int):
            is_category = False
            print("The value of this categorical node is not an integer: " + key + ": " + value + ". Neglected")


        # add the node into the dictionary  
        # try to find the node by checking the key if use_keyid, or check (parent_id, key) otherwise 
        if use_keyid:
            node_id, is_new = self._check_node(key, parent_id=None, constructing_mode=constructing_mode)
        else:
            node_id, is_new = self._check_node(key, parent_id, constructing_mode=constructing_mode)
        
        if not constructing_mode and is_new:
            print(key,value)
            return -1,key

        # if the node is not present in "training" data, then neglect it. 
        if node_id == -1:
            return node_id

        # connect this node with its parent when there is such an option
        if (not constructing_mode) and (not no_parent):  
            graph["edge_list"].append((node_id, parent_id, etype.tree))


        # process key string if there is such an option 
        if (not use_keyid) and (not neglect_str):
            str_id, _ = self._check_node(key, parent_id=None, constructing_mode=constructing_mode)
            # str_id cannot be -1 because it has been seen when checking node_id

            if constructing_mode:
                self.node_type[str_id] = ntype.string
            
            # if precessing mode, connect the node with the key string 
            if not constructing_mode:  
                graph["edge_list"].append((node_id, str_id, etype.key))
 

        # the dict value will be recursively processed later. the node is internal node
        if type(value) is dict:

            if constructing_mode and is_new:
                # Make the dict ready for later processing
                self.node_id_dict[node_id] = dict()
                self.node_type[node_id] = ntype.dict


        # create a string node if the value is a new string
        elif type(value) is str:
            str_id, _ = self._check_node(value, parent_id=None, constructing_mode=constructing_mode)

            # record all possible strings as values of this node 
            if constructing_mode:
                if is_new:
                    self.node_id_dict[node_id] = dict()

                self.node_id_dict[node_id][value] = str_id

                self.node_type[str_id] = ntype.string
                self.node_type[node_id] = ntype.strval

            if not constructing_mode: 
                # add an edge between the node and a string node if the value is a string         
                graph["node_value"][node_id] = str_id 
                graph["edge_list"].append((node_id, str_id, etype.value))


        # If categorical, create nodes for possible values 
        elif (type(value) is int) and is_category:
            # create a new node (node_id, value) with value being an integer. 
            cat_id, is_new = self._check_node(value, parent_id=node_id, constructing_mode=constructing_mode)

            if constructing_mode:
                if is_new:
                    self.node_id_dict[node_id] = dict()

                self.node_id_dict[node_id][value] = cat_id
                self.node_type[cat_id] = ntype.string # a category is just like a constant string 
                self.node_type[node_id] = ntype.category

            if not constructing_mode: 
                # add an edge between the node and a string node if the value is a string         
                graph["node_value"][node_id] = cat_id 
                graph["edge_list"].append((node_id, cat_id, etype.value))

        # if value is a list, and an string entry in the list is new, then create a string node 
        # for every entry in the list, 
        elif type(value) is list:

            if constructing_mode:
                self.node_type[node_id] = ntype.list

            for entry in value: 
                if (type(entry) is str):
                    # check the string from the dictionary
                    str_id, _ = self._check_node(entry, parent_id=None, constructing_mode=constructing_mode)

                    if constructing_mode:
                        self.node_type[str_id] = ntype.string
                    if not constructing_mode: 
                        # add an edge between the node and a string node if the entry is a string         
                        graph["edge_list"].append((node_id, str_id, etype.value))

        
        elif type(value) is bool: 
            if constructing_mode: 
                self.node_type[node_id] = ntype.binary
            if not constructing_mode: 
                graph["node_value"][node_id] = float(value) 

        elif type(value) in [int, float]: 
            if constructing_mode: 
                self.node_type[node_id] = ntype.numerical

            if not constructing_mode: 
                graph["node_value"][node_id] = float(value) 
        else:
            print("Node type is ", type(value))

        return node_id
        
    def process(self, json_obj, task):
        """
        Processing a json file. If it works in the constructing mode, then it updates the two dictionaries. 
        If it works in the processing mode, it returns a node array of node features and an edge list as a graph. 
        args:
            json_obj: a dict; it should be a nested dictionary  
        returns: 
            node_array: feature vector extracted from the node array
            edge_list: a list of tuples, each tuple (i, j, att_ij ) contains the edge and its type. 
        """
       
        # do preprocessing 
        json_obj = self.prune_json(json_obj, task=task)
        json_obj = self.mark_json(json_obj, task=task)


        # either construct new graph nodes or extract a graph 
        if self.constructing_mode:
            graph = None

        else:
            num_nodes = self.global_id + 1
            graph = dict(node_value=np.zeros(num_nodes), edge_list=[])

        q = queue.Queue()
        q.put( (json_obj, 0) )

        while (not q.empty()): 
            (node, node_id) = q.get() # node is a dictionary. please check self._is_expandable()

            for (k, v) in node.items():
                child_id = self._process_node(node_id, k, v, constructing_mode=self.constructing_mode, graph=graph) 
               
                if type(child_id) is tuple:

                    child_id, key = child_id
                    raise Exception(f"New:{key}")

                if (child_id >= 0) and (type(v) is dict):
                    q.put((v, child_id))
                    
        

        return graph 


    
    def tidyup(self):

        self.num_nodes = self.global_id + 1
        
        type_array = np.zeros(self.num_nodes, dtype=int)
        for i in range(self.num_nodes):
            type_array[i] = self.node_type[i]

        self.node_type = type_array

        cat_nodes = []
        ranges = []
        max_cat = 0
        for i in range(self.num_nodes):
            if self.node_type[i] in [ntype.strval, ntype.category]:
                value_node_ids = list(self.node_id_dict[i].values())

                if len(value_node_ids) == 1:
                    continue

                cat_nodes.append(i)
                ranges.append(value_node_ids)

                if len(value_node_ids) > max_cat:
                    max_cat = len(value_node_ids)
                
        
        for i in range(len(ranges)): 
            padding = [-1] * (max_cat - len(ranges[i]))
            ranges[i] = ranges[i] + padding 

        ranges = np.array(ranges)
        cat_nodes = np.array(cat_nodes) 

        self.cat_ranges = dict(nodes=cat_nodes, ranges=ranges)
            
    
        # create a reverse table for node_id_dict
        # with id as the key, we get the list of keys leading to a json entry or a string 
        
        self.node_table = [None] * self.num_nodes 
        self.node_table[0] = list()
        
        for key,value in self.node_id_dict.items(): 
            if (type(key) is int) and (self.node_type[key] == ntype.dict): 
                # only consider keys that refer to internal nodes 
                # their childrens (possibly leaf nodes) will be recorded in the table below
                assert(type(value) is dict) 
                for k,v in value.items():  
                    assert(type(v) is int)

                    new_list = self.node_table[key].copy()
                    new_list.append(k)
                        
                    self.node_table[v] = new_list 

            elif type(key) is str: 
                assert(type(value) is int)
                self.node_table[value] = key
        
if __name__ == "__main__":

    # test the class

    import json
    import os 
    import pickle

    datapath ="/home/liulp/data/gridworldsData/new_train_data2/" 
    #datapath ="/home/liulp/data/monopoly_val/" 
    task = "gridworld"

    jgraph = JsonToGraph()

    num_epi_train = 100
    epi_files = [ list() for i in range(num_epi_train)]

    for fname in os.listdir(datapath):
        if fname.endswith(".json"):
            epi = int(fname[:fname.index('_')]) - 1
            step = int(fname[(fname.index('_') + 1):fname.index('.')]) - 1
            if epi < num_epi_train:
                epi_files[epi].append(step)
   
    for epi in range(len(epi_files)):
        epi_files[epi].sort()


    jgraph = JsonToGraph()

    for epi in range(len(epi_files)):
        for step in epi_files[epi]:
            fname = str(epi+1) + '_' + str(step+1) + ".json"
                
            with open(os.path.join(datapath, fname)) as file:
                json_obj = json.load(file)

            jgraph.process(json_obj, task)



    # switch to the processing mode and dump pkl data from json files
    jgraph.switch_mode("processing")

    node_feat = []
    graph_lists = []
    epi_length = []

    for epi in range(len(epi_files)):
        
        epi_length.append(len(epi_files[epi]))

        for step in epi_files[epi]:
            fname = str(epi+1) + '_' + str(step+1) + ".json"
                
            with open(os.path.join(datapath, fname)) as file:
                json_obj = json.load(file)

            graph = jgraph.process(json_obj, task)

            node_feat.append(graph["node_value"])
            graph_lists.append(graph["edge_list"])


    

        
    # (total_steps, num_nodes)
    node_feat = np.stack(node_feat, axis=0)

    data = dict(node_feat=node_feat, graph_lists=graph_lists, epi_length=epi_length, node_type=jgraph.node_type)
    


    with open("gridworldtrain.pkl", 'wb') as handle:
        pickle.dump(data, handle)


    # tidy up before storing the json object
    jgraph.tidyup()




    with open("jgraph.pkl", 'wb') as handle:
        pickle.dump(jgraph, handle)


   
    # sanity check 
    print("Start a sanity check ... ")
    problematic = False 

    for nid in range(jgraph.num_nodes):

        if jgraph.node_type[nid] in [ntype.numerical, ntype.binary]:
            val_vec = graph["node_value"][nid]
        elif jgraph.node_type[nid] == ntype.strval: 
            str_id = int(graph["node_value"][nid])

            val_vec = jgraph.node_table[str_id]

            d = json_obj
            
            for key in jgraph.node_table[nid]: 
                d = d[key] 

            val_json = d

            if (val_json != val_vec):
                problematic = True
                print("Value is dfferent at ", jgraph.node_table[nid])
        
    if problematic: 
        print("Sanity check ends. Please check issues above")
    else: 
        print("Sanity check ends. Find no issues.")


