from scipy.io.matlab.mio5 import varmats_from_mat
import torch

import scipy
import numpy as np
import torch_geometric.data as data
import scipy.io
import pickle as pkl
import pandas as pd


def prepareKD99rev():
    df = pd.read_csv("/home/plymper/data/KDDCup/kddcup.data_10_percent", header = None)

    labels = df[41].to_numpy()
    labels = np.array([0 if l=='normal.' else 1 for l in labels])
    df =df.drop([0,41],axis=1)
    data_X = df.to_numpy()
    with open("/home/plymper/data/KDDCup/vars.txt", 'r') as f:
        vars = f.readlines()[1:]
   
    vars = [v.split(":")[1] for v in vars]
    vars = np.array([8 if v==' symbolic.\n' else 3 for v in vars])
    
    bin_flag = find_binary_vars_from_data(data_X)
  
    vars[bin_flag]=2
    node_type = vars

    all_categories = []
    for i,v in enumerate(node_type):
        if v==8:
            all_categories.append(list(set(data_X[:,i])))

    cat_inds = np.concatenate(all_categories)
    cat_inds = {c:i for i,c in enumerate(cat_inds)}


    cat_nodes = np.where(node_type==8)[0]
    cat_ranges = []
    max_cat = 0
    for vals in all_categories:
        cat_ranges.append([cat_inds[i] for i in vals])
        if len(vals) > max_cat:
            max_cat = len(vals)

    for i in range(len(cat_ranges)): 
        padding = [-1] * (max_cat - len(cat_ranges[i]))
        cat_ranges[i] = cat_ranges[i] + padding 

    for i,v in enumerate(node_type):
        if v==8:
            for j in range(len(data_X[:,i])):
                data_X[j,i]=float(cat_inds[data_X[j,i]])

    data_info = {'cat_nodes':cat_nodes, 'cat_ranges':np.array(cat_ranges)}

    #Downsample attacks to get 4:1 normal:attack ratio
    normal_inds = np.where(labels == 0)[0]
    n_attack = normal_inds.shape[0]//4
    attack_inds = np.random.choice(np.where(labels==0)[0],size = n_attack)
    all_inds = np.concatenate((normal_inds,attack_inds))
    data_X = data_X[all_inds]


    data = {"X":np.array(data_X,np.float),"y":labels}

    return data,data_info, node_type

def prepareKD99():
    df = pd.read_csv("/home/plymper/data/KDDCup/kddcup.data_10_percent", header = None)

    labels = df[41].to_numpy()
    labels = np.array([1 if l=='normal.' else 0 for l in labels])
    df =df.drop([0,41],axis=1)
    data_X = df.to_numpy()
    with open("/home/plymper/data/KDDCup/vars.txt", 'r') as f:
        vars = f.readlines()[1:]
   
    vars = [v.split(":")[1] for v in vars]
    vars = np.array([8 if v==' symbolic.\n' else 3 for v in vars])
    
    bin_flag = find_binary_vars_from_data(data_X)
  
    vars[bin_flag]=2
    node_type = vars

    all_categories = []
    for i,v in enumerate(node_type):
        if v==8:
            all_categories.append(list(set(data_X[:,i])))

    cat_inds = np.concatenate(all_categories)
    cat_inds = {c:i for i,c in enumerate(cat_inds)}


    cat_nodes = np.where(node_type==8)[0]
    cat_ranges = []
    max_cat = 0
    for vals in all_categories:
        cat_ranges.append([cat_inds[i] for i in vals])
        if len(vals) > max_cat:
            max_cat = len(vals)

    for i in range(len(cat_ranges)): 
        padding = [-1] * (max_cat - len(cat_ranges[i]))
        cat_ranges[i] = cat_ranges[i] + padding 

    for i,v in enumerate(node_type):
        if v==8:
            for j in range(len(data_X[:,i])):
                data_X[j,i]=float(cat_inds[data_X[j,i]])

    data_info = {'cat_nodes':cat_nodes, 'cat_ranges':np.array(cat_ranges)}

    
    data = {"X":np.array(data_X,np.float),"y":labels}

    return data,data_info, node_type



def find_binary_vars_from_data(x):
    bin_flag = np.zeros(x.shape[1])
    for i in range(x.shape[1]):
        iszero = x[:,i]==0
        isone = x[:,i]==1
        bin_flag[i] = int(all(np.logical_xor(iszero,isone)))
    
    bin_flag=np.array(bin_flag,bool)
    return bin_flag

