import numpy as np
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

from .CustomDataset import CustomDataset
from .TruncatedDataset import TruncatedDataset


def get_cifar10_dataset(root='./data'):
    transform_train = transforms.Compose([
        # transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        # transforms.Normalize(
        #     (0.485, 0.456, 0.406),
        #     (0.229, 0.224, 0.225)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize(
        #     (0.485, 0.456, 0.406),
        #     (0.229, 0.224, 0.225)),
    ])
    #
    train_transform = transforms.Lambda(lambda x: torch.Tensor(x))

    trainset = torchvision.datasets.CIFAR10(
        root=root, train=True, download=True, transform=train_transform)

    testset = torchvision.datasets.CIFAR10(
        root=root, train=False, download=True, transform=train_transform)
    return trainset, testset


def get_fmnist_dataloader(root='./data'):
    transform_train = transforms.Compose([
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    trainset = torchvision.datasets.FashionMNIST(
        root=root, train=True, download=True, transform=transform_train)

    testset = torchvision.datasets.FashionMNIST(
        root=root, train=False, download=True, transform=transform_test)
    return trainset, testset



def process_mat_dataset(data):
    # data[:, -1] the last column of data should be the binary label
    normals = data[data[:, -1] == 0]
    anomalies = data[data[:, -1] == 1]
    n_normals = normals.shape[0]
    normals = normals[np.random.permutation(n_normals)]
    train = normals[:n_normals // 2 + 1]
    test_normal = normals[n_normals // 2 + 1:]
    test = np.concatenate([test_normal, anomalies])
    test = test[np.random.permutation(test.shape[0])]
    test_classes = test[:, -1]
    train = train[:, :-1]
    test = test[:, :-1]
    return train, test, test_classes


def get_target_label_idx(labels, targets):
    """
    Get the indices of labels that are included in targets.
    :param labels: array of labels
    :param targets: list/tuple of target labels
    :return: list with indices of target labels
    """
    return np.argwhere(np.isin(labels, targets)).flatten().tolist()


def convert_multiclass(train_data, normal_class, test_data=None):
    new_data = []
    for data in [copy.deepcopy(train_data), copy.deepcopy(test_data)]:
        if data is None:
            new_data.append(None)
            continue
        else:
            y = data[:, -1].reshape(-1, 1).astype(int)
            data[:, -1] = np.apply_along_axis(lambda x: 0 if x == normal_class else 1, 1, y)
            new_data.append(data)
    return new_data[0], new_data[1]


def load_adbench_dataset(root, name):
    if name in ['10_cover', '11_donors', '12_fault', '13_fraud', '14_glass', '15_Hepatitis', '16_http',
                '17_InternetAds', '18_Ionosphere', '19_landsat',
                '1_ALOI', '20_letter', '21_Lymphography', '22_magic.gamma', '23_mammography', '24_mnist',
                '25_musk', '26_optdigits', '27_PageBlocks', '28_pendigits', '29_Pima', '2_annthyroid',
                '30_satellite', '31_satimage-2', '32_shuttle', '33_skin', '34_smtp', '35_SpamBase', '36_speech',
                '37_Stamps', '38_thyroid', '39_vertebral', '3_backdoor', '40_vowels', '41_Waveform', '42_WBC',
                '43_WDBC', '44_Wilt', '45_wine', '46_WPBC', '47_yeast', '4_breastw', '5_campaign', '6_cardio',
                '7_Cardiotocography', '8_celeba', '9_census', '48_arrhythmia']:
        file_name = root + '/' + 'Classical' + '/' + name + '.npz'
        data = np.load(file_name, allow_pickle=True)
        X, y = data['X'], data['y']
        data = np.hstack([X, y.reshape(-1, 1)])
        classes = list(set(y))
        return data, None, classes
    elif name in ['CIFAR10_0', 'CIFAR10_1', 'CIFAR10_2', 'CIFAR10_3', 'CIFAR10_4',
                  'CIFAR10_5', 'CIFAR10_6', 'CIFAR10_7', 'CIFAR10_8', 'CIFAR10_9']:
        file_name = root + '/' + 'CV_by_ViT' + '/' + name + '.npz'
        data = np.load(file_name, allow_pickle=True)
        X, y = data['X'], data['y']
        data = np.hstack([X, y.reshape(-1, 1)])
        classes = list(set(y))
        return data, None, classes
    else:
        raise NotImplementedError



def load_img_dataset(root, name):
    if name in ['cifar10', 'fmnist']:
        if name == 'cifar10':
            train_set, test_set = get_cifar10_dataset(root)
            classes = [i for i in range(len(train_set.classes))]
            return train_set, test_set, classes
        elif name == 'fmnist':
            train_set, test_set = get_fmnist_dataloader(root)
            classes = [i for i in range(len(train_set.classes))]
            return train_set, test_set, classes
    else:
        raise NotImplementedError


def process_img_dataset(train_data, test_data, normal_class, b_size, normalize=True):
    train_data_normal_idx = np.where(np.array(train_data.targets) == normal_class)[0]
    subset_sampler = torch.utils.data.SubsetRandomSampler(train_data_normal_idx)

    shape = train_data.data.shape
    if len(shape) == 3:
        n_channel = 1
        img_size = shape[1]
        train_data.data = train_data.data.unsqueeze(3)
        test_data.data = test_data.data.unsqueeze(3)
    else:
        img_size = shape[1]
        n_channel = shape[3]

    train_data_x = np.array(train_data.data[train_data_normal_idx])
    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)

    test_data_x = np.array(test_data.data)

    # normalize
    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

    train_data_x = train_data_x.transpose((0, 3, 1, 2))
    train_data_y = np.array(train_data.targets)[train_data_normal_idx]

    test_data_x = test_data_x.transpose((0, 3, 1, 2))

    train_data.data = train_data_x
    test_data.data = test_data_x

    train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
    train_data = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                target_transform=None)

    train_loader = torch.utils.data.DataLoader(
            train_data, batch_size=b_size,
        shuffle=True,
        num_workers=0,
        # sampler=subset_sampler
    )
    n_samples = len(train_data_normal_idx)
    print(n_samples)

    test_data_normal_idx = np.where(np.array(test_data.targets) == normal_class)
    converted_test_classes = np.ones_like(test_data.targets)
    converted_test_classes[test_data_normal_idx] = 0
    test_data.targets = converted_test_classes
    test_data_y = converted_test_classes

    test_data = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                               target_transform=None)
    test_loader = torch.utils.data.DataLoader(
        test_data, batch_size=b_size,
        shuffle=False,
        num_workers=0)
    return train_loader, test_loader, n_samples, img_size, n_channel


def load_cifar10_ae_dataset(root, cls, net_name, normalize, b_size):
    # 'cifar10-0-deepsvddAE-128-AEfeature-400-0.pth.tar.pth.tar'
    file_name = root + '/' + f'cifar10-{cls}-{net_name}-128-AEfeature-100-0.pth.tar.pth.tar'
    save = torch.load(file_name)
    train_data_x = save["train_features"].cpu().numpy()
    train_data_y = np.zeros(train_data_x.shape[0])

    test_data_x = save["test_features"].cpu().numpy()
    test_data_y = save["test_labels"].cpu().numpy()

    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)
    n_sample, n_dim = train_data_x.shape

    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

    train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
    train_dataset = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                  target_transform=None)
    test_dataset = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                                 target_transform=None)

    train_loader = DataLoader(dataset=train_dataset, batch_size=b_size, shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(dataset=test_dataset, batch_size=b_size, shuffle=False,
                             num_workers=0)

    return train_loader, test_loader, mu, std, n_dim, n_sample


def load_dataset(root, name):
    if name in ['arrhythmia', 'wine', 'lympho', 'glass', 'vertebral', 'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
                'vowels', 'letter', 'cardio', 'seismic', 'musk', 'speech', 'abalone', 'pendigits', 'mammography',
                'mulcross', 'forest_cover']:
        file_name = root + '/' + name + '/' + name + '.npy'
        with open(file_name, 'rb') as f:
            data = np.load(f)
        y = data[:, -1].astype(int)
        classes = list(set(y))
        return data, None, classes
    elif name in ['thyroid', 'optdigits', 'satimage', 'shuttle', 'kdd']:  # for dataset with test set
        train_file_name = root + '/' + name + '/' + name + '_train.npy'
        with open(train_file_name, 'rb') as f:
            train_data = np.load(f)

        test_file_name = root + '/' + name + '/' + name + '_test.npy'
        with open(test_file_name, 'rb') as f:
            test_data = np.load(f)
        y = train_data[:, -1].astype(int)
        classes = list(set(y))
        return train_data, test_data, classes
    else:
        raise NotImplementedError


def process_dataset(train_data, test_data, classes, normal_class, b_size, normalize=False):
    if len(classes) > 2:  # multiclass dataset
        y = train_data[:, -1].astype(int)
        n_normal_sample = sum(y == normal_class)
        if n_normal_sample <= 20:
            return None, None, None, None, None, None
        else:
            train_data, test_data = convert_multiclass(train_data, normal_class, test_data)

    if test_data is None:
        train_data_x, test_data_x, test_data_y = process_mat_dataset(train_data)
        train_data_y = np.zeros(train_data_x.shape[0])
    else:  # dataset with test set
        train_data = train_data[train_data[:, -1] == 0]
        train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)
        test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)
    n_sample, n_dim = train_data_x.shape

    # normalize
    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

    train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
    train_dataset = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                      target_transform=None)
    test_dataset = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                                     target_transform=None)

    train_loader = DataLoader(dataset=train_dataset, batch_size=b_size, shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(dataset=test_dataset, batch_size=b_size, shuffle=False,
                             num_workers=0)
    return train_loader, test_loader, mu, std, n_dim, n_sample


def load_mat_dataset(root, name, b_size, normalize=False):
    if name in ['arrhythmia', 'thyroid_train_only', 'thyroid_train_test']:
        file_name = root + '/' + name + '.npy'
        with open(file_name, 'rb') as f:
            data = np.load(f)
        train_data_x, test_data_x, test_data_y = process_mat_dataset(data)
        train_data_y = np.zeros(train_data_x.shape[0])

        mu = np.mean(train_data_x, axis=0)
        std = np.std(train_data_x, axis=0)

        # normalize
        if normalize:
            train_data_x = (train_data_x - mu) / (std + 1e-5)
            test_data_x = (test_data_x - mu) / (std + 1e-5)

        train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
        kdd_train_dataset = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                          target_transform=None)
        kdd_test_dataset = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                                         target_transform=None)

        train_loader = DataLoader(dataset=kdd_train_dataset, batch_size=b_size, shuffle=True,
                                  num_workers=0)
        test_loader = DataLoader(dataset=kdd_test_dataset, batch_size=b_size, shuffle=False,
                                 num_workers=0)
        return train_loader, test_loader, mu, std
    else:
        raise NotImplementedError


