from ast import Raise
import numpy as np
import torch
import torch_geometric.data as data
import os
import pickle
from dataset import json_graph
from dataset.json_graph import JsonToGraph, NewNodeException

import dataset.monopoly_novelties as monopoly_novelties



class GameDataset(data.Dataset):

    def __init__(self, data_path, concat_steps, mode='training',ignore_intermediate_nodes = False, task = 'gridworld', inject_novelty_type=None):
        """
        Load data either from a path or from an list (data_obj). The list should be: 
        [episode1, episode2, \ldots, episodeM]
        
        Each episode should be a tuple: (node_array, edge_lists). The node array
        stores node features for all graph nodes over different time steps. Edge lists
        should be edges from a sequence of graphs.  
        """
        if task == 'gridworld' and inject_novelty_type:
            raise ValueError("Cannot inject novelties to gridworld, use preprepared ones")
        if task == 'monopoly' and mode=='test' and not inject_novelty_type:
            raise ValueError("Need to inject novelty to monopoly.")
        
        if task =='monopoly' and mode=='test' and inject_novelty_type:
            print('injecting novelty type', inject_novelty_type)

        self.inject_novelty_type = inject_novelty_type

        with open(os.path.join(data_path, 'jgraph.pkl'), 'rb') as file:
            import sys
            print(data_path,file=sys.stderr)
            self.jgraph = pickle.load(file)
            

        self.node_feature_dim = concat_steps - 1
        self.edge_feature_dim = len(json_graph.etype.__dict__)
        self.ignore_intermediate_nodes= ignore_intermediate_nodes
        self.node_info = dict()
        self.node_info["cat_nodes"] = self.jgraph.cat_ranges["nodes"]
        self.node_info["cat_ranges"] = self.jgraph.cat_ranges["ranges"]
        
        self.node_type = self.jgraph.node_type
        self.num_nodes = self.node_type.shape[0]
        self.task = task
        self.normalizer = dict()

        self.mode = mode
        self.concat_steps = concat_steps
        self.labels = []
        self.node_labels=[]
        self.feature_list = []
        if mode == 'training' or mode == "validation": 

            # get data from a pickle file
            
            with open(os.path.join(data_path, f'{task}.pkl'), 'rb') as file:
                
                dataset = pickle.load(file)
                
