# import tensorflow as tf
import numpy as np
from torchvision.transforms.transforms import CenterCrop
import torch.nn.functional as F
# import sklearn
import torch
# import tensorflow_datasets as tfds
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import TensorDataset, sampler, DataLoader,Dataset
import urllib
import tarfile
import random
import os
# torch.multiprocessing.set_sharing_strategy('file_system')
BUFFER_SIZE = 10000
SIZE = 32

# getImagesDS = lambda X, n: np.concatenate([x[0].numpy()[None,] for x in X.take(n)])
CIFAR10_TRAIN_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_TRAIN_STD = (0.247, 0.243, 0.261)
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
FACESCRUB_TRAIN_MEAN = (0.5708, 0.5905, 0.4272)
FACESCRUB_TRAIN_STD = (0.2058, 0.2275, 0.2098)


class MySubset(Dataset):
    def __init__(self,cifar10_training,index):
        self.cifar10_training_subset = torch.utils.data.Subset(cifar10_training, index)


    def __getitem__(self, index):
        data, target = self.cifar10_training_subset[index]

        # Your transformations here (or set it in CIFAR10)

        return data, target, index

    def __len__(self):
        return len(self.cifar10_training_subset)


def getImagesDS(X, n):
    image_list = []
    for i in range(n):
        image_list.append(X[i][0].numpy()[None,])
    return np.concatenate(image_list)

def remove_class_loader(some_dataset, label_class, batch_size=16, num_workers=2):
    def remove_one_label(target, label):
        label_indices = []
        excluded_indices = []
        for i in range(len(target)):
            if target[i] != label:
                label_indices.append(i)
            else:
                excluded_indices.append(i)
        return label_indices, excluded_indices



    indices, excluded_indices = remove_one_label(some_dataset.targets, label_class)

    new_data_loader = DataLoader(
        some_dataset, shuffle=False, num_workers=num_workers, batch_size=batch_size, sampler = torch.utils.data.sampler.SubsetRandomSampler(indices))
    excluded_data_loader = DataLoader(
        some_dataset, shuffle=False, num_workers=num_workers, batch_size=batch_size, sampler = torch.utils.data.sampler.SubsetRandomSampler(excluded_indices))
    
    return new_data_loader, excluded_data_loader


def load_mnist():
    xpriv = datasets.MNIST(root='./data', train=True, download=True)

    xpub = datasets.MNIST(root='./data', train=False)

    x_train = np.array(xpriv.data)
    y_train = np.array(xpriv.targets)
    x_test = np.array(xpub.data)
    y_test = np.array(xpub.targets)
    
    x_train = x_train[:, None, :, :]
    x_test = x_test[:, None, :, :]
    x_train = np.tile(x_train, (1,3,1,1))
    x_test = np.tile(x_test, (1,3,1,1))

    x_train = torch.Tensor(x_train)
    y_train = torch.Tensor(y_train).type(torch.LongTensor)
    x_test = torch.Tensor(x_test)
    y_test = torch.Tensor(y_test).type(torch.LongTensor)
    x_train = F.interpolate(x_train, (32, 32))
    x_test = F.interpolate(x_test, (32, 32))
    x_train  = x_train / (255/2) - 1
    x_test  = x_test / (255/2) - 1
    x_train = torch.clip(x_train, -1., 1.)
    x_test = torch.clip(x_test, -1., 1.)
    # Need a different way to denormalize
    xpriv = TensorDataset(x_train, y_train)
    xpub = TensorDataset(x_test, y_test)
    return xpriv, xpub

