import os
import sys
sys.path.insert(0, './')
import numpy as np

import torch
import torch.nn as nn
from torchvision import datasets, transforms

from .Utility import SubsetRandomSampler, SubsetSampler, MislabelDataset


class IndexedCIFAR10(datasets.CIFAR10):
    def __init__(self, root, train, download, transform):
        super(IndexedCIFAR10, self).__init__(root = root, train = train, transform = transform, download = download)

    def __getitem__(self, index):
        img, target = super(IndexedCIFAR10, self).__getitem__(index)
        return img, target, index


class IndexedSynCIFAR10(torch.utils.data.Dataset):
    def __init__(self, data, train, transform, chunk=0):
        super(IndexedSynCIFAR10, self).__init__()
        self.train = train
        self.images = data["X"]
        self.targets = data["Y"]
        if train:
            self.images = self.images[20000 * chunk:20000 * chunk + 10000]
            self.targets = self.targets[20000 * chunk:20000 * chunk + 10000]
        else:
            self.images = self.images[20000 * chunk + 10000:20000 * (chunk + 1)]  # 每个shadow的训练集是多大？
            self.targets = self.targets[20000 * chunk + 10000:20000 * (chunk + 1)]

        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img = torch.from_numpy(self.images[index]).permute(2, 0, 1)
        target = self.targets[index]
        img = self.transform(img)
        img = img.float() / 255.0
        return img, target, index


