import numpy as np
import torch
import torch_geometric.data  as data
import os
import pickle
import json_graph
from json_graph import JsonToGraph

class GameDataset(data.Dataset):

    def __init__(self, data_path, concat_steps, training=False):
        """
        Load data either from a path or from an list (data_obj). The list should be: 
        [episode1, episode2, \ldots, episodeM]
        
        Each episode should be a tuple: (node_array, edge_lists). The node array
        stores node features for all graph nodes over different time steps. Edge lists
        should be edges from a sequence of graphs.  
        """

        self.concat_steps = concat_steps
        self.node_feature_dim = concat_steps
        self.edge_feature_dim = len(json_graph.etype.__dict__)

        with open(os.path.join(data_path, 'jgraph.pkl'), 'rb') as file:
            self.jgraph = pickle.load(file)

        self.node_type = self.jgraph.node_type
        self.cat_ranges = self.jgraph.cat_ranges

        if training: 

            with open(os.path.join(data_path, 'gridworldtrain.pkl'), 'rb') as file:
                data = pickle.load(file)

            self.all_node_array = data["node_feat"]
            self.all_graphs = data["graph_lists"]
            epi_sizes = data["epi_length"]

            # get binary vector indicating which step is the last one in its episode 
            num_total = self.all_node_array.shape[0]
            num_epi = len(epi_sizes)

            flag_beginning_steps = np.zeros(num_total)
            beginning_ind = 0
            for i in range(num_epi):
                flag_beginning_steps[beginning_ind : (beginning_ind + self.concat_steps - 1)] = 1
                beginning_ind = beginning_ind + epi_sizes[i]

            # delete indices of steps that are at the end of episodes
            index = np.arange(num_total)
            self.real_index = np.delete(index, np.where(flag_beginning_steps))

        else:

            
            self.buffer_node_feat = []
            self.buffer_graphs = []

        

    def _generate_edge_list(self, attributed_edge_list):


        attributed_edges = np.array(attributed_edge_list).transpose()

        edges = attributed_edges[0:2, :]
        atts = attributed_edges[2, :]

        # convert graph attributes to one-hot encoding
        onehot = np.zeros((atts.shape[0], len(json_graph.etype.__dict__)))
        onehot[np.arange(atts.shape[0]), atts] = 1

        return edges, onehot

    
    def __getitem__(self, index):

        real_ind = self.real_index[index]

        node_feats = self.all_node_array[(real_ind - self.concat_steps + 1) : (real_ind + 1)]
        
        #graphs = self.all_graphs[(real_ind - self.concat_steps + 1) : (real_ind + 1)]

        # only use the last graph 
        edges, atts = self._generate_edge_list(self.all_graphs[real_ind - 1])

        graph_data = data.Data(x=node_feats, edge_index=torch.tensor(edges, dtype=torch.long), edge_attr=atts.astype(np.float32))

        return graph_data 


    def json_to_graph_stack(self, json_obj, episode_begin=False):

        if episode_begin: 
            self.buffer_node_feat = []
            self.buffer_graphs = []

        new_graph = self.jgraph.process(json_obj)

        self.buffer_node_feat.append(new_graph["node_value"])
        self.buffer_graphs.append(new_graph["edge_list"]) 

        if len(self.buffer_node_feat) >= self.concat_steps:
            node_feats = self.buffer_node_feat[-self.concat_steps : ]

            # graphs = self.buffer_graphs[-self.concat_steps : ]
            graph = self.buffer_graphs[-2]

            # update the buffer
            self.buffer_node_feat.pop(0)
            self.buffer_graphs.pop(0)

            # only use the last graph 
            edges, atts = self._generate_edge_list(graph)
            graph_data = data.Data(x=node_feats, edge_index=torch.tensor(edges, dtype=torch.long), edge_attr=atts)

            return graph_data

        return None

    def __len__(self):

        return self.real_index.shape[0]

        
if __name__ == "__main__":
    
    # on test case 
    train_set = GameDataset(".", concat_steps=3, training=True)

    graph_data = train_set[2]
    print(graph_data.x.shape)



    # another test case
    test_set = GameDataset(".", concat_steps=3, training=False)

    import json

    datapath ="/home/liulp/data/gridworldsData/new_train_data2/" 

    for epi in range(2):
        episode_begin=True
        for step in range(20):
            fname = str(epi+1) + '_' + str(step+1) + ".json"
                
            with open(os.path.join(datapath, fname)) as file:
                json_obj = json.load(file)

            json_obj = test_set.jgraph.prune_json(json_obj, task="gridworld")
            json_obj = test_set.jgraph.mark_json(json_obj, task="gridworld")

            graph_data = test_set.json_to_graph_stack(json_obj)

            episode_begin = False