#             with open(os.path.join(data_path, 'monopoly.pkl'), 'rb') as file:
#                 dataset = pickle.load(file)
            
            self.all_node_array = dataset["node_feat"]
            self.all_graphs = dataset["graph_lists"]
            epi_sizes = dataset["epi_length"]
            self.epi_sizes = epi_sizes
            #Create and save or load node_info and normalizer
            if mode =='training':
                self._gather_node_info()
                self._make_normalizer_dict()
                with open("gamedata_info.pkl",'wb') as f:
                    pickle.dump({"normalizer":self.normalizer,"node_info":self.node_info},f)
            else:
                with open("gamedata_info.pkl",'rb') as f:
                    tmp = pickle.load(f)
                self.normalizer = tmp['normalizer']
                self.node_info = tmp['node_info']
                
            #Normalize 
            self._normalize_numerical_nodes()
            
            #make negative class -1 for masking to work
            bin_nodes = self.all_node_array[:,self.node_info['bin_nodes']]
            bin_nodes[bin_nodes==0]=-1
            self.all_node_array[:,self.node_info['bin_nodes']] = bin_nodes
            
            if ignore_intermediate_nodes:
                self.all_node_array,self.node_info = self.convert_dataset(self.all_node_array,self.node_info)
                self.num_nodes = self.all_node_array.shape[-1]
            

            self._create_real_index(epi_sizes)
            print("DONE")
        else: #Test Mode init
            with open("gamedata_info.pkl",'rb') as f:
                tmp = pickle.load(f)
                self.normalizer = tmp['normalizer']
                self.node_info = tmp['node_info']
                print(self.node_info)
            if self.ignore_intermediate_nodes:
                n_feats = sum([len(self.node_info['nume_nodes']),len(self.node_info['cat_nodes']),len(self.node_info['bin_nodes'])])
                new_node_info={}
                new_node_info['nume_nodes']=np.arange(0,len(self.node_info['nume_nodes']))
                new_node_info['cat_nodes'] = np.arange(len(self.node_info['nume_nodes']),len(self.node_info['nume_nodes'])+len(self.node_info['cat_nodes']))
                new_node_info['bin_nodes']=np.arange(len(self.node_info['nume_nodes'])+len(self.node_info['cat_nodes']),n_feats)
                new_node_info['cat_ranges']=self.node_info['cat_ranges']
                self.old_node_info = self.node_info
                self.node_info = new_node_info
                self.num_nodes = n_feats

           #initialize relevant object to be filled in as jsons arrive
            self.all_node_array =None
            self.all_graphs=[]
            self.epi_sizes = [0]
            self.real_index = None

            self.buffer_node_feat = []
            self.buffer_graphs = []
            self.labels = []
            self.node_labels=[]
            print("HERE")

    def convert_dataset(self, all_node_array, node_info):
        numerical = all_node_array[:,node_info['nume_nodes']]
        categorical = all_node_array[:,node_info['cat_nodes']]
        binary = all_node_array[:,node_info['bin_nodes']]
        data = np.concatenate([numerical,categorical,binary], axis = 1)

        import copy
        self.old_node_info = copy.deepcopy(node_info)
        new_node_info ={}

        new_node_info['nume_nodes']=np.arange(0,numerical.shape[1])
        new_node_info['cat_nodes'] = np.arange(numerical.shape[1],numerical.shape[1]+categorical.shape[1])
        new_node_info['bin_nodes']=np.arange(numerical.shape[1]+categorical.shape[1],data.shape[1])
        new_node_info['cat_ranges']=node_info['cat_ranges']
        
        return data,new_node_info

    def convert_datapoint(self,x,node_info):
        numerical = x[node_info['nume_nodes']]
        categorical = x[node_info['cat_nodes']]
        binary = x[node_info['bin_nodes']]
        data = np.concatenate([numerical,categorical,binary], axis = 0)
        return data
    def receive_json_obj(self, json_obj, new_episode = False):
        '''
        Process json object and add it to dataset. If enough jsons received, prepare data for running.

            json_obj: A parsed JSON object. 
        '''

        
        data = self.jgraph.process(json_obj,self.task)
        x = data['node_value']

        if self.ignore_intermediate_nodes:
            x = self.convert_datapoint(x,self.old_node_info)
        
        self.all_node_array = np.expand_dims(x, axis=0) if self.all_node_array is None else np.concatenate([self.all_node_array,np.expand_dims(x,axis = 0)],axis =0)
        self.all_graphs.append(data['edge_list'])
        
        if not new_episode:
            self.epi_sizes[-1]+=1
        else:
            self.epi_sizes.append(1)

        #if(self.epi_sizes[-1]>=self.concat_steps):
        numerical_feats = self.all_node_array[-1, self.node_info["nume_nodes"]]
        self.all_node_array[-1, self.node_info["nume_nodes"]] = (numerical_feats - self.normalizer["mean"][None, :]) / (self.normalizer["std"][None, :])
        self._create_real_index(self.epi_sizes)
    
    
    def _normalize_numerical_nodes(self):
        numerical_feats = self.all_node_array[:, self.node_info["nume_nodes"]]
        # normalize numerical nodes
        self.all_node_array[:, self.node_info["nume_nodes"]] = (numerical_feats - self.normalizer["mean"][None, :]) / (self.normalizer["std"][None, :])

    def _create_real_index(self, epi_sizes):
        '''
        Creates index array to retrieve data with. 
            epi_sizes: List of episode lengths

        '''


        # get binary vector indicating which step is the last one in its episode 
        num_total = self.all_node_array.shape[0]
        num_epi = len(epi_sizes)
        print(num_epi,num_total)

        flag_beginning_steps = np.zeros(num_total)
        beginning_ind = 0
        for i in range(num_epi):
            flag_beginning_steps[beginning_ind : (beginning_ind + self.concat_steps - 1)] = 1
            beginning_ind = beginning_ind + epi_sizes[i]
        
        # delete indices of steps that are at the end of episodes
        index = np.arange(num_total)
        self.real_index = np.delete(index, np.where(flag_beginning_steps))

    def _make_normalizer_dict(self):
        '''
        Gathers information for normalizer.
        ''' 
        # getting data statistics for normalization
        numerical_feats = self.all_node_array[:, self.node_info["nume_nodes"]]
        self.normalizer["mean"] = np.mean(numerical_feats, axis=0)
        self.normalizer["std"] = np.std(numerical_feats, axis=0) + 0.001

    def _gather_node_info(self):
        '''
        Creates member node_info dictionary to store which nodes are numeric and binary. 
        
 
        '''
        # get nodes that need to be predicted
        var_flag = np.std(self.all_node_array, axis=0) > 1e-6 
        numerical_flag = np.logical_and(self.node_type == json_graph.ntype.numerical, var_flag)
        self.node_info["nume_nodes"] = np.where(numerical_flag)[0]

        binary_flag = np.logical_and(self.node_type == json_graph.ntype.binary, var_flag)
        self.node_info["bin_nodes"] = np.where(binary_flag)[0]

        
        
    



        

    def _generate_edge_list(self, attributed_edge_list):


        attributed_edges = np.array(attributed_edge_list).transpose()

        edges = attributed_edges[0:2, :]
        atts = attributed_edges[2, :]

        # convert graph attributes to one-hot encoding
        onehot = np.zeros((atts.shape[0], len(json_graph.etype.__dict__)))
        onehot[np.arange(atts.shape[0]), atts] = 1

        return edges, onehot

    
    def __getitem__(self, index):
        #print(index)
        if self.epi_sizes[-1]<self.concat_steps and self.mode == 'test':
            return None

        
        
        real_ind = self.real_index[index]
        
        # shape is (num_nodes, concat_steps)
        node_feats = self.all_node_array[(real_ind - self.concat_steps + 1) : real_ind].transpose()
        
        last_state = self.all_node_array[real_ind]
        
        if self.mode=='test' and self.task=='monopoly' and np.random.rand()<0.5: #novelty condition
            #modify last_state
            
            last_state,label,node_labels = monopoly_novelties.inject_novelty(node_feats,last_state, task=self.task, novelty_type=self.inject_novelty_type,all_node_array=self.all_node_array,normalizer = self.normalizer,node_info = self.node_info, old_node_info=self.old_node_info)
            self.labels.append(label)
            self.node_labels.append(node_labels)
            
        else:
            self.labels.append(0)
            self.node_labels.append(np.zeros(last_state.shape))



        # only use the last graph 
        edges, atts = [],[]#self._generate_edge_list(self.all_graphs[real_ind - 1])

        graph_data = data.Data(x          = torch.tensor(node_feats, dtype=torch.float32).unsqueeze(0), 
                               edge_index = torch.tensor(edges, dtype=torch.long), 
                               edge_attr  = torch.tensor(atts, dtype=torch.float32), 
                               last_state = torch.tensor(last_state, dtype=torch.float32).unsqueeze(0))

        return graph_data 


    def json_to_graph_stack(self, json_obj, episode_begin=False):

        if episode_begin: 
            self.buffer_node_feat = []
            self.buffer_graphs = []

        new_graph = self.jgraph.process(json_obj)
        new_graph["node_value"][self.node_info["nume_nodes"]] = (new_graph["node_value"][self.node_info["nume_nodes"]] - self.normalizer["mean"]) / (self.normalizer["std"])


        if len(self.buffer_node_feat) >= (self.concat_steps - 1):

            node_feats = self.buffer_node_feat[-(self.concat_steps - 1) : ]

            # graphs = self.buffer_graphs[-self.concat_steps : ]
            graph = self.buffer_graphs[-1]
            edges, atts = self._generate_edge_list(graph)

            # only use the last graph 

            last_state = new_graph["node_value"]
            graph_data = data.Data(x          = torch.tensor(node_feats, dtype=torch.float32), 
                                   edge_index = torch.tensor(edges, dtype=torch.long), 
                                   edge_attr  = torch.tensor(atts, dtype=torch.float32), 
                                   last_state = torch.tensor(last_state, dtype=torch.float32))


            # update the buffer
            self.buffer_node_feat.pop(0)
            self.buffer_graphs.pop(0)

            self.buffer_node_feat.append(new_graph["node_value"])
            self.buffer_graphs.append(new_graph["edge_list"]) 




            return graph_data

        return None

    def __len__(self):

        return self.real_index.shape[0]

    def unnormalize(self, node_values):
        numerical_feats = node_values[self.node_info["nume_nodes"]]
        node_values[self.node_info["nume_nodes"]] = (numerical_feats * self.normalizer["std"].reshape(-1,1)) + self.normalizer["mean"].reshape(-1,1)
        return np.array(np.round(node_values), int)
        
if __name__ == "__main__":

    from torch_geometric.loader import DataLoader

    # on test case 
    train_set = GameDataset(data_path=".", concat_steps=5, training=True)

    #data_list = [data, data]
    loader = DataLoader(train_set, batch_size=2)
    batch = next(iter(loader))
    
    print(batch)


    # another test case
    test_set = GameDataset(".", concat_steps=3, training=False)

    import json

    datapath ="/home/liulp/data/gridworldsData/new_train_data2/" 

    for epi in range(2):
        episode_begin=True
        for step in range(20):
            fname = str(epi+1) + '_' + str(step+1) + ".json"
                
            with open(os.path.join(datapath, fname)) as file:
                json_obj = json.load(file)

            json_obj = test_set.jgraph.prune_json(json_obj, task="gridworld")
            json_obj = test_set.jgraph.mark_json(json_obj, task="gridworld")

            graph_data = test_set.json_to_graph_stack(json_obj)

            episode_begin = False


