import copy

import torch
import numpy as np
import random
import pandas as pd
import scipy.io
from scipy.io import arff
import os
import sys
import zipfile


def process_mat_dataset(data):
    # data[:, -1] the last column of data should be the binary label
    normals = data[data[:, -1] == 0]
    anomalies = data[data[:, -1] == 1]
    n_normals = normals.shape[0]
    normals = normals[np.random.permutation(n_normals)]
    train = normals[:n_normals // 2 + 1]
    test_normal = normals[n_normals // 2 + 1:]
    test = np.concatenate([test_normal, anomalies])
    test = test[np.random.permutation(test.shape[0])]
    test_classes = test[:, -1]
    train = train[:, :-1]
    test = test[:, :-1]
    return train, test, test_classes


def get_target_label_idx(labels, targets):
    """
    Get the indices of labels that are included in targets.
    :param labels: array of labels
    :param targets: list/tuple of target labels
    :return: list with indices of target labels
    """
    return np.argwhere(np.isin(labels, targets)).flatten().tolist()


def convert_multiclass(train_data, normal_class, test_data=None):
    new_data = []
    for data in [copy.deepcopy(train_data), copy.deepcopy(test_data)]:
        if data is None:
            new_data.append(None)
            continue
        else:
            y = data[:, -1].reshape(-1, 1).astype(int)
            data[:, -1] = np.apply_along_axis(lambda x: 0 if x == normal_class else 1, 1, y)
            new_data.append(data)
    return new_data[0], new_data[1]


def load_dataset(root, name):
    if name in ['arrhythmia', 'wine', 'lympho', 'glass', 'vertebral', 'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
                'vowels', 'letter', 'cardio', 'seismic', 'musk', 'speech', 'abalone', 'pendigits', 'mammography',
                'mulcross', 'forest_cover']:
        file_name = root + '/' + name + '/' + name + '.npy'
        with open(file_name, 'rb') as f:
            data = np.load(f)
        y = data[:, -1].astype(int)
        classes = list(set(y))
        return data, None, classes
    elif name in ['thyroid', 'optdigits', 'satimage', 'shuttle', 'kdd']:  # for dataset with test set
        train_file_name = root + '/' + name + '/' + name + '_train.npy'
        with open(train_file_name, 'rb') as f:
            train_data = np.load(f)

        test_file_name = root + '/' + name + '/' + name + '_test.npy'
        with open(test_file_name, 'rb') as f:
            test_data = np.load(f)
        y = train_data[:, -1].astype(int)
        classes = list(set(y))
        return train_data, test_data, classes
    else:
        raise NotImplementedError


def process_dataset(train_data, test_data, classes, normal_class, b_size, normalize=False):
    if len(classes) > 2:  # multiclass dataset
        y = train_data[:, -1].astype(int)
        n_normal_sample = sum(y == normal_class)
        if n_normal_sample <= 20:
            return None, None, None
        else:
            train_data, test_data = convert_multiclass(train_data, normal_class, test_data)

    if test_data is None:  # dataset with test set
        train_data_x, test_data_x, test_data_y = process_mat_dataset(train_data)
        train_data_y = np.zeros(train_data_x.shape[0])
    else:
        train_data = train_data[train_data[:, -1] == 0]
        train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)
        test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)
    n_sample, n_dim = train_data_x.shape
    if n_sample > 8000:
        print(normal_class, n_sample)

    # normalize
    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

    train_data_x = torch.tensor(train_data_x)
    test_data_x = torch.tensor(test_data_x)
    test_data_y = torch.tensor(test_data_y)
    return train_data_x, test_data_x, test_data_y


