import copy

import numpy as np


def process_mat_dataset(data, n_noise=0.05):
    # data[:, -1] the last column of data should be the binary label
    normals = data[data[:, -1] == 0]
    anomalies = data[data[:, -1] == 1]

    n_normals = normals.shape[0]
    # permutation has been fixed through different methods using a saved dict
    # normals = normals[np.random.permutation(n_normals)]

    # add noise from true anomaly
    anomalies = anomalies[np.random.permutation(anomalies.shape[0])]
    if n_noise < 1:
        n_noise = int(n_normals * n_noise)
    noise = anomalies[:n_noise]

    #  todo: add true noise

    train = np.concatenate([normals[:n_normals // 2 + 1], noise])

    n_train = train.shape[0]
    val = train[:int(n_train * 0.3) + 1]
    train = train[int(n_train * 0.3) + 1:]

    test_normal = normals[n_normals // 2 + 1:]
    test = np.concatenate([test_normal, anomalies[n_noise:]])
    test = test[np.random.permutation(test.shape[0])]
    test_classes = test[:, -1]
    train = train[:, :-1]
    test = test[:, :-1]
    val = val[:, :-1]
    return train, val, test, test_classes


def get_target_label_idx(labels, targets):
    """
    Get the indices of labels that are included in targets.
    :param labels: array of labels
    :param targets: list/tuple of target labels
    :return: list with indices of target labels
    """
    return np.argwhere(np.isin(labels, targets)).flatten().tolist()


def convert_multiclass(train_data, normal_class, test_data=None):
    new_data = []
    for data in [copy.deepcopy(train_data), copy.deepcopy(test_data)]:
        if data is None:
            new_data.append(None)
            continue
        else:
            y = data[:, -1].reshape(-1, 1).astype(int)
            data[:, -1] = np.apply_along_axis(lambda x: 0 if x == normal_class else 1, 1, y)
            new_data.append(data)
    return new_data[0], new_data[1]


def load_dataset(root, name):
    if name in ['arrhythmia', 'wine', 'lympho', 'glass', 'vertebral', 'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
                'vowels', 'letter', 'cardio', 'seismic', 'musk', 'speech', 'abalone', 'pendigits', 'mammography',
                'mulcross', 'forest_cover']:
        file_name = root + '/' + name + '/' + name + '.npy'
        with open(file_name, 'rb') as f:
            data = np.load(f)
        y = data[:, -1].astype(int)
        classes = list(set(y))
        return data, None, classes
    elif name in ['thyroid', 'optdigits', 'satimage', 'shuttle', 'kdd-small']:  # for dataset with test set
        train_file_name = root + '/' + name + '/' + name + '_train.npy'
        with open(train_file_name, 'rb') as f:
            train_data = np.load(f)

        test_file_name = root + '/' + name + '/' + name + '_test.npy'
        with open(test_file_name, 'rb') as f:
            test_data = np.load(f)
        y = train_data[:, -1].astype(int)
        classes = list(set(y))
        return train_data, test_data, classes
    else:
        raise NotImplementedError


def load_adbench_dataset(root, name):
    if name in ['48_arrhythmia', '49_shuttle',
        '10_cover', '11_donors', '12_fault', '13_fraud', '14_glass', '15_Hepatitis', '16_http',
                '17_InternetAds', '18_Ionosphere', '19_landsat',
                '1_ALOI', '20_letter', '21_Lymphography', '22_magic.gamma', '23_mammography', '24_mnist',
                '25_musk', '26_optdigits', '27_PageBlocks', '28_pendigits', '29_Pima', '2_annthyroid',
                '30_satellite', '31_satimage-2', '32_shuttle', '33_skin', '34_smtp', '35_SpamBase', '36_speech',
                '37_Stamps', '38_thyroid', '39_vertebral', '3_backdoor', '40_vowels', '41_Waveform', '42_WBC',
                '43_WDBC', '44_Wilt', '45_wine', '46_WPBC', '47_yeast', '4_breastw', '5_campaign', '6_cardio',
                '7_Cardiotocography', '8_celeba', '9_census']:
        file_name = root + '/' + 'Classical' + '/' + name + '.npz'
        data = np.load(file_name, allow_pickle=True)
        X, y = data['X'], data['y']
        data = np.hstack([X, y.reshape(-1, 1)])
        classes = list(set(y))
        return data, None, classes
    else:
        raise NotImplementedError


def process_dataset(train_data, test_data, classes, normal_class, n_noise, normalize=False):
    if len(classes) > 2:  # multiclass dataset
        y = train_data[:, -1].astype(int)
        n_normal_sample = sum(y == normal_class)
        if n_normal_sample <= 20:
            return None, None, None
        else:
            train_data, test_data = convert_multiclass(train_data, normal_class, test_data)

    if test_data is None:
        train_data_x, val_data_x, test_data_x, test_data_y = process_mat_dataset(train_data, n_noise)
        train_data_y = np.zeros(train_data_x.shape[0])
    else: # dataset with test set
        anomalies = train_data[train_data[:, -1] == 1]
        normals = train_data[train_data[:, -1] == 0]

        anomalies = anomalies[np.random.permutation(anomalies.shape[0])]
        if n_noise < 1:
            n_noise = int(normals.shape[0] * n_noise)
        noise = anomalies[:n_noise]

        #  add true noise

        train_data = np.concatenate([normals, noise])

        train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)
        test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)
    n_sample, n_dim = train_data_x.shape

    # normalize
    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)
        val_data_x = (val_data_x - mu) / (std + 1e-5)

    test_data_y = test_data_y.reshape(-1, 1)
    return train_data_x, val_data_x, test_data_x, test_data_y, mu, std