def make_tabular_datasets(name='arrhythmia',train_test_split=0.5, train_val_split = 0.1):
    
    
    np.random.seed(0)
    if name == 'arrhythmia':
        data = scipy.io.loadmat("/home/plymper/data/arrhythmia/arrhythmia.mat")
        feature_types = np.zeros(data['X'].shape[1])+3 #all feaures numerical
        bin_flag = find_binary_vars_from_data(data['X'])
        feature_types[bin_flag] = 2
        data_info = None

    elif name == 'thyroid':
        data = scipy.io.loadmat("/home/plymper/data/thyroid/thyroid.mat")
        feature_types = np.zeros(data['X'].shape[1])+3 #all feaures numerical
        bin_flag = find_binary_vars_from_data(data['X'])
        feature_types[bin_flag] = 2
        data_info = None

    elif name == 'blob':
        with open("/home/plymper/data/synthetic/blob.pkl", 'rb') as f:
            data = pkl.load(f)
        data_info = None
        feature_types = np.zeros(data['X'].shape[1])+3 #all feaures numerical
    elif name=='kdd99':
        data,data_info,feature_types = prepareKD99()
        
    
    
    
    
    x,y = data['X'], data['y']
    inds = np.random.permutation(np.arange(x.shape[0]))
    x,y = x[inds], y[inds].flatten()

    n_points = x.shape[0]
    x_train, x_test = x[:int(n_points*train_test_split)], x[int(n_points*train_test_split):]
    y_train,y_test = y[:int(n_points*train_test_split)], y[int(n_points*train_test_split):]

    x_train = x_train[np.where(y_train==0)]
    

    x_val = x_train[:int(x_train.shape[0]*train_val_split)]
    x_train = x_train[int(x_train.shape[0]*train_val_split):]

    y_train = np.zeros((x_train.shape[0],))
    y_val = np.zeros((x_val.shape[0],))

    print(x_train.shape, x_val.shape, x_test.shape)
    train_dataset= TabularDataset(x_train,y_train,feature_types, data_info=data_info)

    data_info = {"normalizer_dict":train_dataset.normalizer,'node_info':train_dataset.node_info}
    val_dataset = TabularDataset(x_val,y_val,feature_types, data_info=data_info)
    test_dataset = TabularDataset(x_test,y_test,feature_types, data_info=data_info)  

        

    return train_dataset,val_dataset,test_dataset

class TabularDataset(data.Dataset):
    def __init__(self, data,labels, feature_types, data_info = None):

        self.data = data
        self.labels = labels
        
        self.node_type = feature_types
        self.num_nodes = feature_types.shape[0]
        self.node_feature_dim = 1

        #reshape into graph-like array
        self.all_node_array = self.data.reshape((-1,self.num_nodes,1))
        
        node_types = ['nume_nodes', 'cat_nodes', 'bin_nodes']
        if data_info==None or any([i not in data_info for i in node_types]):
            self._gather_node_info(data_info)
            self._make_normalizer_dict()
        else:
            self.node_info=data_info['node_info']
            self.normalizer= data_info['normalizer_dict']

        self._normalize_numerical_nodes()
        
        #print(self.all_node_array[0])
        return

    def __getitem__(self, index):
        
        node_feats = self.all_node_array[index]
        graph_data = data.Data(x = torch.tensor(node_feats, dtype=torch.float32), last_state = torch.tensor(node_feats, dtype=torch.float32))
        return graph_data


    def _gather_node_info(self, data_info):
        '''
        Creates member node_info dictionary to store which nodes are numeric and binary. 
        
 
        '''
        # get nodes that need to be predicted
        
        print(self.all_node_array.shape)
        var_flag = self.all_node_array.std(axis = 0) > 1e-6 
        numerical_flag = np.logical_and(self.node_type == 3, var_flag)
        self.node_info = {}
        
        self.node_info["nume_nodes"] = np.array(np.where(self.node_type == 3)[0], np.int)
        
        
        binary_flag = np.logical_and(self.node_type == 2, var_flag)
        self.node_info["bin_nodes"] = np.array(np.where(self.node_type == 2)[0], np.int)
        
        if data_info is None:
            self.node_info["cat_nodes"] = np.array([])#np.where(self.node_type==8)[0]
            self.node_info["cat_ranges"] = np.array([])
        elif 'cat_nodes' in data_info:
            self.node_info['cat_nodes'] = data_info['cat_nodes']
            self.node_info['cat_ranges'] = data_info['cat_ranges']
        
         
        

    def _make_normalizer_dict(self):
        '''
        Gathers information for normalizer.
        ''' 
        # getting data statistics for normalization
        numerical_feats = self.all_node_array[:, self.node_info["nume_nodes"]]
        print(numerical_feats.shape)
        self.normalizer = {}
        self.normalizer["mean"] = np.mean(numerical_feats, axis=0)
        self.normalizer["std"] = np.std(numerical_feats, axis=0) + 0.001

    def _normalize_numerical_nodes(self):
        numerical_feats = self.all_node_array[:, self.node_info["nume_nodes"]]
        # normalize numerical nodes
        self.all_node_array[:, self.node_info["nume_nodes"]] = (numerical_feats - self.normalizer["mean"][None, :]) / (self.normalizer["std"][None, :])

    def __len__(self):

        return len(self.data)