import numpy as np
import torch
import torch_geometric.data as data
import torch_geometric as pyg
import os
import pickle
from dataset import json_graph
from dataset.json_graph import JsonToGraph, NewNodeException
import pandas as pd
import networkx as nx
from sklearn.preprocessing import MinMaxScaler

def make_sensor_datasets(name,**kwargs):
    if name == 'wadi':
        return make_wadi_datasets(**kwargs)
    if name == 'swat':
        return make_swat_datasets(**kwargs)
    elif name == 'gecco':
        return make_gecco_datasets(**kwargs)
    elif name == 'swan_sf':
        return make_swan_sf_datasets(**kwargs)
    elif name == 'smap':
        return make_smap_datasets(**kwargs)
    elif name == 'msl':
        return make_msl_datasets(**kwargs)

def subsample(skip_t,traindata,trainlabels,testdata,testlabels):
    spaced_data_indeces = np.arange(0,traindata.shape[0], skip_t)
    trainlabels=np.array(get_spaced_labels(trainlabels,skip_t))
    traindata = traindata[spaced_data_indeces]

    spaced_data_indeces = np.arange(0,testdata.shape[0], skip_t)
    testlabels=np.array(get_spaced_labels(testlabels,skip_t))
    testdata = testdata[spaced_data_indeces]
    return traindata,trainlabels,testdata,testlabels

def get_spaced_labels(labels,spacing=10):
    sp_labels = []
    for i in range(0,len(labels),spacing):
        sp_labels.append(int(any(labels[i:i+spacing])))
    return sp_labels

def make_wadi_datasets(concat_steps=5,train_val_split =0.1, **kwargs):
    '''
    Create train validation and test dataset objects for WADI data
    '''
    traindata = pd.read_csv("../dataset/wadi/train.csv")
    testdata = pd.read_csv("../dataset/wadi/test.csv")
    with open("../dataset/swat/wadi/list.txt",'r') as f:
        feature_list = f.readlines()
    
    trainlabels = traindata['attack'].to_numpy()
    testlabels = testdata['attack'].to_numpy()
    dropped = ["attack",'Unnamed: 0'] # dropping label, index and problematic feature
    traindata = traindata.drop(dropped,axis=1).to_numpy()
    testdata = testdata.drop(dropped,axis=1).to_numpy()
    spacing = kwargs['subsample']
    traindata,trainlabels,testdata,testlabels = subsample(spacing,traindata,trainlabels,testdata,testlabels)

    feature_types = np.zeros(traindata.shape[-1])+json_graph.ntype.numerical
    
    bin_flag = find_binary_vars_from_data(traindata)
    feature_types[bin_flag]=json_graph.ntype.binary

    valdata= traindata[int(traindata.shape[0]*(1-train_val_split)):]
    traindata = traindata[:int(traindata.shape[0]*(1-train_val_split))]
    vallabels = np.zeros(valdata.shape[0])
    trainlabels = np.zeros(traindata.shape[0])

    train_set = SensorDataset(traindata,trainlabels,feature_types,concat_steps=concat_steps, feature_list=feature_list)
    data_info = {"normalizer":train_set.normalizer,"node_info":train_set.node_info}
    val_set = SensorDataset(valdata,vallabels,feature_types,concat_steps=concat_steps,data_info=data_info, feature_list=feature_list)
    test_set = SensorDataset(testdata,testlabels,feature_types,concat_steps=concat_steps,data_info=data_info, feature_list=feature_list)

    return train_set,val_set,test_set


def make_swat_datasets(concat_steps=5,train_val_split =0.1,**kwargs):
    '''
    Create train validation and test dataset objects for SWAT data
    '''
    traindata = pd.read_csv("../dataset/swat/train.csv")
    testdata = pd.read_csv("../dataset/swat/test.csv")
    with open("../dataset/swat/list.txt",'r') as f:
        feature_list = f.readlines()
    spacing = kwargs['subsample']
    
    trainlabels = traindata['attack'].to_numpy()
    testlabels = testdata['attack'].to_numpy()
    dropped = ["attack",'Unnamed: 0'] # dropping label, index and problematic feature
    
    traindata = traindata.drop(dropped,axis=1).to_numpy()
    testdata = testdata.drop(dropped,axis=1).to_numpy()
    
    traindata,trainlabels,testdata,testlabels = subsample(spacing,traindata,trainlabels,testdata,testlabels)
    
    feature_types = np.zeros(traindata.shape[-1])+json_graph.ntype.numerical
    
    bin_flag = find_binary_vars_from_data(traindata)
    feature_types[bin_flag]=json_graph.ntype.binary

    valdata= traindata[int(traindata.shape[0]*(1-train_val_split)):]
    traindata = traindata[:int(traindata.shape[0]*(1-train_val_split))]
    vallabels = np.zeros(valdata.shape[0])
    trainlabels = np.zeros(traindata.shape[0])

    train_set = SensorDataset(traindata,trainlabels,feature_types,concat_steps=concat_steps, feature_list=feature_list)
    data_info = {"normalizer":train_set.normalizer,"node_info":train_set.node_info}
    val_set = SensorDataset(valdata,vallabels,feature_types,concat_steps=concat_steps,data_info=data_info, feature_list=feature_list)
    test_set = SensorDataset(testdata,testlabels,feature_types,concat_steps=concat_steps,data_info=data_info, feature_list=feature_list)

    return train_set,val_set,test_set

