import copy

import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import DataLoader

from .CustomDataset import CustomDataset
from .TruncatedDataset import TruncatedDataset


def process_mat_dataset(data):
    # data[:, -1] the last column of data should be the binary label
    normals = data[data[:, -1] == 0]
    anomalies = data[data[:, -1] == 1]
    n_normals = normals.shape[0]
    normals = normals[np.random.permutation(n_normals)]
    train = normals[:n_normals // 2 + 1]
    test_normal = normals[n_normals // 2 + 1:]
    test = np.concatenate([test_normal, anomalies])
    test = test[np.random.permutation(test.shape[0])]
    test_classes = test[:, -1]
    train = train[:, :-1]
    test = test[:, :-1]
    return train, test, test_classes


def get_target_label_idx(labels, targets):
    """
    Get the indices of labels that are included in targets.
    :param labels: array of labels
    :param targets: list/tuple of target labels
    :return: list with indices of target labels
    """
    return np.argwhere(np.isin(labels, targets)).flatten().tolist()


def convert_multiclass(train_data, normal_class, test_data=None):
    new_data = []
    for data in [copy.deepcopy(train_data), copy.deepcopy(test_data)]:
        if data is None:
            new_data.append(None)
            continue
        else:
            y = data[:, -1].reshape(-1, 1).astype(int)
            data[:, -1] = np.apply_along_axis(lambda x: 0 if x == normal_class else 1, 1, y)
            new_data.append(data)
    return new_data[0], new_data[1]


def load_dataset(root, name):
    if name in ['arrhythmia', 'wine', 'lympho', 'glass', 'vertebral', 'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
                'vowels', 'letter', 'cardio', 'seismic', 'musk', 'speech', 'abalone', 'pendigits', 'mammography',
                'mulcross', 'forest_cover']:
        file_name = root + '/' + name + '/' + name + '.npy'
        with open(file_name, 'rb') as f:
            data = np.load(f)
        y = data[:, -1].astype(int)
        classes = list(set(y))
        return data, None, classes
    elif name in ['thyroid', 'optdigits', 'satimage', 'shuttle', 'kdd']:  # for dataset with test set
        train_file_name = root + '/' + name + '/' + name + '_train.npy'
        with open(train_file_name, 'rb') as f:
            train_data = np.load(f)

        test_file_name = root + '/' + name + '/' + name + '_test.npy'
        with open(test_file_name, 'rb') as f:
            test_data = np.load(f)
        y = train_data[:, -1].astype(int)
        classes = list(set(y))
        return train_data, test_data, classes
    else:
        raise NotImplementedError


def process_dataset(train_data, test_data, classes, normal_class, b_size, normalize=False):
    if len(classes) > 2:  # multiclass dataset
        y = train_data[:, -1].astype(int)
        n_normal_sample = sum(y == normal_class)
        if n_normal_sample <= 20:
            return None, None, None, None, None, None
        else:
            train_data, test_data = convert_multiclass(train_data, normal_class, test_data)

    if test_data is None:  # dataset with test set
        train_data_x, test_data_x, test_data_y = process_mat_dataset(train_data)
        train_data_y = np.zeros(train_data_x.shape[0])
    else:
        train_data = train_data[train_data[:, -1] == 0]
        train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)
        test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)
    n_sample, n_dim = train_data_x.shape

    # normalize
    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

    train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
    train_dataset = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                      target_transform=None)
    test_dataset = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                                     target_transform=None)

    train_loader = DataLoader(dataset=train_dataset, batch_size=b_size, shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(dataset=test_dataset, batch_size=b_size, shuffle=False,
                             num_workers=0)
    return train_loader, test_loader, mu, std, n_dim, n_sample


def load_mat_dataset(root, name, b_size, normalize=False):
    if name in ['arrhythmia', 'thyroid_train_only', 'thyroid_train_test']:
        file_name = root + '/' + name + '.npy'
        with open(file_name, 'rb') as f:
            data = np.load(f)
        train_data_x, test_data_x, test_data_y = process_mat_dataset(data)
        train_data_y = np.zeros(train_data_x.shape[0])

        mu = np.mean(train_data_x, axis=0)
        std = np.std(train_data_x, axis=0)

        # normalize
        if normalize:
            train_data_x = (train_data_x - mu) / (std + 1e-5)
            test_data_x = (test_data_x - mu) / (std + 1e-5)

        train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
        kdd_train_dataset = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                          target_transform=None)
        kdd_test_dataset = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                                         target_transform=None)

        train_loader = DataLoader(dataset=kdd_train_dataset, batch_size=b_size, shuffle=True,
                                  num_workers=0)
        test_loader = DataLoader(dataset=kdd_test_dataset, batch_size=b_size, shuffle=False,
                                 num_workers=0)
        return train_loader, test_loader, mu, std
    else:
        raise NotImplementedError


def load_thyroid_dataset(root, b_size, normalize=False):
    with open(root + '/thyroid_train.npy', 'rb') as f:
        train_data = np.load(f)

    with open(root + '/thyroid_test.npy', 'rb') as f:
        test_data = np.load(f)

    # load data
    train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)
    test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)

    # normalize
    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

    train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
    kdd_train_dataset = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                      target_transform=None)
    kdd_test_dataset = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                                     target_transform=None)

    train_loader = DataLoader(dataset=kdd_train_dataset, batch_size=b_size, shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(dataset=kdd_test_dataset, batch_size=b_size, shuffle=False,
                             num_workers=0)
    return train_loader, test_loader, mu, std


def load_kdd_cup_dataset(root, normal_class, b_size, normalize=True):
    with open(root + '/kdd-cup99_train.npy', 'rb') as f:
        train_data = np.load(f)

    with open(root + '/kdd-cup99_test.npy', 'rb') as f:
        test_data = np.load(f)

    # load data
    train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)
    test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)

    # normalize
    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

    # train_transform = transforms.Compose([
    #     transforms.ToTensor(),
    # ])
    train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
    outlier_classes = list(range(0, 40))
    outlier_classes.remove(normal_class)
    train_idx_normal = get_target_label_idx(train_data_y, [normal_class])

    target_transform = transforms.Lambda(lambda x: int(x in outlier_classes))
    kdd_train_dataset = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                      target_transform=target_transform)
    kdd_train_dataset = TruncatedDataset(kdd_train_dataset, dataidxs=train_idx_normal,
                                         transform=train_transform, target_transform=target_transform)


    kdd_test_dataset = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                                     target_transform=target_transform)

    train_loader = DataLoader(dataset=kdd_train_dataset, batch_size=b_size, shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(dataset=kdd_test_dataset, batch_size=b_size, shuffle=False,
                             num_workers=0)
    return train_loader, test_loader, mu, std