def load_thyroid_dataset(root, b_size, normalize=False):
    with open(root + '/thyroid_train.npy', 'rb') as f:
        train_data = np.load(f)

    with open(root + '/thyroid_test.npy', 'rb') as f:
        test_data = np.load(f)

    # load data
    train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)
    test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)

    # normalize
    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

    train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
    kdd_train_dataset = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                      target_transform=None)
    kdd_test_dataset = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                                     target_transform=None)

    train_loader = DataLoader(dataset=kdd_train_dataset, batch_size=b_size, shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(dataset=kdd_test_dataset, batch_size=b_size, shuffle=False,
                             num_workers=0)
    return train_loader, test_loader, mu, std


def load_kdd_cup_dataset(root, normal_class, b_size, normalize=True):
    with open(root + '/kdd-cup99_train.npy', 'rb') as f:
        train_data = np.load(f)

    with open(root + '/kdd-cup99_test.npy', 'rb') as f:
        test_data = np.load(f)

    # load data
    train_data_x, train_data_y = train_data[:, :-1], train_data[:, -1].astype(int)
    test_data_x, test_data_y = test_data[:, :-1], test_data[:, -1].astype(int)

    mu = np.mean(train_data_x, axis=0)
    std = np.std(train_data_x, axis=0)

    # normalize
    if normalize:
        train_data_x = (train_data_x - mu) / (std + 1e-5)
        test_data_x = (test_data_x - mu) / (std + 1e-5)

    # train_transform = transforms.Compose([
    #     transforms.ToTensor(),
    # ])
    train_transform = transforms.Lambda(lambda x: torch.Tensor(x))
    outlier_classes = list(range(0, 40))
    outlier_classes.remove(normal_class)
    train_idx_normal = get_target_label_idx(train_data_y, [normal_class])

    target_transform = transforms.Lambda(lambda x: int(x in outlier_classes))
    kdd_train_dataset = CustomDataset(train_data_x, train_data_y, transform=train_transform,
                                      target_transform=target_transform)
    kdd_train_dataset = TruncatedDataset(kdd_train_dataset, dataidxs=train_idx_normal,
                                         transform=train_transform, target_transform=target_transform)


    kdd_test_dataset = CustomDataset(test_data_x, test_data_y, transform=train_transform,
                                     target_transform=target_transform)

    train_loader = DataLoader(dataset=kdd_train_dataset, batch_size=b_size, shuffle=True,
                              num_workers=0)
    test_loader = DataLoader(dataset=kdd_test_dataset, batch_size=b_size, shuffle=False,
                             num_workers=0)
    return train_loader, test_loader, mu, std