def load_mnist_membership():
    xpriv = datasets.MNIST(root='./data', train=True, download=True)

    xpub = datasets.MNIST(root='./data', train=False)

    x_train = np.array(xpriv.data)
    y_train = np.array(xpriv.targets)
    x_test = np.array(xpub.data)
    y_test = np.array(xpub.targets)
    
    x_train = x_train[:, None, :, :]
    x_test = x_test[:, None, :, :]
    x_train = np.tile(x_train, (1,3,1,1))
    x_test = np.tile(x_test, (1,3,1,1))

    x_train = torch.Tensor(x_train)
    y_train = torch.Tensor(y_train).type(torch.LongTensor)
    x_test = torch.Tensor(x_test)
    y_test = torch.Tensor(y_test).type(torch.LongTensor)
    x_train = F.interpolate(x_train, (32, 32))
    x_test = F.interpolate(x_test, (32, 32))
    x_train  = x_train / (255/2) - 1
    x_test  = x_test / (255/2) - 1
    x_train = torch.clip(x_train, -1., 1.)
    x_test = torch.clip(x_test, -1., 1.)

    # Divide the dataset into (train_x, train_y) that is 50,000 used in training. 
    # (member_x, member_y) that is 25,000 of the member data known by the attacker (from original training data). 
    # (nonmember_x, nonmember_y) is 5,000 of the nonmember data known by the attacker  (from original validation data).
    x_mem = x_train[:x_train.size(0)//2, :, :, :]
    
    y_mem = y_train[:y_train.size(0)//2]
    x_nomem = x_test[:x_test.size(0)//2, :, :, :]
    y_nomem = y_test[:y_test.size(0)//2]

    x_mem_test = x_train[x_train.size(0)//2:x_train.size(0), :, :, :]
    y_mem_test = y_train[y_train.size(0)//2:y_train.size(0)]
    x_nomem_test = x_test[x_test.size(0)//2:, :, :, :]
    y_nomem_test = y_test[y_test.size(0)//2:]

    xpriv = TensorDataset(x_train, y_train)
    xpub = TensorDataset(x_test, y_test)
    xmem = TensorDataset(x_mem, y_mem)
    xnomem = TensorDataset(x_nomem, y_nomem)
    xmem_test = TensorDataset(x_mem_test, y_mem_test)
    xnomem_test = TensorDataset(x_nomem_test, y_nomem_test)
    return xpriv, xpub, xmem, xnomem, xmem_test, xnomem_test


def get_mnist_bothloader(batch_size=16, num_workers=2, shuffle=True, num_agent = 1, collude_use_public = False):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """
    mnist_training, mnist_testing = load_mnist()
    
    if num_agent == 1:
        mnist_training_loader = torch.utils.data.DataLoader(mnist_training,  batch_size=batch_size, shuffle=shuffle,
                num_workers=num_workers)
    elif num_agent > 1:
        mnist_training_loader = []
        for i in range(num_agent):
            mnist_training_subset = torch.utils.data.Subset(mnist_training, list(range(i * (len(mnist_training)//num_agent), (i+1) * (len(mnist_training)//num_agent))))
            subset_training_loader = DataLoader(
                mnist_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
            mnist_training_loader.append(subset_training_loader)
    
    mnist_testing_loader = torch.utils.data.DataLoader(mnist_testing,  batch_size=batch_size, shuffle=False,
                num_workers=num_workers)

    return mnist_training_loader, mnist_testing_loader


def get_facescrub_bothloader(batch_size=16, num_workers=2, shuffle=True, num_agent = 1, collude_use_public = False):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(FACESCRUB_TRAIN_MEAN, FACESCRUB_TRAIN_STD)
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(FACESCRUB_TRAIN_MEAN, FACESCRUB_TRAIN_STD)
    ])
    facescrub_training = datasets.ImageFolder('facescrub-dataset/32x32/train', transform=transform_train)
    facescrub_testing = datasets.ImageFolder('facescrub-dataset/32x32/validate', transform=transform_test)
    
    if num_agent == 1:
        facescrub_training_loader = torch.utils.data.DataLoader(facescrub_training,  batch_size=batch_size, shuffle=shuffle,
                num_workers=num_workers)
    elif num_agent > 1:
        facescrub_training_loader = []
        for i in range(num_agent):
            mnist_training_subset = torch.utils.data.Subset(facescrub_training, list(range(i * (len(facescrub_training)//num_agent), (i+1) * (len(facescrub_training)//num_agent))))
            subset_training_loader = DataLoader(
                mnist_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
            facescrub_training_loader.append(subset_training_loader)
    
    facescrub_testing_loader = torch.utils.data.DataLoader(facescrub_testing,  batch_size=batch_size, shuffle=False,
                num_workers=num_workers)

    return facescrub_training_loader, facescrub_testing_loader


def get_purchase_trainloader():
    DATASET_PATH='./datasets/purchase'
    DATASET_NAME= 'dataset_purchase'

    if not os.path.isdir(DATASET_PATH):
        os.makedirs(DATASET_PATH)

    DATASET_FILE = os.path.join(DATASET_PATH,DATASET_NAME)

    if not os.path.isfile(DATASET_FILE):
        print("Dowloading the dataset...")
        urllib.request.urlretrieve("https://www.comp.nus.edu.sg/~reza/files/dataset_purchase.tgz",os.path.join(DATASET_PATH,'tmp.tgz'))
        print('Dataset Dowloaded')

        tar = tarfile.open(os.path.join(DATASET_PATH,'tmp.tgz'))
        tar.extractall(path=DATASET_PATH)


    data_set =np.genfromtxt(DATASET_FILE,delimiter=',')

    X = data_set[:,1:].astype(np.float64)
    Y = (data_set[:,0]).astype(np.int32)-1

    len_train =len(X)
    r = np.load('./dataset_shuffle/random_r_purchase100.npy')
    X=X[r]
    Y=Y[r]
    train_classifier_ratio, train_attack_ratio = 0.1,0.15
    train_classifier_data = X[:int(train_classifier_ratio*len_train)]
    test_data = X[int((train_classifier_ratio+train_attack_ratio)*len_train):]

    train_classifier_label = Y[:int(train_classifier_ratio*len_train)]
    test_label = Y[int((train_classifier_ratio+train_attack_ratio)*len_train):]

    xpriv = TensorDataset(train_classifier_data, train_classifier_label)
    xpub = TensorDataset(test_data, test_label)


    train_classifier_ratio, train_attack_ratio = 0.1,0.3
    train_data = X[:int(train_classifier_ratio*len_train)]
    test_data = X[int((train_classifier_ratio+train_attack_ratio)*len_train):]
    
    train_label = Y[:int(train_classifier_ratio*len_train)]
    test_label = Y[int((train_classifier_ratio+train_attack_ratio)*len_train):]
    
    np.random.seed(100)
    train_len = train_data.shape[0]
    r = np.arange(train_len)
    np.random.shuffle(r)
    shadow_indices = r[:train_len//2]
    target_indices = r[train_len//2:]

    shadow_train_data, shadow_train_label = train_data[shadow_indices], train_label[shadow_indices]
    target_train_data, target_train_label = train_data[target_indices], train_label[target_indices]

    test_len = 1*train_len
    r = np.arange(test_len)
    np.random.shuffle(r)
    shadow_indices = r[:test_len//2]
    target_indices = r[test_len//2:]
    
    shadow_test_data, shadow_test_label = test_data[shadow_indices], test_label[shadow_indices]
    target_test_data, target_test_label = test_data[target_indices], test_label[target_indices]

    shadow_train = tensor_data_create(shadow_train_data, shadow_train_label)
    shadow_train_loader = DataLoader(shadow_train, batch_size=batch_size, shuffle=True, num_workers=1)

    shadow_test = tensor_data_create(shadow_test_data, shadow_test_label)
    shadow_test_loader = DataLoader(shadow_test, batch_size=batch_size, shuffle=True, num_workers=1)

    target_train = tensor_data_create(target_train_data, target_train_label)
    target_train_loader = DataLoader(target_train, batch_size=batch_size, shuffle=True, num_workers=1)

    target_test = tensor_data_create(target_test_data, target_test_label)
    target_test_loader = DataLoader(target_test, batch_size=batch_size, shuffle=True, num_workers=1)
    print('Data loading finished')
    return shadow_train_loader, shadow_test_loader, target_train_loader, target_test_loader

def get_cifar10_trainloader(batch_size=128, num_workers=4, shuffle=True, num_agent = 1):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD)
    ])
    #cifar00_training = CIFAR10Train(path, transform=transform_train)
    cifar10_training = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    if num_agent == 1:
        cifar10_training_loader = DataLoader(
            cifar10_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
    elif num_agent > 1:
        cifar10_training_loader = []
        for i in range(num_agent):
            cifar10_training_subset = MySubset(cifar10_training, list(range(i * ((len(cifar10_training))//num_agent),(i+1) * ((len(cifar10_training))//num_agent))))
            subset_training_loader = DataLoader(
                cifar10_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
            cifar10_training_loader.append(subset_training_loader)

    

    cnts_dict = {i:[50000/10/num_agent for _ in range(10)] for i in range(num_agent)}



    return cifar10_training_loader,cnts_dict


def get_cifar10_trainloader_noniid(logger,batch_size=128, num_workers=4, shuffle=True, num_agent = 1,class_per_client=2):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD)
    ])
    #cifar00_training = CIFAR10Train(path, transform=transform_train)
    cifar10_training = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    if num_agent == 1:
        cifar10_training_loader = DataLoader(
            cifar10_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
    elif num_agent > 1:
        group = int((class_per_client * num_agent) / 10)
        if group != (class_per_client * num_agent) // 10 or group < 1:
            raise RuntimeError(
                'The 10 classes cannot be divided evenly, please reset the number of clients or the number of classes per client'
            )
        # catagory the different classes of data
        index_dic={0:[],1:[],2:[],3:[],4:[],5:[],6:[],7:[],8:[],9:[]}
        for i in range(len(cifar10_training)):
            index_dic[cifar10_training[i][1]].append(i)
        # group the data in classes
        for key in index_dic.keys():
            num_per_group=len(index_dic[key])//group
            temp_list = []
            for group_idx in range(group-1):
                temp_list.append(index_dic[key][group_idx*num_per_group:(group_idx+1)*num_per_group])
            temp_list.append(index_dic[key][(group-1)*num_per_group:])
            index_dic[key] = []+temp_list


        cifar10_training_loader = []

        client_class = []
        selection_list = []
        for _ in range(group):
            selection_list += list(range(10))
        random.shuffle(selection_list)
        # selection_list = [0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7,8,9,8,9]

        for client_id in range(num_agent):
            client_class.append(selection_list[(client_id*class_per_client):((client_id+1)*class_per_client)])
            # selection_list = (selection_list[class_per_client::]+selection_list[0:class_per_client])

            logger.info('[client-{}] labels:{}'.format(client_id,client_class[-1]))

        # # assign class to clients
        # temp_num = int(10 / class_per_client)
        # for _ in range(group):
        #     arr = np.arange(10)
        #     np.random.shuffle(arr)
        #     for temp_num_i in range(temp_num):
        #         client_class.append(arr[temp_num_i * class_per_client:(temp_num_i + 1) * class_per_client].tolist())

        cnts_dict = {i:[0 for _ in range(10)] for i in range(num_agent)}

        for i,selected_class in enumerate(client_class):
            # check the number of remaining classes
            selected_index = []
            for class_item in selected_class:
                cnts_dict[i][class_item] += 50000/group/num_agent
                selected_index = selected_index + index_dic[class_item][-1]
                index_dic[class_item].pop()
            cifar10_training_subset = MySubset(cifar10_training, selected_index)
            subset_training_loader = DataLoader(
                cifar10_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
            cifar10_training_loader.append(subset_training_loader)
    return cifar10_training_loader,cnts_dict

def get_cifar10_trainloader_noniid_dic(logger,batch_size=128, num_workers=4, shuffle=True, num_agent = 1,beta=0.1,ordered=False):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD)
    ])
    dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    labels = np.array(dataset.targets)
    num_data = 50000
    _lst_sample = 2
    K = 10
    min_size = 0
    y_train = labels


    least_idx = np.zeros((num_agent, 10, _lst_sample), dtype=np.int)

    for i in range(10):
        idx_i = np.random.choice(np.where(labels==i)[0], num_agent*_lst_sample, replace=False)
        least_idx[:, i, :] = idx_i.reshape((num_agent, _lst_sample))
    least_idx = np.reshape(least_idx, (num_agent, -1))


    least_idx_set = set(np.reshape(least_idx, (-1)))
    local_idx = np.random.choice(list(set(range(50000))-least_idx_set), len(list(set(range(50000))-least_idx_set)), replace=False)

    N = y_train.shape[0]
    net_dataidx_map = {}
    dict_users = {i: np.array([], dtype='int64') for i in range(num_agent)}

    while min_size < 10:
        idx_batch = [[] for _ in range(num_agent)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(y_train == k)[0]
            idx_k = [id for id in idx_k if id in local_idx]
            
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(beta, num_agent))
            ## Balance
            proportions = np.array([p*(len(idx_j)<N/num_agent) for p,idx_j in zip(proportions,idx_batch)])
            proportions = proportions/proportions.sum()

            if ordered:
                proportions = np.sort(proportions)[-1::-1]
                proportions = np.concatenate((proportions[(10-k)*5:],proportions[:(10-k)*5]),axis=0)
            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    if ordered:
        shuffle_index = list(range(len(idx_batch)))
        np.random.shuffle(shuffle_index)
        idx_batch = np.array(idx_batch)[shuffle_index]
        idx_batch = idx_batch.tolist()


    for j in range(num_agent):
        np.random.shuffle(idx_batch[j])
        dict_users[j] = idx_batch[j]  
        dict_users[j] = np.concatenate((dict_users[j], least_idx[j]), axis=0)          

    cnts_dict = {}
    cifar10_training_loader = []
    for i in range(num_agent):
        cifar10_training_subset = MySubset(dataset, dict_users[i])
        subset_training_loader = DataLoader(
            cifar10_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
        cifar10_training_loader.append(subset_training_loader)

        labels_i = labels[dict_users[i]]
        cnts = np.array([np.count_nonzero(labels_i == j ) for j in range(10)] )
        cnts_dict[i] = cnts
        logger.info('[client-{}] labels:{} sum:{}'.format(i," ".join([str(cnt) for cnt in cnts]),sum(cnts)))
    return cifar10_training_loader,cnts_dict


def get_cifar10_testloader(batch_size=128, num_workers=4, shuffle=True):
    """ return training dataloader
    Args:
        mean: mean of cifar10 test dataset
        std: std of cifar10 test dataset
        path: path to cifar10 test python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: cifar10_test_loader:torch dataloader object
    """
    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD)
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD)
    ])

    transform_exlabel = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    cifar10_test_loader = DataLoader(
        cifar10_test, shuffle=shuffle, num_workers=num_workers, batch_size=len(cifar10_test))

    return cifar10_test_loader


def get_cifar100_trainloader(batch_size=128, num_workers=4, shuffle=True, num_agent = 1, collude_use_public = False):
    """ return training dataloader
    Args:
        mean: mean of cifar100 training dataset
        std: std of cifar100 training dataset
        path: path to cifar100 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])
    
    cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    if num_agent == 1:
        cifar100_training_loader = DataLoader(
            cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
    
    elif num_agent > 1:
        cifar100_training_loader = []
        if not collude_use_public:
            for i in range(num_agent):
                cifar100_training_subset = MySubset(cifar100_training, list(range(i * (len(cifar100_training)//num_agent), (i+1) * (len(cifar100_training)//num_agent))))
                subset_training_loader = DataLoader(
                    cifar100_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
                cifar100_training_loader.append(subset_training_loader)

    cnts_dict = {i:[50000/100/num_agent for _ in range(100)] for i in range(num_agent)}


    return cifar100_training_loader,cnts_dict


def get_cifar100_trainloader_noniid(logger,batch_size=128, num_workers=4, shuffle=True, num_agent = 1,class_per_client=2):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    if num_agent == 1:
        cifar100_training_loader = DataLoader(
            cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
    elif num_agent > 1:
        group = int((class_per_client * num_agent) / 100)
        if group != (class_per_client * num_agent) // 100 or group < 1:
            raise RuntimeError(
                'The 100 classes cannot be divided evenly, please reset the number of clients or the number of classes per client'
            )
        # catagory the different classes of data
        # index_dic={0:[],1:[],2:[],3:[],4:[],5:[],6:[],7:[],8:[],9:[]}
        index_dic={}
        for i in range(100):
            index_dic[i]= []

        for i in range(len(cifar100_training)):
            index_dic[cifar100_training[i][1]].append(i)
        # group the data in classes
        for key in index_dic.keys():
            num_per_group=len(index_dic[key])//group
            temp_list = []
            for group_idx in range(group-1):
                temp_list.append(index_dic[key][group_idx*num_per_group:(group_idx+1)*num_per_group])
            temp_list.append(index_dic[key][(group-1)*num_per_group:])
            index_dic[key] = []+temp_list


        cifar100_training_loader = []

        client_class = []
        selection_list = []
        for _ in range(group):
            selection_list += list(range(100))
        random.shuffle(selection_list)
        # selection_list = [0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7,8,9,8,9]

        for client_id in range(num_agent):
            client_class.append(selection_list[(client_id*class_per_client):((client_id+1)*class_per_client)])
            # selection_list = (selection_list[class_per_client::]+selection_list[0:class_per_client])

            logger.info('[client-{}] labels:{}'.format(client_id,client_class[-1]))

        # # assign class to clients
        # temp_num = int(10 / class_per_client)
        # for _ in range(group):
        #     arr = np.arange(10)
        #     np.random.shuffle(arr)
        #     for temp_num_i in range(temp_num):
        #         client_class.append(arr[temp_num_i * class_per_client:(temp_num_i + 1) * class_per_client].tolist())

        cnts_dict = {i:[0 for _ in range(100)] for i in range(num_agent)}

        for i,selected_class in enumerate(client_class):
            # check the number of remaining classes
            selected_index = []
            for class_item in selected_class:
                cnts_dict[i][class_item] += 50000/group/num_agent
                selected_index = selected_index + index_dic[class_item][-1]
                index_dic[class_item].pop()
            cifar100_training_subset = MySubset(cifar100_training, selected_index)
            subset_training_loader = DataLoader(
                cifar100_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
            cifar100_training_loader.append(subset_training_loader)
    return cifar100_training_loader,cnts_dict


def get_cifar100_trainloader_noniid_dic(logger,batch_size=128, num_workers=4, shuffle=True, num_agent = 1,beta=0.1):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """
    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])
    
    dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    labels = np.array(dataset.targets)
    num_data = 50000
    _lst_sample = 2
    K = 100
    min_size = 0
    y_train = labels


    least_idx = np.zeros((num_agent, 100, _lst_sample), dtype=np.int)

    for i in range(100):
        idx_i = np.random.choice(np.where(labels==i)[0], num_agent*_lst_sample, replace=False)
        least_idx[:, i, :] = idx_i.reshape((num_agent, _lst_sample))
    least_idx = np.reshape(least_idx, (num_agent, -1))


    least_idx_set = set(np.reshape(least_idx, (-1)))
    local_idx = np.random.choice(list(set(range(50000))-least_idx_set), len(list(set(range(50000))-least_idx_set)), replace=False)

    N = y_train.shape[0]
    net_dataidx_map = {}
    dict_users = {i: np.array([], dtype='int64') for i in range(num_agent)}

    while min_size < 10:
        idx_batch = [[] for _ in range(num_agent)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(y_train == k)[0]
            idx_k = [id for id in idx_k if id in local_idx]
            
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(beta, num_agent))
            ## Balance
            proportions = np.array([p*(len(idx_j)<N/num_agent) for p,idx_j in zip(proportions,idx_batch)])
            proportions = proportions/proportions.sum()
            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(num_agent):
        np.random.shuffle(idx_batch[j])
        dict_users[j] = idx_batch[j]  
        dict_users[j] = np.concatenate((dict_users[j], least_idx[j]), axis=0)          

    cnts_dict = {}
    cifar100_training_loader = []
    for i in range(num_agent):
        cifar10_training_subset = MySubset(dataset, dict_users[i])
        subset_training_loader = DataLoader(
            cifar10_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
        cifar100_training_loader.append(subset_training_loader)

        labels_i = labels[dict_users[i]]
        cnts = np.array([np.count_nonzero(labels_i == j ) for j in range(100)] )
        cnts_dict[i] = cnts
        logger.info('[client-{}] labels:{} sum:{}'.format(i," ".join([str(cnt) for cnt in cnts]),sum(cnts)))
    return cifar100_training_loader,cnts_dict


def get_cifar100_testloader(batch_size=16, num_workers=2, shuffle=True, extra_cls_removed_dataset = False, cls_to_remove = 0):
    """ return training dataloader
    Args:
        mean: mean of cifar100 test dataset
        std: std of cifar100 test dataset
        path: path to cifar100 test python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: cifar100_test_loader:torch dataloader object
    """
    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    transform_exlabel = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    #cifar100_test = CIFAR100Test(path, transform=transform_test)
    cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    cifar100_test_loader = DataLoader(
        cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=len(cifar100_test)//2)
    return cifar100_test_loader




def get_cifar50_trainloader_noniid(logger,batch_size=128, num_workers=4, shuffle=True, num_agent = 1,class_per_client=2):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """
    class_num = 50
    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    if num_agent == 1:
        cifar100_training_loader = DataLoader(
            cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
    elif num_agent > 1:
        group = int((class_per_client * num_agent) / class_num)
        if group != (class_per_client * num_agent) // class_num or group < 1:
            raise RuntimeError(
                'The '+str(class_num)+' classes cannot be divided evenly, please reset the number of clients or the number of classes per client'
            )
        # catagory the different classes of data
        # index_dic={0:[],1:[],2:[],3:[],4:[],5:[],6:[],7:[],8:[],9:[]}
        index_dic={}
        # the first 50 labels
        for i in range(class_num):
            index_dic[i]= []

        for i in range(len(cifar100_training)):
            if cifar100_training[i][1]<class_num:
                index_dic[cifar100_training[i][1]].append(i)
        # group the data in classes
        for key in index_dic.keys():
            num_per_group=len(index_dic[key])//group
            temp_list = []
            for group_idx in range(group-1):
                temp_list.append(index_dic[key][group_idx*num_per_group:(group_idx+1)*num_per_group])
            temp_list.append(index_dic[key][(group-1)*num_per_group:])
            index_dic[key] = []+temp_list


        cifar100_training_loader = []

        client_class = []
        selection_list = []
        for _ in range(group):
            selection_list += list(range(class_num))
        random.shuffle(selection_list)
        # selection_list = [0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7,8,9,8,9]

        for client_id in range(num_agent):
            client_class.append(selection_list[(client_id*class_per_client):((client_id+1)*class_per_client)])
            # selection_list = (selection_list[class_per_client::]+selection_list[0:class_per_client])

            logger.info('[client-{}] labels:{}'.format(client_id,client_class[-1]))

        # # assign class to clients
        # temp_num = int(10 / class_per_client)
        # for _ in range(group):
        #     arr = np.arange(10)
        #     np.random.shuffle(arr)
        #     for temp_num_i in range(temp_num):
        #         client_class.append(arr[temp_num_i * class_per_client:(temp_num_i + 1) * class_per_client].tolist())

        cnts_dict = {i:[0 for _ in range(class_num)] for i in range(num_agent)}

        for i,selected_class in enumerate(client_class):
            # check the number of remaining classes
            selected_index = []
            for class_item in selected_class:
                cnts_dict[i][class_item] += (class_num*500)/group/num_agent
                selected_index = selected_index + index_dic[class_item][-1]
                index_dic[class_item].pop()
            cifar100_training_subset = MySubset(cifar100_training, selected_index)
            subset_training_loader = DataLoader(
                cifar100_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
            cifar100_training_loader.append(subset_training_loader)
    return cifar100_training_loader,cnts_dict


def get_cifar50_trainloader_noniid_dic(logger,batch_size=128, num_workers=4, shuffle=True, num_agent = 1,beta=0.1):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """
    class_num = 50
    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])
    
    dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    labels = np.array(dataset.targets)
    num_data = 100*500
    _lst_sample = 0
    K = class_num
    min_size = 0
    y_train = labels


    least_idx = np.zeros((num_agent, class_num, _lst_sample), dtype=np.int)

    for i in range(class_num):
        idx_i = np.random.choice(np.where(labels==i)[0], num_agent*_lst_sample, replace=False)
        least_idx[:, i, :] = idx_i.reshape((num_agent, _lst_sample))
    least_idx = np.reshape(least_idx, (num_agent, -1))


    least_idx_set = set(np.reshape(least_idx, (-1)))
    local_idx = np.random.choice(list(set(range(num_data))-least_idx_set), len(list(set(range(num_data))-least_idx_set)), replace=False)

    N = y_train.shape[0]
    net_dataidx_map = {}
    dict_users = {i: np.array([], dtype='int64') for i in range(num_agent)}

    while min_size < 10:
        idx_batch = [[] for _ in range(num_agent)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(y_train == k)[0]
            idx_k = [id for id in idx_k if id in local_idx]
            
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(beta, num_agent))
            ## Balance
            proportions = np.array([p*(len(idx_j)<N/num_agent) for p,idx_j in zip(proportions,idx_batch)])
            proportions = proportions/proportions.sum()

            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]

            idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(num_agent):
        np.random.shuffle(idx_batch[j])
        dict_users[j] = idx_batch[j]  
        dict_users[j] = np.concatenate((dict_users[j], least_idx[j]), axis=0)          

    cnts_dict = {}
    cifar100_training_loader = []
    for i in range(num_agent):
        cifar10_training_subset = MySubset(dataset, dict_users[i])
        subset_training_loader = DataLoader(
            cifar10_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
        cifar100_training_loader.append(subset_training_loader)

        labels_i = labels[dict_users[i]]
        cnts = np.array([np.count_nonzero(labels_i == j ) for j in range(class_num)] )
        cnts_dict[i] = cnts
        logger.info('[client-{}] labels:{} sum:{}'.format(i," ".join([str(cnt) for cnt in cnts]),sum(cnts)))
    return cifar100_training_loader,cnts_dict


def get_cifar50_testloader(batch_size=16, num_workers=2, shuffle=True, extra_cls_removed_dataset = False, cls_to_remove = 0):
    """ return training dataloader
    Args:
        mean: mean of cifar100 test dataset
        std: std of cifar100 test dataset
        path: path to cifar100 test python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: cifar100_test_loader:torch dataloader object
    """
    class_num=50
    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    transform_exlabel = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])



    #cifar100_test = CIFAR100Test(path, transform=transform_test)
    cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    class_num=50
    labels = np.array(cifar100_test.targets)
    index=np.where(labels==0)[0]
    for i in range(1,class_num):
        index = np.concatenate((index,np.where(labels==i)[0]),axis=0)
    # cifar100_test = cifar100_test[np.where(labels==0)[0]]
    cifar100_training_subset = torch.utils.data.Subset(cifar100_test, index)

    cifar100_test_loader = DataLoader(
        cifar100_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=len(cifar100_test)//2)
    return cifar100_test_loader



################


def get_celeba_trainloader(batch_size=16, num_workers=2, shuffle=True, num_agent = 1, collude_use_public = False):
    """ return training dataloader
    Args:
        mean: mean of celeba training dataset
        std: std of celeba training dataset
        path: path to celeba training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    celeba_training = torchvision.datasets.CelebA(root='./data', train=True, download=True, transform=transform_train)
    if num_agent == 1:
        celeba_training_loader = DataLoader(
            celeba_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
    
    elif num_agent > 1:
        celeba_training_loader = []
        if not collude_use_public:
            for i in range(num_agent):
                celeba_training_subset = torch.utils.data.Subset(celeba_training, list(range(i * (len(celeba_training)//num_agent), (i+1) * (len(celeba_training)//num_agent))))
                subset_training_loader = DataLoader(
                    celeba_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
                celeba_training_loader.append(subset_training_loader)
        else:
            '''1 + collude + (n-2) vanilla clients, all training data is shared by n-1 clients'''
            # celeba_test = torchvision.datasets.CelebA(root='./data', train=False, download=True, transform=transform_train)
            # subset_training_loader = DataLoader(
            #     celeba_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
            # celeba_training_loader.append(subset_training_loader)
            # for i in range(num_agent-1):
            #     celeba_training_subset = torch.utils.data.Subset(celeba_training, list(range(i * (len(celeba_training)//(num_agent-1)), (i+1) * (len(celeba_training)//(num_agent-1)))))
            #     subset_training_loader = DataLoader(
            #         celeba_training_subset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
            #     celeba_training_loader.append(subset_training_loader)
            # # switch the testloader to collude position
            # temp = celeba_training_loader[0]
            # celeba_training_loader[0] = celeba_training_loader[1]
            # celeba_training_loader[1] = temp

            '''1+ (n-1) * collude, the single client gets all training data'''
            subset_training_loader = DataLoader(
                celeba_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
            celeba_training_loader.append(subset_training_loader)
            celeba_test = torchvision.datasets.CelebA(root='./data', train=False, download=True, transform=transform_train)
            for i in range(num_agent-1):
                subset_training_loader = DataLoader(
                    celeba_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
                
                celeba_training_loader.append(subset_training_loader)

    celeba_training2 = torchvision.datasets.CelebA(root='./data', train=True, download=True, transform=transform_train)

    celeba_training_mem = torch.utils.data.Subset(celeba_training2, list(range(0, 5000)))
    xmem_training_loader = DataLoader(
        celeba_training_mem, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    celeba_testing_mem = torch.utils.data.Subset(celeba_training2, list(range(5000, 10000)))
    xmem_testing_loader = DataLoader(
        celeba_testing_mem, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)


    return celeba_training_loader, xmem_training_loader, xmem_testing_loader

def get_celeba_testloader(batch_size=16, num_workers=2, shuffle=True):
    """ return training dataloader
    Args:
        mean: mean of celeba test dataset
        std: std of celeba test dataset
        path: path to celeba test python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: celeba_test_loader:torch dataloader object
    """
    transform_train = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    #celeba_test = CIFAR100Test(path, transform=transform_test)
    celeba_test = torchvision.datasets.CelebA(root='./data', train=False, download=True, transform=transform_test)
    celeba_test_loader = DataLoader(
        celeba_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    celeba_test2 = torchvision.datasets.CelebA(root='./data', train=False, download=True, transform=transform_train)
    celeba_training_nomem = torch.utils.data.Subset(celeba_test2, list(range(0, len(celeba_test2)//2)))
    nomem_training_loader = DataLoader(
        celeba_training_nomem, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    celeba_testing_nomem = torch.utils.data.Subset(celeba_test2, list(range(len(celeba_test2)//2, len(celeba_test2))))
    nomem_testing_loader = DataLoader(
        celeba_testing_nomem, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)


    return celeba_test_loader, nomem_training_loader, nomem_testing_loader





#####################
def load_mnist_mangled(class_to_remove):
    xpriv = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            transforms.RandomCrop(32, padding=4)
            
        ]))

    xpub = datasets.MNIST(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            transforms.RandomCrop(32, padding=4)
            
        ]))
    

    x_train = np.array(xpriv.data)
    y_train = np.array(xpriv.targets)
    x_test = np.array(xpub.data)
    y_test = np.array(xpub.targets)
    # remove class from Xpub
    (x_test, y_test), _ = remove_class(x_test, y_test, class_to_remove)
    # for evaluation
    (x_train_seen, y_train_seen), (x_removed_examples, y_removed_examples) = remove_class(x_train, y_train, class_to_remove)
    
    x_test = torch.Tensor(x_test)
    y_test = torch.Tensor(y_test).type(torch.LongTensor)

    xpub = TensorDataset(x_test, y_test)

    x_removed_examples = torch.Tensor(x_removed_examples) # transform to torch tensor
    y_removed_examples = torch.Tensor(y_removed_examples)

    x_train_seen = torch.Tensor(x_train_seen)
    y_train_seen = torch.Tensor(y_train_seen).type(torch.LongTensor)

    xremoved_examples = TensorDataset(x_removed_examples, y_removed_examples)
    xpriv_other = TensorDataset(x_train_seen, y_train_seen)
    
    return xpriv, xpub, xremoved_examples, xpriv_other


def load_fashion_mnist():
    xpriv = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            transforms.RandomCrop(32, padding=4)
        ]))

    xpub = datasets.FashionMNIST(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            transforms.RandomCrop(32, padding=4)
        ]))
    
    return xpriv, xpub

def remove_class(X, Y, ctr):
    mask = Y!=ctr
    XY = X[mask], Y[mask]
    mask = Y==ctr
    XYr = X[mask], Y[mask]
    return XY, XYr

def plot(X, label='', norm=True):
    n = len(X)
    X = (X+1) / 2 
    fig, ax = plt.subplots(1, n, figsize=(n*3,3))
    for i in range(n):
        if X[i].shape[0] == 1:
            ax[i].imshow(X[i].squeeze(), cmap=plt.get_cmap('gray'));  
        else:
            ax[i].imshow(X[i]);  
        ax[i].set(xticks=[], yticks=[], title=label)

def get_dataloader(dataset,logger,batch_size=128, num_workers=4, num_agent=20,data_dist='iid'):
    # setup dataset
    if 'noniid' in data_dist:
        non_iid=True
        dire=False
        class_per_client = int(data_dist.split('noniid')[-1])
    elif 'dir' in data_dist:
        non_iid=True
        dire=True
        beta = float(data_dist.split('dir')[-1])
        if "order" in data_dist:
            ordered = True
        else:
            ordered = False
    else:
        non_iid = False
        class_per_client=10
        beta = np.inf
    if dataset == "cifar10":
        orig_class = 10
        if non_iid:
            if dire:
                client_dataloader,client_class = get_cifar10_trainloader_noniid_dic(logger,batch_size=batch_size,
                                                                    num_workers=num_workers, shuffle=True,
                                                                    num_agent=num_agent,beta=beta,ordered=ordered)
            else:
                client_dataloader,client_class = get_cifar10_trainloader_noniid(logger,batch_size=batch_size,
                                                                    num_workers=num_workers, shuffle=True,
                                                                    num_agent=num_agent,class_per_client=class_per_client)
        else:
            client_dataloader,client_class = get_cifar10_trainloader(batch_size=batch_size,
                                                                num_workers=num_workers,shuffle=True,
                                                                num_agent=num_agent)
        pub_dataloader = get_cifar10_testloader(batch_size=batch_size,num_workers=4,shuffle=False)
    elif dataset == "cifar100":
        orig_class = 100
        if non_iid:
            if dire:
                client_dataloader,client_class = get_cifar100_trainloader_noniid_dic(logger,batch_size=batch_size,
                                                                    num_workers=num_workers, shuffle=True,
                                                                    num_agent=num_agent,beta=beta)
            else:
                client_dataloader,client_class = get_cifar100_trainloader_noniid(logger,batch_size=batch_size,
                                                                    num_workers=num_workers, shuffle=True,
                                                                    num_agent=num_agent,class_per_client=class_per_client)
        else:
            client_dataloader, client_class= get_cifar100_trainloader(batch_size=batch_size,
                                                        num_workers=4, shuffle=True,
                                                        num_agent=num_agent)
        pub_dataloader = get_cifar100_testloader(batch_size=batch_size,
                                                 num_workers=4,
                                                 shuffle=False)
    
    elif dataset == "cifar50":
        orig_class = 50
        if non_iid:
            if dire:
                client_dataloader,client_class = get_cifar50_trainloader_noniid_dic(logger,batch_size=batch_size,
                                                                    num_workers=num_workers, shuffle=True,
                                                                    num_agent=num_agent,beta=beta)
            else:
                client_dataloader,client_class = get_cifar50_trainloader_noniid(logger,batch_size=batch_size,
                                                                    num_workers=num_workers, shuffle=True,
                                                                    num_agent=num_agent,class_per_client=class_per_client)
        else:
            raise ValueError('Do not support cifar50 iid!')
        pub_dataloader = get_cifar50_testloader(batch_size=batch_size,
                                                 num_workers=4,
                                                 shuffle=False)
        
    elif dataset == "facescrub":
        orig_class = 530
        client_class = [list(range(orig_class)) for _ in range(num_agent)]
        client_dataloader, pub_dataloader = get_facescrub_bothloader(batch_size=batch_size,
                                                                    num_workers=4,
                                                                    shuffle=True,
                                                                    num_agent=num_agent)
    elif dataset == "mnist":
        orig_class = 10
        client_class = [list(range(orig_class)) for _ in range(num_agent)]
        client_dataloader, pub_dataloader = get_mnist_bothloader(batch_size=batch_size,
                                                                            num_workers=4,
                                                                            shuffle=True,
                                                                            num_agent=num_agent)
    else:
        raise ("Dataset {} is not supported!".format(dataset))

    if num_agent == 1:
        num_batches = len(client_dataloader)
        print("Total number of batches per epoch is ", num_batches)
    else:
        num_batches = round(np.mean([len(client_dataloader[i]) for i in range(num_agent)]))
        print("Total number of batches per epoch for each agent is ", num_batches)

    return SFL_data_generator(client_dataloader,num_agent,batch_size), pub_dataloader, orig_class, num_batches,client_class
    # return client_dataloader, pub_dataloader, orig_class, num_batches

class SFL_data_generator:
    def __init__(self,client_dataloader,num_agent,batch_size):
        # load data batch
        self.num_agent = num_agent
        self.batch_size = batch_size
        self.client_dataloader = client_dataloader
        if self.num_agent > 1:
            self.client_iterator_list = []
            for client_id in range(self.num_agent):
                self.client_iterator_list.append(iter(self.client_dataloader[client_id]))
        else:
            self.client_iterator_list = [iter(self.client_dataloader)]
    def load_data(self,client_id,adjusted_batch_size):
        try:
            images, labels, _ = next(self.client_iterator_list[client_id])
            if images.size(0) != self.batch_size:
                self.client_iterator_list[client_id] = iter(self.client_dataloader[client_id])
                images, labels, _ = next(self.client_iterator_list[client_id])
        except StopIteration:
            self.client_iterator_list[client_id] = iter(self.client_dataloader[client_id])
            images, labels, _ = next(self.client_iterator_list[client_id])
        if adjusted_batch_size>images.size(0):
            raise ValueError('The adjusted batch size is larger than original batch size.')
        return images[0:int(adjusted_batch_size)].cuda(), labels[0:int(adjusted_batch_size)].cuda()





