"""
code for converting a json file to data that can be used by GNNs
"""
import json
from os import listdir
import pickle, os
import numpy as np
from types import SimpleNamespace  


# ntype.string
ntype = SimpleNamespace(noexist   =0, # the node does not exist in the original graph. 
                        string    =1, # a string in the string dictionary
                        binary    =2, # node value is binary
                        numerical =3, # node value is a number, integer or float
                        strval    =4, # node value is a string
                        strlist   =5, # node value is a list of strings.  
                        dict      =6, # node value is a dictionary, will be expanded
                        array     =7, # node value is an array in the form of a list, will be expanded 
                        list      =8, # node value is a list of other types, will be expanded 
                        none      =9  # node value is a list of other types, will be expanded 
                       )

etype = SimpleNamespace(noexist   =0, # reserved
                        tree      =1, # dictionary - value edge
                        key       =2, # node - string edge. The string is key of the node 
                        value     =3  # node - string edge. The string is either an entry of the list node or the value of the node 
                       )
         

def is_str_list(lst):
    """
    Check whether list elements are all strings 
    """
    assert(isinstance(lst, list))

    flag = True
    # check whether the list is a vector
    for elem in lst:
        if not isinstance(elem, str):  
            flag = False
            break

    return flag
 
def is_vector(lst):
    """
    Check whether the list can be converted as a vector
    """
    assert(isinstance(lst, list))

    flag = True

    # check whether the list is a vector
    for elem in lst:
        if not isinstance(elem, (int, float)):  
            flag = False
            break

    return flag 
 

def is_matrix(lst):
    """
    Check whether a list is an array
    TODO: check this function
    """

    assert(isinstance(lst, list))

           
    flag = True
    
    # make sure all elements are lists and have the same length
    row_size = -1 

    for elem in lst:

        # make sure that every element is a list 
        if not isinstance(elem, (list, np.ndarray)):
            flag = False
            break
            
        # record the first size
        if row_size == -1:
            row_size = len(elem)
        else:
            
            # then check whether the size is the same as the size of the first list
            if not row_size == len(elem):
                flag = False 
                break
 
        if isinstance(elem, list) and (not is_vector(elem)):
            # if the list is not a vector
            flag = False 
            break

    return flag 

def traverse_flatten(key, value, parent_name, str_dict, node_list):
    """
    get all node string in the following form:
    grand_parent_name->parent_name->child_name
    """

    # get the current path_name
    if parent_name is None:
        node_name = '->' + key 
    else:
        node_name = parent_name + '->' + key 

    # record node names in the list
    node_list.append(node_name)

    # record key string 
    if not key.startswith('LIST'): #NOTE: keys staring with 'LIST' are list elements. 
        if key not in str_dict:
            str_dict[key] = len(str_dict)
    
    # record string value 
    if isinstance(value, str):
        if value not in str_dict:
            str_dict[value] = len(str_dict)

    # record strings in a list of strings
    if isinstance(value, list) and is_str_list(value):
        for s in value:
            if s not in str_dict:
                str_dict[s] = len(str_dict)

    # recursion
    # expand children if value is a dict
    if isinstance(value, dict):
        kv_items = value.items()

        # recursively process its children 
        for i, (k, v) in enumerate(kv_items):
            traverse_flatten(k, v, node_name, str_dict, node_list)

    elif isinstance(value, list) and (not is_str_list(value)):  # 
        # a list of strings should be handled separately: the list becomes a node, which is connects to string nodes

        # NOTE: using our knowledge, we should not use more than 10 elements in the list
        if len(value) > 10:
            value = value[:10]

        # recursively process its children 
        fake_keys = ['LIST' + str(i) for i in range(len(value))]
        kv_items = zip(fake_keys, value)

        for i, (k, v) in enumerate(kv_items):
            traverse_flatten(k, v, node_name, str_dict, node_list)

    else: 
        pass


