import queue
import types
import numpy as np

# ntype.string
ntype = types.SimpleNamespace(noexist   =0, # the node does not exist in the original graph. 
                        string    =1, # a string in the string dictionary
                        binary    =2, # node value is binary
                        numerical =3, # node value is a number, integer or float
                        strval    =4, # node value is a string
                        strlist   =5, # node value is a list of strings.  
                        dict      =6, # node value is a dictionary, will be expanded
                        list      =7, # node value is a list of other types, will be expanded 
                        category  =8, # node value
                        none      =9, # node value
                       )

etype = types.SimpleNamespace(noexist   =0, # reserved
                        tree      =1, # dictionary - value edge
                        key       =2, # node - string edge. The string is key of the node 
                        value     =3  # node - string edge. The string is either an entry of the list node or the value of the node 
                       )



class JsonToGraph:
    
    def __init__(self):
        """
        Initialize the object with constructing mode
        """
        
        # working mode: constructing mode or processing mode 
        self.constructing_mode = True

        # used to mark strings
        self.str_marker = "8UNQ5RVoTP"

        if self.constructing_mode:
            self.switch_mode("constructing")

    def switch_mode(self, mode=None):
        """
        Two working modes: constructing or processing. If the constructing mode is on, then the object is intialized. 
        """

        if mode == "constructing":
            self.constructing_mode = True
        elif mode == "processing":
            self.constructing_mode = False 
        else:
            raise Exception("No such working mode", mode)

        # initialize all dictionaries
        if self.constructing_mode:

            self.node_type = dict()
            self.node_id_dict = dict()

            # node 0 always represent the entire json object. 
            self.global_id = 0
            self.node_id_dict[0] = dict()
            self.node_type[0] = ntype.dict


    def prune_json(self, json_obj, task):
        
        if task == 'monopoly':
            json_obj = json_obj['prev_state']
        elif task == "gridworld":
            json_obj['goal']['Distribution']= "Uniformed"
            pass
        
        return json_obj

    def mark_json(self, json_obj, task, marker=None):
        """
        mark json object with different markers, so it can be processed later. The purpose is to simplify the json object.  
        Several markers can be created. 
        """
 
        if type(marker) is str:
            self.str_marker = marker

        if task == "monopoly":
            pass     

        elif task == "gridworld": # this part of code is task related. You can do whatever you want here
            for i in range(10):
                for j in range(10):
                    pos = str(i) + ",4," + str(j)
                    # add substring "USEKEYID!" in front of the key pos 
                    # The operation removes pos and add a new entry with key ("USEKEYID!" + pos)
                    new_key = self._mark_str(pos)

                    json_obj["map"][new_key] = json_obj["map"].pop(pos)["name"]

        return json_obj


    
    def _mark_str(self, st, use_keyid=False, neglect_str=False, no_parent=False, is_cateogry=False):
        """
            We mark a key string st by "8UNQ5RVoTP!0000!" + st. Here "0000" are bits indicating the handling of key-value pair. The bits 
            indicates: (USEKEYID, NEGLECTSTR, NOPARENENT, ISCATEGORY)
            use_keyid: if on, use the key as the id of the node, irrespect to the json structure leading to the key. NOTE: if a key 
                         string is used in multiple nodes, then only the last node decides the attribute for the node created 
                         from the key. 
            neglect_str: if off, the key string will be a string node (created before or by this string) and connect this node to the 
                        string; otherwise, do not create a node for the string. Likely the string (e.g. a hash encoding) 
                        only appears once in the json file, then there is no need to create a string node and create a link to this node. 
            no_parenent: if on, do not connect this node to its parent node in the json structure. 
            is_cateogry: this marker indicates that the value of this node is categorical. Then the value will be encoded as one-hot encoding.    
        """
        
        bits = str(int(use_keyid)) + str(int(neglect_str)) + str(int(no_parent)) + str(int(use_keyid))

        marked_str = self.str_marker + '!' + bits + '!' + st

        return marked_str

    def _read_masked_str(self, marked_str):

        use_keyid, neglect_str, no_parenent, is_cateogry=False, False, False, False
        st = marked_str

        if st.startswith(self.str_marker):
            bits = marked_str[len(self.str_marker) + 1: len(self.str_marker) + 5]
            st = marked_str[len(self.str_marker) + 6:]
            
            use_keyid   = bool(int(bits[0]))
            neglect_str = bool(int(bits[1]))
            no_parenent = bool(int(bits[2]))
            is_cateogry = bool(int(bits[3]))

        return st, use_keyid, neglect_str, no_parenent, is_cateogry

    def _str_marked(self, st):
        return st.startswith(self.str_marker)


    def _check_node(self, key, parent_id=None, constructing_mode=False):
        """
        Check the node in the dictionary and return the node id. If the node is not in the dict, it adds a new node and assigns 
        a new node id if the working model is constructing; it will return None if the mode is processing. 
        
        When checking the node id, if parent_id is not None, then the node id is decided by (parent_id, key); 
        otherwise, it is decided by key. 

        When creating a new node that is a dictionary and has node_id, we will store an empty dictionary in node_id_dict[node_id], 
        which will be used later by its children. 
        
        """

        # decide which dictionary to use
        if parent_id is not None:  
            the_dict = self.node_id_dict[parent_id]
        else:
            the_dict = self.node_id_dict

        if key in the_dict:
            # the node exisits
            node_id = the_dict[key]
            is_new = False 

        elif constructing_mode is True:
            # need to create the new node 
            # the global id increases by one as the new node id. NO other function should touch self.global_id  
            self.global_id += 1
            node_id = self.global_id
            the_dict[key] = node_id
            is_new = True 

        else: #in processing mode, but the key is not in the dict 
            print("Cannot find key in the node dict. Parent id and key string are: " + str(parent_id) + ', ' + key)
            return -1, False 

        return node_id, is_new


    def _process_node(self, parent_id, key, value, constructing_mode=False, graph=None):
    
        """
        process a node in the form "parent: {key: value}". parent_id is an integer id for the dictionary; `key` is a 
        key in the dictionary; and `value` is the value corresponding to the `key` in the dictionary. 

        In construction mode, a new node will be created for (parent, key), and the node id is stored at `node_id_dict[parent][key]`. 
        If use_keyid, then the node id is stored at `node_id_dict[key]`.   

        If `value` is a dictionary, then create a new entry `node_id_dict[node_id] = dict()`, which is ready to store entries in value. 

        If `value` is category, then create a child node for each possible category value, and connect the node to the child node. 
        At the same time, `node_id_dict[node_id][category] = child_node_id` stores possible values. 

        If `value` is string, and connect the node to the string node. At the same time, `node_id_dict[node_id][string] = string_node_id` 
        stores all possible string values. 

        If `value` is list of strings, and connect the node to every string node of a string in the list.

        """

        # get options from the key string
        key, use_keyid, neglect_str, no_parent, is_category = self._read_masked_str(key)

        if is_category and (type(value) is not int):
            is_category = False
            print("The value of this categorical node is not an integer: " + key + ": " + value + ". Neglected")


        # add the node into the dictionary  
        # try to find the node by checking the key if use_keyid, or check (parent_id, key) otherwise 
        if use_keyid:
            node_id, is_new = self._check_node(key, parent_id=None, constructing_mode=constructing_mode)
        else:
            node_id, is_new = self._check_node(key, parent_id, constructing_mode=constructing_mode)
        
        # if the node is not present in "training" data, then neglect it. 
        if node_id == -1:
            return node_id

        # connect this node with its parent when there is such an option
        if (not constructing_mode) and (not no_parent):  
            graph["edge_list"].append((node_id, parent_id, etype.tree))


        # process key string if there is such an option 
        if (not use_keyid) and (not neglect_str):
            str_id, _ = self._check_node(key, parent_id=None, constructing_mode=constructing_mode)
            # str_id cannot be -1 because it has been seen when checking node_id

            if constructing_mode:
                self.node_type[str_id] = ntype.string
            
            # if precessing mode, connect the node with the key string 
            if not constructing_mode:  
                graph["edge_list"].append((node_id, str_id, etype.key))
 

        # the dict value will be recursively processed later. the node is internal node
        if type(value) is dict:

            if constructing_mode and is_new:
                # Make the dict ready for later processing
                self.node_id_dict[node_id] = dict()
                self.node_type[node_id] = ntype.dict


        # create a string node if the value is a new string
        elif type(value) is str:
            str_id, _ = self._check_node(value, parent_id=None, constructing_mode=constructing_mode)

            # record all possible strings as values of this node 
            if constructing_mode:
                if is_new:
                    self.node_id_dict[node_id] = dict()

                self.node_id_dict[node_id][value] = str_id

                self.node_type[str_id] = ntype.string
                self.node_type[node_id] = ntype.strval

            if not constructing_mode: 
                # add an edge between the node and a string node if the value is a string         
                graph["node_value"][node_id] = str_id 
                graph["edge_list"].append((node_id, str_id, etype.value))


        # If categorical, create nodes for possible values 
        elif (type(value) is int) and is_category:
            # create a new node (node_id, value) with value being an integer. 
            cat_id, is_new = self._check_node(value, parent_id=node_id, constructing_mode=constructing_mode)

            if constructing_mode:
                if is_new:
                    self.node_id_dict[node_id] = dict()

                self.node_id_dict[node_id][value] = cat_id
                self.node_type[cat_id] = ntype.string # a category is just like a constant string 
                self.node_type[node_id] = ntype.category

            if not constructing_mode: 
                # add an edge between the node and a string node if the value is a string         
                graph["node_value"][node_id] = cat_id 
                graph["edge_list"].append((node_id, cat_id, etype.value))

        # if value is a list, and an string entry in the list is new, then create a string node 
        # for every entry in the list, 
        elif type(value) is list:

            if constructing_mode:
                self.node_type[node_id] = ntype.list

            for entry in value: 
                if (type(entry) is str):
                    # check the string from the dictionary
                    str_id, _ = self._check_node(entry, parent_id=None, constructing_mode=constructing_mode)

                    if constructing_mode:
                        self.node_type[str_id] = ntype.string
                    if not constructing_mode: 
                        # add an edge between the node and a string node if the entry is a string         
                        graph["edge_list"].append((node_id, str_id, etype.value))

        
        elif type(value) is bool: 
            if constructing_mode: 
                self.node_type[node_id] = ntype.binary
            if not constructing_mode: 
                graph["node_value"][node_id] = float(value) 

        elif type(value) in [int, float]: 
            if constructing_mode: 
                self.node_type[node_id] = ntype.numerical

            if not constructing_mode: 
                graph["node_value"][node_id] = float(value) 
        else:
            print("Node type is ", type(value))

        return node_id
        
    def process(self, json_obj, task):
        """
        Processing a json file. If it works in the constructing mode, then it updates the two dictionaries. 
        If it works in the processing mode, it returns a node array of node features and an edge list as a graph. 
        args:
            json_obj: a dict; it should be a nested dictionary  
        returns: 
            node_array: feature vector extracted from the node array
            edge_list: a list of tuples, each tuple (i, j, att_ij ) contains the edge and its type. 
        """
       
        # do preprocessing 
        json_obj = self.prune_json(json_obj, task=task)
        json_obj = self.mark_json(json_obj, task=task)


        # either construct new graph nodes or extract a graph 
        if self.constructing_mode:
            graph = None

        else:
            num_nodes = self.global_id + 1
            graph = dict(node_value=np.zeros(num_nodes), edge_list=[])

        q = queue.Queue()
        q.put( (json_obj, 0) )

        while (not q.empty()): 
            (node, node_id) = q.get() # node is a dictionary. please check self._is_expandable()

            for (k, v) in node.items():
                child_id = self._process_node(node_id, k, v, constructing_mode=self.constructing_mode, graph=graph) 

                if (child_id >= 0) and (type(v) is dict):
                    q.put((v, child_id))
        

        return graph 


    
    def tidyup(self):

        self.num_nodes = self.global_id + 1
        
        type_array = np.zeros(self.num_nodes, dtype=int)
        for i in range(self.num_nodes):
            type_array[i] = self.node_type[i]

        self.node_type = type_array

        cat_nodes = []
        ranges = []
        max_cat = 0
        for i in range(self.num_nodes):
            if self.node_type[i] in [ntype.strval, ntype.category]:
                value_node_ids = list(self.node_id_dict[i].values())

                if len(value_node_ids) == 1:
                    continue

                cat_nodes.append(i)
                ranges.append(value_node_ids)

                if len(value_node_ids) > max_cat:
                    max_cat = len(value_node_ids)
                
        
        for i in range(len(ranges)): 
            padding = [-1] * (max_cat - len(ranges[i]))
            ranges[i] = ranges[i] + padding 

        ranges = np.array(ranges)
        cat_nodes = np.array(cat_nodes) 

        self.cat_ranges = dict(nodes=cat_nodes, ranges=ranges)