def make_yahoo_datasets(concat_steps=5,train_val_split =0.1):
    '''
    Create train validation and test dataset objects for Yahoo data
    '''
    traindata = pd.read_csv("/home/plymper/data/sensor/yahoo/train.csv")
    testdata = pd.read_csv("/home/plymper/data/sensor/yahoo/test.csv")
    
    trainlabels = traindata['ground_truth'].to_numpy()
    testlabels = testdata['ground_truth'].to_numpy()
    
    dropped = ["attack",'Unnamed: 0'] #"2B_AIT_002_PV"] # dropping label, index and problematic feature
    traindata = traindata.drop(dropped,axis=1).to_numpy()
    testdata = testdata.drop(dropped,axis=1).to_numpy()
    return


def make_gecco_datasets(concat_steps=5, train_test_split = 0.5,train_val_split =0.1,**kwargs):
    '''
    Create train validation and test dataset objects for Gecco Water Quality data
    '''
    import sys
    print("here",file=sys.stderr)
    data = pd.read_csv("/home/plymper/data/gecco/water_quality.csv")
    labels = data['label'].to_numpy()
    dropped = ["label"] # dropping label
    data = data.drop(dropped,axis=1).to_numpy()
    
    test_split_point = int(data.shape[0]*(1-train_test_split))
    traindata = data[:test_split_point]
    
    testdata = data[test_split_point:]
    testlabels = labels[test_split_point:]
    
    trainlabels = labels[:test_split_point]
    traindata = traindata[np.where(trainlabels==0)]
    spacing = kwargs['subsample']
    traindata,trainlabels,testdata,testlabels = subsample(spacing,traindata,trainlabels,testdata,testlabels)

    val_split_point = int(traindata.shape[0]*(1-train_val_split))
    valdata = traindata[val_split_point:]
    traindata = traindata[:val_split_point]
    trainlabels = np.zeros(traindata.shape[0])
    vallabels = np.zeros(valdata.shape[0])
    
    feature_types = np.zeros(traindata.shape[-1])+json_graph.ntype.numerical
    
    bin_flag = find_binary_vars_from_data(traindata)
    feature_types[bin_flag]=json_graph.ntype.binary

    train_set = SensorDataset(traindata,trainlabels,feature_types,concat_steps=concat_steps)
    data_info = {"normalizer":train_set.normalizer,"node_info":train_set.node_info}
    val_set = SensorDataset(valdata,vallabels,feature_types,concat_steps=concat_steps,data_info=data_info)
    test_set = SensorDataset(testdata,testlabels,feature_types,concat_steps=concat_steps,data_info=data_info)
    
    return train_set,val_set,test_set

def make_swan_sf_datasets(concat_steps=5, train_test_split = 0.5,train_val_split =0.1):
    '''
    Create train validation and test dataset objects for swan_sf data
    '''
    data = pd.read_csv("/home/plymper/data/swan_sf/swan_sf.csv")
    labels = data['label'].to_numpy()
    dropped = ["label"] # dropping label
    data = data.drop(dropped,axis=1).to_numpy()

    test_split_point = int(data.shape[0]*(1-train_test_split))
    traindata = data[:test_split_point]
    
    testdata = data[test_split_point:]
    testlabels = labels[test_split_point:]

    trainlabels = labels[:test_split_point]
    traindata = traindata[np.where(trainlabels==0)]
    val_split_point = int(traindata.shape[0]*(1-train_val_split))
    valdata = traindata[val_split_point:]
    traindata = traindata[:val_split_point]
    trainlabels = np.zeros(traindata.shape[0])
    vallabels = np.zeros(valdata.shape[0])

    feature_types = np.zeros(traindata.shape[-1])+json_graph.ntype.numerical
    
    bin_flag = find_binary_vars_from_data(traindata)
    feature_types[bin_flag]=json_graph.ntype.binary

    train_set = SensorDataset(traindata,trainlabels,feature_types,concat_steps=concat_steps)
    data_info = {"normalizer":train_set.normalizer,"node_info":train_set.node_info}
    val_set = SensorDataset(valdata,vallabels,feature_types,concat_steps=concat_steps,data_info=data_info)
    test_set = SensorDataset(testdata,testlabels,feature_types,concat_steps=concat_steps,data_info=data_info)
    
    return train_set,val_set,test_set