def traverse_construct(key, value, parent_name, node_name_dict, edge_list, node_array):
    """
    construct graph data from the json file
    """

    # get the current path_name
    if parent_name is None:
        node_name = '->' + key 
    else:
        node_name = parent_name + '->' + key 

    if node_name not in node_name_dict:
        # if the node id is not in the dictionary, do nothing as if the node does not exist.
        return

    # get node id for the current node
    node_id = node_name_dict[node_name]

    # edges from the json tree
    if parent_name is not None:
        parent_id = node_name_dict[parent_name]
        edge_list.append((parent_id, node_id, etype.tree))


    # connect an edge because the key matches a string node
    if not key.startswith('LIST'):
        key_id = node_name_dict[key]
        edge_list.append((node_id, key_id, etype.key))
    else:
        key_id = -1


    # get a scalar from value 
    if isinstance(value, dict): 
        node_type = ntype.dict
        scalar_val = 0 

    elif isinstance(value, list) and is_str_list(value):

        node_type = ntype.strlist
        scalar_val = len(value) 

        for s in value: 
            str_node_id = node_name_dict[s]
            edge_list.append((node_id, str_node_id, etype.value))

    elif isinstance(value, list) and (is_vector(value) or is_vector(value)):

        node_type = ntype.array
        scalar_val = len(value) # TODO: compute the mean of the array

    elif isinstance(value, list): # not (is_vector(value) or is_vector(value) or is_str_list(value))

        node_type = ntype.list
        scalar_val = len(value) 

    elif isinstance(value, str):

        node_type = ntype.strval
        scalar_val = node_name_dict[value] 

        value_id = node_name_dict[value]
        edge_list.append((node_id, value_id, etype.value))

    elif isinstance(value, bool):
        node_type = ntype.binary 
        scalar_val = value 

    elif isinstance(value, int):
        node_type = ntype.numerical 
        scalar_val = value 

    elif isinstance(value, float):
        node_type = ntype.numerical
        scalar_val = value 

    elif value is None:
        node_type = ntype.none
        scalar_val = 0 

    else:
        raise Exception('The value type is not considered yet: ', v)

    # store node features
    # if a node is not in the graph, it is an isolated node, and its type is `ntype.noexist`
    node_array[node_id] = [node_type, key_id, scalar_val]

    # recursion
    # need to further expand its children    
    if isinstance(value, dict):
        kv_items = value.items()
        for i, (k, v) in enumerate(kv_items):
            traverse_construct(k, v, node_name, node_name_dict, edge_list, node_array)

    elif isinstance(value, list) and (not is_str_list(value)):

        # NOTE: use at most 10 elements from the list
        if len(value) > 10:
            value = value[:10]

        # recursively process its children 
        fake_keys = ['LIST' + str(i) for i in range(len(value))]
        kv_items = zip(fake_keys, value)

        for i, (k, v) in enumerate(kv_items):
            traverse_construct(k, v, node_name, node_name_dict, edge_list, node_array)

    else: 
        # no need to expand
        pass
class JsonToGraph:
    
    def __init__(self, path, game = 'monopoly'):
        """
        input:
            path: the path to the node dictionary
            game: the name of the game, 'monopoly' or 'gridworld'
        """

        # load common path names and strings
        dict_file = os.path.join(path, 'node_name_dict.pkl') 
        with open(dict_file, 'rb') as infile:  
            node_name_dict = pickle.load(infile)
        
        num_json_nodes = 0
        for k, v in node_name_dict.items():
            if k.startswith("->"):
                num_json_nodes += 1
        
        self.node_name_dict = node_name_dict
        self.num_json_nodes = num_json_nodes

        self.game = game
        print('The dictionary has ', len(node_name_dict), ' names, ', num_json_nodes, ' of which are actual json nodes.')

    @classmethod
    def prune_json(cls, json_obj, game):
        
        if game == 'monopoly':
            json_obj = json_obj['prev_state']
        else:
            json_obj['goal']['Distribution']= "Uniformed"
            json_obj['step']=0
            pass
        
        return json_obj
    
    def process(self, json_obj):

        # prune the json object first
        json_obj = self.prune_json(json_obj, self.game)

        # extract a graph from the object
        edge_list = []
        node_array = np.zeros([len(self.node_name_dict), 3])

        # set the type of string nodes
        num_string_nodes = len(self.node_name_dict) - self.num_json_nodes
        node_array[0:num_string_nodes, 0] = ntype.string 
        
        traverse_construct('root', json_obj, None, self.node_name_dict, edge_list, node_array)
        
        return node_array, edge_list
        