if __name__ == "__main__":

    # test the class

    import json
    import os 
    import pickle

    datapath ="/home/plymper/data/gridworldsData/new_train_data2/" 
    #datapath ="/home/liulp/data/monopoly_val/" 
    task = "gridworld"

    jgraph = JsonToGraph()

    num_epi_train = 100
    epi_files = [ list() for i in range(num_epi_train)]

    for fname in os.listdir(datapath):
        if fname.endswith(".json"):
            epi = int(fname[:fname.index('_')]) - 1
            step = int(fname[(fname.index('_') + 1):fname.index('.')]) - 1
            if epi < num_epi_train:
                epi_files[epi].append(step)
   
    for epi in range(len(epi_files)):
        epi_files[epi].sort()


    jgraph = JsonToGraph()

    for epi in range(len(epi_files)):
        for step in epi_files[epi]:
            fname = str(epi+1) + '_' + str(step+1) + ".json"
                
            with open(os.path.join(datapath, fname)) as file:
                json_obj = json.load(file)

            jgraph.process(json_obj, task)



    # switch to the processing mode and dump pkl data from json files
    jgraph.switch_mode("processing")

    node_feat = []
    graph_lists = []
    epi_length = []

    for epi in range(len(epi_files)):
        
        epi_length.append(len(epi_files[epi]))

        for step in epi_files[epi]:
            fname = str(epi+1) + '_' + str(step+1) + ".json"
                
            with open(os.path.join(datapath, fname)) as file:
                json_obj = json.load(file)

            graph = jgraph.process(json_obj, task)

            node_feat.append(graph["node_value"])
            graph_lists.append(graph["edge_list"])
        
    # (total_steps, num_nodes)
    node_feat = np.stack(node_feat, axis=0)

    data = dict(node_feat=node_feat, graph_lists=graph_lists, epi_length=epi_length, node_type=jgraph.node_type)
    


    with open("gridworldtrain.pkl", 'wb') as handle:
        pickle.dump(data, handle)


    # tidy up before storing the json object
    jgraph.tidyup()
    with open("jgraph.pkl", 'wb') as handle:
        pickle.dump(jgraph, handle)

   