def make_smap_datasets(concat_steps=5, train_val_split =0.1):
    '''
    Create train validation and test dataset objects for SMAP data
    '''
    import pickle as pkl


    with open("/home/plymper/data/sensor/nasa/processed/SMAP_train.pkl", 'rb') as f:
        traindata = pkl.load(f)
        trainlabels = np.zeros(traindata.shape[0])
    
    with open("/home/plymper/data/sensor/nasa/processed/SMAP_test.pkl", 'rb') as f:
        testdata = pkl.load(f)
    with open("/home/plymper/data/sensor/nasa/processed/SMAP_test_label.pkl", 'rb') as f:
        testlabels = pkl.load(f)
    
    feature_types = np.zeros(traindata.shape[-1])+json_graph.ntype.numerical
    
    bin_flag = find_binary_vars_from_data(traindata)
    feature_types[bin_flag]=json_graph.ntype.binary

    valdata= traindata[int(traindata.shape[0]*(1-train_val_split)):]
    traindata = traindata[:int(traindata.shape[0]*(1-train_val_split))]
    vallabels = np.zeros(valdata.shape[0])
    trainlabels = np.zeros(traindata.shape[0])

    train_set = SensorDataset(traindata,trainlabels,feature_types,concat_steps=concat_steps)
    data_info = {"normalizer":train_set.normalizer,"node_info":train_set.node_info}
    val_set = SensorDataset(valdata,vallabels,feature_types,concat_steps=concat_steps,data_info=data_info)
    test_set = SensorDataset(testdata,testlabels,feature_types,concat_steps=concat_steps,data_info=data_info)

    
    return train_set,val_set,test_set

def make_msl_datasets(concat_steps=5, train_val_split =0.1):
    '''
    Create train validation and test dataset objects for MSL data
    '''
    import pickle as pkl


    with open("/home/plymper/mtad-gat-pytorch/datasets/data/processed/MSL_train.pkl", 'rb') as f:
        traindata = pkl.load(f)
        trainlabels = np.zeros(traindata.shape[0])
    
    with open("/home/plymper/mtad-gat-pytorch/datasets/data/processed/MSL_test.pkl", 'rb') as f:
        testdata = pkl.load(f)
    with open("/home/plymper/mtad-gat-pytorch/datasets/data/processed/MSL_test_label.pkl", 'rb') as f:
        testlabels = pkl.load(f)
    
    feature_types = np.zeros(traindata.shape[-1])+json_graph.ntype.numerical
    
    bin_flag = find_binary_vars_from_data(traindata)
    feature_types[bin_flag]=json_graph.ntype.binary

    valdata= traindata[int(traindata.shape[0]*(1-train_val_split)):]
    traindata = traindata[:int(traindata.shape[0]*(1-train_val_split))]
    vallabels = np.zeros(valdata.shape[0])
    trainlabels = np.zeros(traindata.shape[0])

    train_set = SensorDataset(traindata,trainlabels,feature_types,concat_steps=concat_steps)
    data_info = {"normalizer":train_set.normalizer,"node_info":train_set.node_info}
    val_set = SensorDataset(valdata,vallabels,feature_types,concat_steps=concat_steps,data_info=data_info)
    test_set = SensorDataset(testdata,testlabels,feature_types,concat_steps=concat_steps,data_info=data_info)

    
    return train_set,val_set,test_set

def find_binary_vars_from_data(x):
    bin_flag = np.zeros(x.shape[1])
    for i in range(x.shape[1]):
        iszero = x[:,i]==0
        isone = x[:,i]==1
        bin_flag[i] = 0#int(all(np.logical_xor(iszero,isone)))
    
    bin_flag=np.array(bin_flag,bool)
    return bin_flag