def cifar10(batch_size, root='./data/cifar10', valid_ratio=None, shuffle=True, augmentation=True, train_subset=None, test_subset=None,
            mislabel_ratio=0., mislabel_seed=0, class_subset_path=None, is_split=False, split_seed=0, is_shadow=False, shadow_ratio=0.8, shadow_seed=0,
            member_train=False, member_test=False, nonmember_train=False, nonmember_test=False, num_worker=0):
    '''
    batch_size: batch size.
    root: where data is stored.
     valid_ratio: the ratio of validation data, None if no validation set.
     shuffle: whether or not the training set is shuffled.
     augmentation: whether or not the augmentation is applied.
     train_subset: the specified subset for training, None if we use the whole training set.
    '''

    assert class_subset_path is None, 'Class subset is not supported for CIFAR-10 dataset.'
    if member_train and member_test:
        raise ValueError('member_train and member_test cannot be both True.')
    if nonmember_train and nonmember_test:
        raise ValueError('nonmember_train and nonmember_test cannot be both True.')

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding = 4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
        ]) if augmentation == True else transforms.Compose([
        transforms.ToTensor()
        ])
    transform_valid = transforms.Compose([
        transforms.ToTensor()
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor()
        ])


    trainset = IndexedCIFAR10(root = root, train = True, download = True, transform = transform_train)
    validset = IndexedCIFAR10(root = root, train = True, download = True, transform = transform_valid)
    testset = IndexedCIFAR10(root = root, train = False, download = True, transform = transform_test)

    # wrap the dataset with mislabel
    if mislabel_ratio > 0:
        trainset = MislabelDataset(trainset, num_class=10, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)
        testset = MislabelDataset(testset, num_class=10, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)
        validset = MislabelDataset(validset, num_class=10, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    if is_split:
        num_train_data, num_test_data = len(trainset), len(testset)
        np.random.seed(split_seed)
        train_subset = np.random.choice(num_train_data, size=num_train_data//2, replace=False)
        np.random.seed(split_seed)
        test_subset = np.random.choice(num_test_data, size=num_test_data//2, replace=False)
        if is_shadow:
            train_subset = set(np.arange(num_train_data)) - set(train_subset)
            train_subset = np.array(list(train_subset))
            test_subset = set(np.arange(num_test_data)) - set(test_subset)
            test_subset = np.array(list(test_subset))
            np.random.seed(shadow_seed)
            train_subset = np.random.choice(train_subset, size=int(len(train_subset) * shadow_ratio), replace=False)
            np.random.seed(shadow_seed)
            test_subset = np.random.choice(test_subset, size=int(len(test_subset) * shadow_ratio), replace=False)

    if train_subset is None:
        train_indices = list(range(len(trainset)))
    else:
        train_indices = np.random.permutation(train_subset)
    if member_train:
        train_indices = train_indices[:len(train_indices) // 2]
    if member_test:
        train_indices = train_indices[len(train_indices) // 2:]
    train_instance_num = len(train_indices)
    print('%d instances are picked from the training set' % train_instance_num)

    if test_subset is None:
        test_indices = list(range(len(testset)))
    else:
        test_indices = test_subset
    if nonmember_train:
        test_indices = test_indices[:len(test_indices) // 2]
    if nonmember_test:
        test_indices = test_indices[len(test_indices) // 2:]
    test_instance_num = len(test_indices)
    print('%d instances are picked from the test set' % test_instance_num)
    test_sampler = SubsetSampler(test_indices)

    if valid_ratio is not None and valid_ratio > 0.:
        split_pt = int(train_instance_num * valid_ratio)
        train_idx, valid_idx = train_indices[split_pt:], train_indices[:split_pt]

        if shuffle == True:
            train_sampler, valid_sampler = SubsetRandomSampler(train_idx), SubsetSampler(valid_idx)
        else:
            train_sampler, valid_sampler = SubsetSampler(train_idx), SubsetSampler(valid_idx)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, sampler = train_sampler, num_workers = num_worker, pin_memory = True)
        valid_loader = torch.utils.data.DataLoader(validset, batch_size = batch_size, sampler = valid_sampler, num_workers = num_worker, pin_memory = True)
        test_loader = torch.utils.data.DataLoader(testset, batch_size = batch_size, sampler = test_sampler, shuffle = False, num_workers = num_worker, pin_memory = True)

    else:
        if shuffle == True:
            train_sampler = SubsetRandomSampler(train_indices)
        else:
            train_sampler = SubsetSampler(train_indices)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=num_worker, pin_memory = True)
        valid_loader = None
        test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, sampler=test_sampler, shuffle=False, num_workers=num_worker, pin_memory = True)

    return train_loader, valid_loader, test_loader, classes


def syn_cifar10(batch_size, root="/data/cifar5m/part0.npz", valid_ratio=None, shuffle=True, augmentation=True, train_subset=None, test_subset=None,
            mislabel_ratio=0., mislabel_seed=0, class_subset_path=None, is_split=False, split_seed=0, is_shadow=False, shadow_ratio=0.8, shadow_seed=0,
            member_train=False, member_test=False, nonmember_train=False, nonmember_test=False, num_worker=0, chunk=0):
    '''
    batch_size: batch size.
    root: where data is stored.
     valid_ratio: the ratio of validation data, None if no validation set.
     shuffle: whether or not the training set is shuffled.
     augmentation: whether or not the augmentation is applied.
     train_subset: the specified subset for training, None if we use the whole training set.
    '''

    assert class_subset_path is None, 'Class subset is not supported for CIFAR-10 dataset.'
    if member_train and member_test:
        raise ValueError('member_train and member_test cannot be both True.')
    if nonmember_train and nonmember_test:
        raise ValueError('nonmember_train and nonmember_test cannot be both True.')

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding = 4),
        transforms.RandomHorizontalFlip(),
        # transforms.ToTensor()
        ]) if augmentation == True else transforms.Compose([
        # transforms.ToTensor()
        ])
    transform_valid = transforms.Compose([
        # transforms.ToTensor()
        ])
    transform_test = transforms.Compose([
        # transforms.ToTensor()
        ])

    data = np.load(root)
    trainset = IndexedSynCIFAR10(data, train=True, transform=transform_train, chunk=chunk)
    validset = IndexedSynCIFAR10(data, train=True, transform=transform_valid, chunk=chunk)
    testset = IndexedSynCIFAR10(data, train=False, transform=transform_test, chunk=chunk)

    # wrap the dataset with mislabel
    if mislabel_ratio > 0:
        trainset = MislabelDataset(trainset, num_class=10, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)
        testset = MislabelDataset(testset, num_class=10, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)
        validset = MislabelDataset(validset, num_class=10, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    if is_split:
        num_train_data, num_test_data = len(trainset), len(testset)
        np.random.seed(split_seed)
        train_subset = np.random.choice(num_train_data, size=num_train_data//2, replace=False)
        np.random.seed(split_seed)
        test_subset = np.random.choice(num_test_data, size=num_test_data//2, replace=False)
        if is_shadow:
            train_subset = set(np.arange(num_train_data)) - set(train_subset)
            train_subset = np.array(list(train_subset))
            test_subset = set(np.arange(num_test_data)) - set(test_subset)
            test_subset = np.array(list(test_subset))
            np.random.seed(shadow_seed)
            train_subset = np.random.choice(train_subset, size=int(len(train_subset) * shadow_ratio), replace=False)
            np.random.seed(shadow_seed)
            test_subset = np.random.choice(test_subset, size=int(len(test_subset) * shadow_ratio), replace=False)

    if train_subset is None:
        train_indices = list(range(len(trainset)))
    else:
        train_indices = np.random.permutation(train_subset)
    if member_train:
        train_indices = train_indices[:len(train_indices) // 2]
    if member_test:
        train_indices = train_indices[len(train_indices) // 2:]
    train_instance_num = len(train_indices)
    print('%d instances are picked from the training set' % train_instance_num)

    if test_subset is None:
        test_indices = list(range(len(testset)))
    else:
        test_indices = test_subset
    if nonmember_train:
        test_indices = test_indices[:len(test_indices) // 2]
    if nonmember_test:
        test_indices = test_indices[len(test_indices) // 2:]
    test_instance_num = len(test_indices)
    print('%d instances are picked from the test set' % test_instance_num)
    test_sampler = SubsetSampler(test_indices)

    if valid_ratio is not None and valid_ratio > 0.:
        split_pt = int(train_instance_num * valid_ratio)
        train_idx, valid_idx = train_indices[split_pt:], train_indices[:split_pt]

        if shuffle == True:
            train_sampler, valid_sampler = SubsetRandomSampler(train_idx), SubsetSampler(valid_idx)
        else:
            train_sampler, valid_sampler = SubsetSampler(train_idx), SubsetSampler(valid_idx)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, sampler = train_sampler, num_workers = num_worker, pin_memory = True)
        valid_loader = torch.utils.data.DataLoader(validset, batch_size = batch_size, sampler = valid_sampler, num_workers = num_worker, pin_memory = True)
        test_loader = torch.utils.data.DataLoader(testset, batch_size = batch_size, sampler = test_sampler, shuffle = False, num_workers = num_worker, pin_memory = True)

    else:
        if shuffle == True:
            train_sampler = SubsetRandomSampler(train_indices)
        else:
            train_sampler = SubsetSampler(train_indices)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, sampler = train_sampler, num_workers = num_worker, pin_memory = True)
        valid_loader = None
        test_loader = torch.utils.data.DataLoader(testset, batch_size = batch_size, sampler = test_sampler, shuffle = False, num_workers = num_worker, pin_memory = True)

    return train_loader, valid_loader, test_loader, classes


if __name__ == '__main__':
    train_loader, valid_loader, test_loader, classes =  cifar10(100, root="/data", is_split=True, split_seed=42,
                                                                is_shadow=True, shadow_seed=0)
    print(100*len(train_loader), 100*len(test_loader))