class Data_Loader:

    def __init__(self, n_trains=None):
        self.n_train = n_trains

    def get_dataset(self, dataset_name):
        script_dir = os.path.dirname(__file__)
        rel_path = os.path.join("Data/",dataset_name)
        abs_file_path = os.path.join(script_dir, rel_path)
        mat_files=['annthyroid','arrhythmia','breastw','cardio','forest_cover','glass','ionosphere','letter','lympho','mammography','mnist','musk',
                   'optdigits','pendigits','pima','satellite','satimage','shuttle','speech','thyroid','vertebral','vowels','wbc','wine']
        if dataset_name in mat_files :
            print ('generic mat file')
            return self.build_train_test_generic_matfile(abs_file_path)

        if dataset_name == 'thyroid_train_test':
            print('seismic')
            return self.build_train_test_thyroid(script_dir+'/Data')

        if dataset_name == 'seismic':
            print('seismic')
            return self.build_train_test_seismic(abs_file_path+'.arff')

        if dataset_name == 'mulcross':
            print('mullcross')
            return self.build_train_test_mulcross(abs_file_path+'.arff')

        if dataset_name == 'abalone':
            print('abalone')
            return self.build_train_test_abalone(abs_file_path+'.data')

        if dataset_name == 'ecoli':
            print('ecoli')
            return self.build_train_test_ecoli(abs_file_path+'.data')

        if dataset_name == 'kdd':
            print ('kdd')
            return self.build_train_test_kdd(script_dir+'/Data/kddcup.data_10_percent_corrected.zip')

        if dataset_name == 'kdd-large':
            print ('kdd-large')
            return self.build_train_test_kdd_large(script_dir+'/Data')

        if dataset_name == 'kddrev':
            print ('kddrev')
            return self.build_train_test_kdd_rev(script_dir+'/Data/kddcup.data_10_percent_corrected.zip')

        sys.exit ('No such dataset!')

    def build_train_test_generic_matfile(self,name_of_file):  # takes a matrice that contains the entire dataset, and creates a trainset with 50% of the data of all normals, and the rest are test
        dataset = scipy.io.loadmat(name_of_file)
        X = dataset['X']
        classes = dataset['y']
        jointXY = torch.cat((torch.tensor(X,dtype=torch.double), torch.tensor(classes,dtype=torch.double)), dim=1)
        normals=jointXY[jointXY[:,-1]==0]
        anomalies=jointXY[jointXY[:,-1]==1]
        normals = normals[torch.randperm(normals.shape[0])]
        train, test_norm = torch.split(normals, int(normals.shape[0] / 2) + 1)
        test = torch.cat((test_norm, anomalies))
        test = test[torch.randperm(test.shape[0])]
        train = train[torch.randperm(train.shape[0])]
        test_classes = test[:, -1].view(-1, 1)
        train = train[:, 0:train.shape[1] - 1]
        test = test[:, 0:test.shape[1] - 1]
        return (train, test, test_classes)


    def build_train_test_seismic(self,name_of_file):  # takes a matrice that contains the entire dataset, and creates a trainset with 50% of the data of all normals, and the rest are test
        dataset, meta = arff.loadarff(name_of_file)
        dataset = pd.DataFrame(dataset)
        classes = dataset.iloc[:, -1]
        dataset = dataset.iloc[:, :-1]
        dataset = pd.get_dummies(dataset.iloc[:, :-1])
        dataset = pd.concat((dataset, classes), axis=1)
        normals = dataset[dataset.iloc[:, -1] == b'0'].values
        anomalies = dataset[dataset.iloc[:, -1] == b'1'].values
        normals = normals[torch.randperm(normals.shape[0])]
        normals = torch.tensor(normals[:, :-1].astype('float32'))
        anomalies = torch.tensor(anomalies[:, :-1].astype('float32'))
        normals = torch.cat((normals, torch.zeros(normals.shape[0]).view(-1, 1)), dim=1)
        anomalies = torch.cat((anomalies, torch.ones(anomalies.shape[0]).view(-1, 1)), dim=1)
        train, test_norm = torch.split(normals, int(normals.shape[0] / 2) + 1)
        test = torch.cat((test_norm, anomalies))
        test = test[torch.randperm(test.shape[0])]
        train = train[torch.randperm(train.shape[0])]
        test_classes = test[:, -1].view(-1, 1)
        train = train[:, 0:train.shape[1] - 1]
        test = test[:, 0:test.shape[1] - 1]
        return (train, test, test_classes)

    def build_train_test_mulcross(self,name_of_file):  # takes a matrice that contains the entire dataset, and creates a trainset with 50% of the data of all normals, and the rest are test
        dataset, _ = arff.loadarff(name_of_file)
        dataset = pd.DataFrame(dataset)
        normals = dataset[dataset.iloc[:, -1] == b'Normal'].values
        anomalies = dataset[dataset.iloc[:, -1] == b'Anomaly'].values
        normals = normals[torch.randperm(normals.shape[0])]
        normals = torch.tensor(normals[:, :-1].astype('float32'))
        anomalies = torch.tensor(anomalies[:, :-1].astype('float32'))
        normals = torch.cat((normals, torch.zeros(normals.shape[0]).view(-1, 1)), dim=1)
        anomalies = torch.cat((anomalies, torch.ones(anomalies.shape[0]).view(-1, 1)), dim=1)
        train, test_norm = torch.split(normals, int(normals.shape[0] / 2) + 1)
        test = torch.cat((test_norm, anomalies))
        test = test[torch.randperm(test.shape[0])]
        train = train[torch.randperm(train.shape[0])]
        test_classes = test[:, -1].view(-1, 1)
        train = train[:, 0:train.shape[1] - 1]
        test = test[:, 0:test.shape[1] - 1]
        return (train, test, test_classes)

    def build_train_test_ecoli(self,name_of_file):  # takes a matrice that contains the entire dataset, and creates a trainset with 50% of the data of all normals, and the rest are test
        dataset = pd.read_csv(name_of_file, header=None, sep='\s+')
        dataset = dataset.iloc[:, 1:]
        anomalies = np.array(
            dataset[(dataset.iloc[:, 7] == 'omL') | (dataset.iloc[:, 7] == 'imL') | (dataset.iloc[:, 7] == 'imS')])[:,
                    :-1]
        normals = np.array(dataset[(dataset.iloc[:, 7] == 'cp') | (dataset.iloc[:, 7] == 'im') | (
                    dataset.iloc[:, 7] == 'pp') | (dataset.iloc[:, 7] == 'imU') | (dataset.iloc[:, 7] == 'om')])[:, :-1]
        normals = torch.tensor(normals.astype('double'))
        anomalies = torch.tensor(anomalies.astype('double'))
        normals = torch.cat((normals, torch.zeros(normals.shape[0], 1,dtype=torch.double)), dim=1)
        anomalies = torch.cat((anomalies, torch.ones(anomalies.shape[0], 1,dtype=torch.double)), dim=1)
        normals = normals[torch.randperm(normals.shape[0])]
        anomalies = anomalies[torch.randperm(anomalies.shape[0])]
        train, test_norm = torch.split(normals, int(normals.shape[0] / 2) + 1)
        test = torch.cat((test_norm, anomalies))
        test = test[torch.randperm(test.shape[0])]
        test_classes = test[:, -1].view(-1, 1)
        train = train[:, :-1]
        test = test[:, :-1]
        return (train, test, test_classes)

    def build_train_test_abalone(self,path):  # takes a matrice that contains the entire dataset, and creates a trainset with 50% of the data of all normals, and the rest are test

        data = pd.read_csv(path, header=None, sep=',')
        data = data.rename(columns={8: 'y'})
        data['y'].replace([8, 9, 10], -1, inplace=True)
        data['y'].replace([3, 21], 0, inplace=True)
        data.iloc[:, 0].replace('M', 0, inplace=True)
        data.iloc[:, 0].replace('F', 1, inplace=True)
        data.iloc[:, 0].replace('I', 2, inplace=True)
        test = data[data['y'] == 0]
        normal = data[data['y'] == -1].sample(frac=1)
        num_normal_samples_test = normal.shape[0] // 2
        test_data = np.concatenate((test.drop('y', axis=1), normal[:num_normal_samples_test].drop('y', axis=1)), axis=0)
        train = normal[num_normal_samples_test:]
        train_data = train.drop('y', axis=1).values
        test_labels = np.concatenate((test['y'], normal[:num_normal_samples_test]['y'].replace(-1, 1)), axis=0)
        for i in range(test_labels.shape[0]):
            if test_labels[i] == 0:
                test_labels[i] = 1
            else:
                test_labels[i] = 0
        train_data=torch.tensor(train_data.astype('double'))
        test_data=torch.tensor(test_data.astype('double'))
        test_labels=torch.tensor(test_labels.astype('double'))
        return (train_data, test_data, test_labels)

    def build_train_test_kdd(self,name_of_file):  # takes a matrice that contains the entire dataset, and creates a trainset with 50% of the data of all normals, and the rest are test
        zf = zipfile.ZipFile(name_of_file)
        kdd_loader = pd.read_csv(zf.open('kddcup.data_10_percent_corrected'), delimiter=',')
        entire_set = np.array(kdd_loader)
        revised_pd = pd.DataFrame(entire_set)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 1], prefix='new1')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 2], prefix='new2')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 3], prefix='new3')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 6], prefix='new6')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 11], prefix='new11')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 21], prefix='new21')), axis=1)
        revised_pd.drop(revised_pd.columns[[1, 2, 3, 6, 11, 20, 21]], inplace=True, axis=1)
        new_columns = [0, 'new1_icmp', 'new1_tcp', 'new1_udp', 'new2_IRC', 'new2_X11', 'new2_Z39_50', 'new2_auth',
                       'new2_bgp',
                       'new2_courier', 'new2_csnet_ns', 'new2_ctf', 'new2_daytime', 'new2_discard', 'new2_domain',
                       'new2_domain_u', 'new2_echo', 'new2_eco_i', 'new2_ecr_i', 'new2_efs', 'new2_exec', 'new2_finger',
                       'new2_ftp', 'new2_ftp_data', 'new2_gopher', 'new2_hostnames', 'new2_http', 'new2_http_443',
                       'new2_imap4',
                       'new2_iso_tsap', 'new2_klogin', 'new2_kshell', 'new2_ldap', 'new2_link', 'new2_login', 'new2_mtp',
                       'new2_name', 'new2_netbios_dgm', 'new2_netbios_ns', 'new2_netbios_ssn', 'new2_netstat', 'new2_nnsp',
                       'new2_nntp', 'new2_ntp_u', 'new2_other', 'new2_pm_dump', 'new2_pop_2', 'new2_pop_3', 'new2_printer',
                       'new2_private', 'new2_red_i', 'new2_remote_job', 'new2_rje', 'new2_shell', 'new2_smtp',
                       'new2_sql_net',
                       'new2_ssh', 'new2_sunrpc', 'new2_supdup', 'new2_systat', 'new2_telnet', 'new2_tftp_u', 'new2_tim_i',
                       'new2_time', 'new2_urh_i', 'new2_urp_i', 'new2_uucp', 'new2_uucp_path', 'new2_vmnet', 'new2_whois',
                       'new3_OTH', 'new3_REJ', 'new3_RSTO', 'new3_RSTOS0', 'new3_RSTR', 'new3_S0', 'new3_S1', 'new3_S2',
                       'new3_S3', 'new3_SF', 'new3_SH', 4, 5, 'new6_0', 'new6_1', 7, 8, 9, 10, 'new11_0', 'new11_1', 12, 13,
                       14,
                       15, 16, 17, 18, 19, 'new21_0', 'new21_1', 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
                       35, 36, 37, 38, 39, 40, 41]
        revised_pd = revised_pd.reindex(columns=new_columns)
        revised_pd.loc[revised_pd[41] != 'normal.', 41] = 1.0  # seems a mistake
        revised_pd.loc[revised_pd[41] == 'normal.', 41] = 0.0
        kdd_normal = np.array(revised_pd.loc[revised_pd[41] == 0.0], dtype=np.double)
        kdd_anomaly = np.array(revised_pd.loc[revised_pd[41] == 1.0], dtype=np.double)
        kdd_normal = torch.tensor(kdd_normal)
        kdd_anomaly = torch.tensor(kdd_anomaly)
        kdd_normal = kdd_normal[
            torch.randperm(kdd_normal.shape[0])]
        kdd_anomaly = kdd_anomaly[torch.randperm(kdd_anomaly.shape[0])]
        train, test_norm = torch.split(kdd_normal, int(kdd_normal.shape[0] / 2) + 1)
        test = torch.cat((test_norm, kdd_anomaly))
        test = test[torch.randperm(test.shape[0])]
        test_classes = test[:, -1].view(-1, 1)
        train = train[:, 0:train.shape[1] - 1]
        test = test[:, 0:test.shape[1] - 1]
        return (train, test, test_classes)


    def build_train_test_kdd_large(self, root='./Data'):
        def get_target_label_idx(labels, targets):
            """
            Get the indices of labels that are included in targets.
            :param labels: array of labels
            :param targets: list/tuple of target labels
            :return: list with indices of target labels
            """
            return np.argwhere(np.isin(labels, targets)).flatten().tolist()

        with open(root + '/kdd-cup99_train.npy', 'rb') as f:
            train_data = np.load(f)

        with open(root + '/kdd-cup99_test.npy', 'rb') as f:
            test_data = np.load(f)

        train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)

        train_idx_normal = get_target_label_idx(train_data_y, [0])

        test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

        kdd_normal = torch.tensor(train_data_x[train_idx_normal, :])
        kdd_normal = kdd_normal[
            torch.randperm(kdd_normal.shape[0])]
        test_data_y[test_data_y != 0] = 1

        test_data_x = torch.tensor(test_data_x)
        test_data_y = torch.tensor(test_data_y)
        return kdd_normal, test_data_x, test_data_y


    def build_train_test_kdd_rev(self,name_of_file):  # takes a matrice that contains the entire dataset, and creates a trainset with 50% of the data of all normals, and the rest are test
        zf = zipfile.ZipFile(name_of_file)
        kdd_loader = pd.read_csv(zf.open('kddcup.data_10_percent_corrected'), delimiter=',')
        entire_set = np.array(kdd_loader)
        revised_pd = pd.DataFrame(entire_set)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 1], prefix='new1')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 2], prefix='new2')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 3], prefix='new3')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 6], prefix='new6')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 11], prefix='new11')), axis=1)
        revised_pd = pd.concat((revised_pd, pd.get_dummies(revised_pd.iloc[:, 21], prefix='new21')), axis=1)
        revised_pd.drop(revised_pd.columns[[1, 2, 3, 6, 11, 20, 21]], inplace=True, axis=1)
        new_columns = [0, 'new1_icmp', 'new1_tcp', 'new1_udp', 'new2_IRC', 'new2_X11', 'new2_Z39_50', 'new2_auth',
                       'new2_bgp',
                       'new2_courier', 'new2_csnet_ns', 'new2_ctf', 'new2_daytime', 'new2_discard', 'new2_domain',
                       'new2_domain_u', 'new2_echo', 'new2_eco_i', 'new2_ecr_i', 'new2_efs', 'new2_exec', 'new2_finger',
                       'new2_ftp', 'new2_ftp_data', 'new2_gopher', 'new2_hostnames', 'new2_http', 'new2_http_443',
                       'new2_imap4',
                       'new2_iso_tsap', 'new2_klogin', 'new2_kshell', 'new2_ldap', 'new2_link', 'new2_login', 'new2_mtp',
                       'new2_name', 'new2_netbios_dgm', 'new2_netbios_ns', 'new2_netbios_ssn', 'new2_netstat', 'new2_nnsp',
                       'new2_nntp', 'new2_ntp_u', 'new2_other', 'new2_pm_dump', 'new2_pop_2', 'new2_pop_3', 'new2_printer',
                       'new2_private', 'new2_red_i', 'new2_remote_job', 'new2_rje', 'new2_shell', 'new2_smtp',
                       'new2_sql_net',
                       'new2_ssh', 'new2_sunrpc', 'new2_supdup', 'new2_systat', 'new2_telnet', 'new2_tftp_u', 'new2_tim_i',
                       'new2_time', 'new2_urh_i', 'new2_urp_i', 'new2_uucp', 'new2_uucp_path', 'new2_vmnet', 'new2_whois',
                       'new3_OTH', 'new3_REJ', 'new3_RSTO', 'new3_RSTOS0', 'new3_RSTR', 'new3_S0', 'new3_S1', 'new3_S2',
                       'new3_S3', 'new3_SF', 'new3_SH', 4, 5, 'new6_0', 'new6_1', 7, 8, 9, 10, 'new11_0', 'new11_1', 12, 13,
                       14,
                       15, 16, 17, 18, 19, 'new21_0', 'new21_1', 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
                       35, 36, 37, 38, 39, 40, 41]
        revised_pd = revised_pd.reindex(columns=new_columns)
        revised_pd.loc[revised_pd[41] != 'normal.', 41] = 1.0
        revised_pd.loc[revised_pd[41] == 'normal.', 41] = 0.0
        kdd_anomaly = np.array(revised_pd.loc[revised_pd[41] == 1.0], dtype=np.double)
        kdd_normal = np.array(revised_pd.loc[revised_pd[41] == 0.0], dtype=np.double)
        kdd_normal = torch.tensor(kdd_normal)
        kdd_anomaly = torch.tensor(kdd_anomaly)
        kdd_anomaly = kdd_anomaly[random.sample(range(kdd_anomaly.shape[0]), int(kdd_normal.shape[0] / 4)), :]
        kdd_normal = kdd_normal[ torch.randperm(kdd_normal.shape[0])]
        train, test_norm = torch.split(kdd_normal, int(kdd_normal.shape[0] / 2) + 1)
        test = torch.cat((test_norm, kdd_anomaly))
        test = test[torch.randperm(test.shape[0])]
        test_classes = test[:, -1].view(-1, 1)
        train = train[:, 0:train.shape[1] - 1]
        test = test[:, 0:test.shape[1] - 1]
        return (train, test, test_classes)

    def build_train_test_thyroid(self, root):
        with open(root + '/thyroid_train.npy', 'rb') as f:
            train_data = np.load(f)

        with open(root + '/thyroid_test.npy', 'rb') as f:
            test_data = np.load(f)

        # load data
        train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)
        test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

        # normalize
        mu = np.mean(train_data_x, axis=0)
        std = np.std(train_data_x, axis=0)
        train_data_x = (train_data_x - mu) / (std + 1e-5)

        # mu = np.mean(test_data_x, axis=0)
        # std = np.std(test_data_x, axis=0)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

        kdd_normal = torch.tensor(train_data_x[train_data_y == 0])
        kdd_normal = kdd_normal[
            torch.randperm(kdd_normal.shape[0])]

        test_data_x = torch.tensor(test_data_x)
        test_data_y = torch.tensor(test_data_y)
        return kdd_normal, test_data_x, test_data_y