class SensorDataset(data.Dataset):

    def __init__(self, data,labels,feature_types, concat_steps, data_info = None, feature_list= None):
        """
        Class for sensor data. Used for Gecco, WADI, Swan_sf.  
        """

       
        self.node_feature_dim = concat_steps - 1
        
        self.node_info = dict()
        self.node_info["cat_nodes"] = np.array([])
        self.node_info["cat_ranges"] = np.array([])

        self.node_type = feature_types
        self.num_nodes = feature_types.shape[0]
        self.labels = labels[concat_steps-1:]
        self.feature_list=feature_list
        #reshape into graph-like array
        self.all_node_array = data.reshape((-1,self.num_nodes))

        self.normalizer = dict()

        self.concat_steps = concat_steps

        node_types = ['nume_nodes', 'cat_nodes', 'bin_nodes']
        if data_info==None or any([i not in data_info for i in node_types]):
            self._gather_node_info(data_info)
            self._make_normalizer_dict()
        else:
            self.node_info=data_info['node_info']
            self.normalizer= data_info['normalizer_dict']

        self._normalize_numerical_nodes()
        #make negative class -1 for masking to work
        if 'bin_nodes' in self.node_info: 
            bin_nodes = self.all_node_array[:,self.node_info['bin_nodes']]
            bin_nodes[bin_nodes==0]=-1
            self.all_node_array[:,self.node_info['bin_nodes']] = bin_nodes

        self._create_real_index()

        g = nx.complete_graph(self.num_nodes)
        self.edge_index = np.array([])#pyg.utils.from_networkx(g).edge_index

        


    def _normalize_numerical_nodes(self):
        numerical_feats = self.all_node_array[:, self.node_info["nume_nodes"]]
        # normalize numerical nodes
        self.all_node_array[:, self.node_info["nume_nodes"]] = (numerical_feats - self.normalizer["mean"][None, :]) / (self.normalizer["std"][None, :])

    def _create_real_index(self):
        '''
        Creates index array to retrieve data with. 
            

        '''
        # get binary vector indicating which step is the last one in its episode 
        num_total = self.all_node_array.shape[0]
        num_epi = 1 # everything is one episode here

        flag_beginning_steps = np.zeros(num_total)
        beginning_ind = 0
        for i in range(num_epi):
            flag_beginning_steps[beginning_ind : (beginning_ind + self.concat_steps - 1)] = 1
            beginning_ind = beginning_ind + 1

        # delete indices of steps that are at the end of episodes
        index = np.arange(num_total)
        self.real_index = np.delete(index, np.where(flag_beginning_steps))

    def _make_normalizer_dict(self):
        '''
        Gathers information for normalizer.
        ''' 
        # getting data statistics for normalization
        numerical_feats = self.all_node_array[:, self.node_info["nume_nodes"]]
        self.normalizer["mean"] = np.mean(numerical_feats, axis=0)
        self.normalizer["std"] = np.std(numerical_feats, axis=0) + 0.001
        
   
    def _gather_node_info(self, data_info):
        '''
        Creates member node_info dictionary to store which nodes are numeric and binary. 
        
 
        '''
        # get nodes that need to be predicted
        
        print(self.all_node_array.shape)
        var_flag = self.all_node_array.std(axis = 0) > 1e-6 
        numerical_flag = np.logical_and(self.node_type == 3, var_flag)
        self.node_info = {}
        
        self.node_info["nume_nodes"] = np.array(np.where(self.node_type == 3)[0], np.int)
        
        if data_info is None:
            binary_flag = np.logical_and(self.node_type == 2, var_flag)
            self.node_info["bin_nodes"] = np.array(np.where(self.node_type == 2)[0], np.int)
        elif 'bin_nodes' in data_info:
            self.node_info['bin_nodes'] = data_info['bin_nodes']
        
        if data_info is None:
            self.node_info["cat_nodes"] = np.array([])#np.where(self.node_type==8)[0]
            self.node_info["cat_ranges"] = np.array([])
        elif 'cat_nodes' in data_info:
            self.node_info['cat_nodes'] = data_info['cat_nodes']
            self.node_info['cat_ranges'] = data_info['cat_ranges']
        


    def _generate_edge_list(self):
        #g = nx.complete_graph(self.num_nodes)
        #g = pyg.utils.from_networkx(g)
        #adj = nx.adjacency_matrix(g).todense().reshape(1,self.num_nodes,self.num_nodes)
        
        return self.edge_index,None#g.edge_index,None
        #return g.edge_index,None #leaving empty since not using GNN


    def __getitem__(self, index):

        real_ind = self.real_index[index]
        
        # shape is (num_nodes, concat_steps)
        node_feats = self.all_node_array[(real_ind - self.concat_steps + 1) : real_ind].transpose()
        last_state = self.all_node_array[real_ind]
        
        #graphs = self.all_graphs[(real_ind - self.concat_steps + 1) : (real_ind + 1)]

        # only use the last graph 
        edges, atts = self._generate_edge_list()

        graph_data = data.Data(x          = torch.tensor(node_feats, dtype=torch.float32).unsqueeze(0), 
                               edge_index = torch.tensor(edges, dtype=torch.long), 
                               last_state = torch.tensor(last_state, dtype=torch.float32).unsqueeze(0))

        return graph_data 


    def __len__(self):

        return self.real_index.shape[0]

    def unnormalize(self, node_values):
        numerical_feats = node_values[self.node_info["nume_nodes"]]
        node_values[self.node_info["nume_nodes"]] = (numerical_feats * self.normalizer["std"].reshape(-1,1)) + self.normalizer["mean"].reshape(-1,1)
        return np.array(np.round(node_values), int)
        

        
if __name__ == "__main__":

    pass


