from ast import Raise
import numpy as np
from pandas import concat
import torch.utils.data as data
import torch
import torch_geometric.data as data
import os
import pickle
#from dataset import json_graph
#from dataset.json_graph import JsonToGraph, NewNodeException
import sys
#import json2vec 
import dataset.gridworld_json2vec as gridworld_json2vec
import dataset.gridworld_novelties as gridworld_novelties
import copy 

class GameDataset(data.Dataset):

    def __init__(self, data_path, concat_steps, mode='training',ignore_intermediate_nodes = False):
        """
        Load data either from a path or from an list (data_obj). The list should be: 
        [episode1, episode2, \ldots, episodeM]
        
        Each episode should be a tuple: (node_array, edge_lists). The node array
        stores node features for all graph nodes over different time steps. Edge lists
        should be edges from a sequence of graphs.  
        """


        self.node_feature_dim = concat_steps - 1
        self.mode = mode
        self.feature_list = []

        self.concat_steps = concat_steps
        self.node_info = dict()
        self.node_info["cat_nodes"] = np.array(gridworld_json2vec.data_info['cat_entries'])
        self.node_info["cat_ranges"] = np.array(gridworld_json2vec.data_info['cat_ranges'])
        self.node_info['nume_nodes'] = np.array(gridworld_json2vec.data_info['nume_entries'])
        self.node_info['bin_nodes'] = np.array(gridworld_json2vec.data_info['bin_entries'])
        self.node_info['random_nodes'] = gridworld_json2vec.data_info['random_entries']
        
        #self.num_nodes = self.node_type.shape[0]
        

        if mode == 'training' or mode == "validation": 

            # get data from a pickle file
            
            with open(os.path.join(data_path), 'rb') as file:            
                dataset = pickle.load(file)
    
            
            self.all_node_array = dataset['data']
            self.epi_sizes = dataset["ep_lengths"]
            self.num_nodes = self.all_node_array.shape[-1]
            
            self._create_real_index(self.epi_sizes)
    
        else: #Test Mode init
            #data path is valid path to acquire scaler
            with open(os.path.join(data_path), 'rb') as file:            
                dataset = pickle.load(file)

            
            self.scaler = dataset['scaler']

            #initialize relevant object to be filled in as jsons arrive
            self.num_nodes = dataset['data'].shape[-1]
            self.all_node_array =None
            self.epi_sizes = [0]
            self.real_index = None

            self.buffer_node_feat = []
            self.buffer_graphs = []
            self.labels = []
            self.node_labels=[]
            

    
    def receive_json_obj(self, json_obj, new_episode = False):
        '''
        Process json object and add it to dataset. If enough jsons received, prepare data for running.

            json_obj: A parsed JSON object. 
        '''

        
        
        x = gridworld_json2vec.process_no_onehot(json_obj)
        x = self._normalize_numerical_nodes(x)
        
        self.all_node_array = np.expand_dims(x, axis=0) if self.all_node_array is None else np.concatenate([self.all_node_array,np.expand_dims(x,axis = 0)],axis =0)

        
        if not new_episode:
            self.epi_sizes[-1]+=1
        else:
            self.epi_sizes.append(1)

        self._create_real_index(self.epi_sizes)
    
    
    def _normalize_numerical_nodes(self,x):
        
        if len(x.shape)==2:
            nume_nodes = x[:,self.node_info['nume_nodes']]
            nume_nodes = self.scaler.transform(nume_nodes)
            x[:,self.node_info['nume_nodes']] = nume_nodes
        else:
            nume_nodes = x[self.node_info['nume_nodes']]
            nume_nodes = self.scaler.transform(nume_nodes.reshape(1,-1))
            x[self.node_info['nume_nodes']] = nume_nodes.reshape(-1)
        return x

    def _invert_normalize_numerical_nodes(self,x):
        
        if len(x.shape)==2:
            nume_nodes = x[:,self.node_info['nume_nodes']]
            nume_nodes = self.scaler.inverse_transform(nume_nodes)
            x[:,self.node_info['nume_nodes']] = torch.tensor(np.round(nume_nodes)).float()
        else:
            nume_nodes = x[self.node_info['nume_nodes']]
            nume_nodes = self.scaler.inverse_transform(nume_nodes.reshape(1,-1))
            x[self.node_info['nume_nodes']] = torch.tensor(np.round(nume_nodes.reshape(-1))).float()
        return x

    def _create_real_index(self, epi_sizes):
        '''
        Creates index array to retrieve data with. 
            epi_sizes: List of episode lengths

        '''


        # get binary vector indicating which step is the last one in its episode 
        num_total = self.all_node_array.shape[0]
        num_epi = len(epi_sizes)
        

        flag_beginning_steps = np.zeros(num_total)
        beginning_ind = 0
        for i in range(num_epi):
            flag_beginning_steps[beginning_ind : (beginning_ind + self.concat_steps - 1)] = 1
            beginning_ind = beginning_ind + epi_sizes[i]
        
        # delete indices of steps that are at the end of episodes
        index = np.arange(num_total)
        self.real_index = np.delete(index, np.where(flag_beginning_steps))

    
    def __getitem__(self, index):
        
        if self.epi_sizes[-1]<self.concat_steps and self.mode == 'test':
            return None
        
        real_ind = self.real_index[index]
        
        # shape is (num_nodes, concat_steps)
        node_feats = self.all_node_array[(real_ind - self.concat_steps + 1) : real_ind].transpose()
        last_state = self.all_node_array[real_ind]

            

        #data.Data(x=torch.tensor(node_feats, dtype=torch.float32).unsqueeze(0),last_state = torch.tensor(last_state, dtype=torch.float32).unsqueeze(0))
        
        graph_data = data.Data(x          = torch.tensor(node_feats, dtype=torch.float32).unsqueeze(0), 
                               edge_index = torch.tensor([], dtype=torch.long), 
                               edge_attr  = torch.tensor([], dtype=torch.float32), 
                               last_state = torch.tensor(last_state, dtype=torch.float32).unsqueeze(0),
                               labels = torch.tensor(np.zeros_like(last_state)))
        
        if self.mode == 'test' and np.random.rand()<1:
            tmp = self._invert_normalize_numerical_nodes(graph_data.x[0].T).T.unsqueeze(0)
            graph_data.x= tmp
            graph_data.last_state  = self._invert_normalize_numerical_nodes(graph_data.last_state)
            a = graph_data.last_state
            graph_data = gridworld_novelties.inject_novelty(graph_data)
            graph_data.x = torch.tensor(self._normalize_numerical_nodes(graph_data.x[0].T.numpy())).T.unsqueeze(0)
            graph_data.last_state = torch.tensor(self._normalize_numerical_nodes(graph_data.last_state.numpy()))
        
        
        return graph_data 


    def __len__(self):

        return self.real_index.shape[0]

        
if __name__ == "__main__":

    from torch.utils.data.dataloader import DataLoader
    import os
    import importlib.resources
    from pathlib import Path
    
    
    DATASET_ROOT = Path("/home/plymper/data/polycraftv2")
    # on test case 
    train_set = GameDataset(data_path=DATASET_ROOT/Path("json_normal_train_data.pkl"), concat_steps=5, mode="training")

    #data_list = [data, data]
    loader = DataLoader(train_set, batch_size=2)
    batch = next(iter(loader))
    
    print(batch)