def get_all_strs_nodes(data_path, json_files, game = 'monopoly'):
    """
    traverse the json file and get all common "path names"
    """

    # all files share the same string dict
    str_dict = {}
    is_first = True

    for i, f in enumerate(json_files): 

        batch_size = 1000
        if i % batch_size == 0:
            print('processing batch ', i // batch_size, '*', batch_size, '...')
        
        jfile_path = os.path.join(data_path, f) 
        with open(jfile_path) as jfile:
            json_obj = json.load(jfile)

        #TODO: may need some fix 
        json_obj = JsonToGraph.prune_json(json_obj, game=  game)

        node_list = []
        traverse_flatten('root', json_obj, None, str_dict, node_list)

        if len(node_list) < 10:
            print('The file contains less than 10 nodes: ', f)
            continue

        node_set = set(node_list)

        if is_first:
            is_first = False
            union = node_set
        else:
            union = union.union(node_set)
    

    print('Find ', len(str_dict), 'unique strings')
    print('last node list has ', len(node_list), ' common node paths')

    common_node_list = union 

    # keep the original order of node names so it is easy to print and verify?
    #for node_name in node_list:
    #    if (node_name in union) and (node_name not in common_node_list):
    #        common_node_list.append(node_name)

    print('Find ', len(common_node_list), ' node paths')

    node_name_dict = str_dict
    for node_name in common_node_list:
        node_name_dict[node_name] = len(node_name_dict)
        

    with open(os.path.join(data_path,'node_name_dict.pkl'), 'wb') as output:  # Overwrites any existing file.
        pickle.dump(node_name_dict, output)


def construct_graphs(data_path,json_files, game='monopoly'):
    """
    convert json files to graph data
    """

    # separate names and then colligate episodes
    json_files.sort()
    epi_dict = dict()
    for i, f in enumerate(json_files): 
        epi,step = f.split('.')[0].split('_')[-2:]
        epi = int(epi)
        step = int(step)

        if epi not in epi_dict:
            epi_dict[epi] = [step]
        else:
            epi_dict[epi].append(step)

    # make sure the number of steps is the number of json files
    for epi in epi_dict: 
        epi_dict[epi].sort()
        assert(len(epi_dict[epi]) == epi_dict[epi][-1])

    json2graph = JsonToGraph(data_path,game)

    # collecting data from json files
    data = []
    print('Total ', len(epi_dict), 'episodes')

    import collections
    order_dict = collections.OrderedDict(sorted(epi_dict.items()))
    for epi in order_dict:
        print('processing episode ', epi)

        epi_data = []
        prefix = 'json_msg_' if game == 'monopoly' else ""
        for step in order_dict[epi]:

            f = prefix + str(epi) + '_' + str(step)  + '.json'
            print(f)
            jfile_path = os.path.join(data_path, f)

            with open(jfile_path) as jfile:
                json_obj = json.load(jfile)

            node_array, edge_list = json2graph.process(json_obj)

            epi_data.append((node_array, edge_list))


        data.append(epi_data)

   
    # check the data for a particular file  episode 10, step 123
    #example = data_dict[10]['nodearrays'][122]
    #str_list = dict(zip(str_dict.values(), str_dict.keys()))

    #for i in range(example.shape[0]):

    #    if example[i, 0] <= 1:
    #        print(str_list[example[i, 1]])

    #    elif example[i, 0] == 2:
    #        print(str_list[example[i, 1]], ':', str_list[example[i, 2]])

    #    elif example[i, 0] > 2:
    #        print(str_list[example[i, 1]], ':', example[i, 2])

    # pickle the data
    with open(os.path.join(data_path,'processed_data.pkl'), 'wb') as output:  # Overwrites any existing file.
        pickle.dump(data, output)

if __name__ == "__main__":
    """
    A small test of the function defined above
    """
    
    data_path = '/home/plymper/data/gridworldsData/novel_nonov'
    game = 'gridworld'

    json_files = [f for f in listdir(data_path) if f.endswith('json')]

    # if you want to try a small sample, you need to comment out the assertion around line 379,
    # which is used to check where an episode has consecutive time steps
    #json_files = json_files[:2000]

    print('Total ', len(json_files), ' json files')

    # if not os.path.isfile('node_name_dict.pkl'):
    #get_all_strs_nodes(data_path, json_files, game = game)

    construct_graphs(data_path,json_files,game = game)
