import random

import numpy as np
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
from torchvision.datasets import MNIST, CIFAR10, CIFAR100

from args import *
from data.autoaugment import CIFAR10Policy, Cutout


class MNIST_truncated(data.Dataset):

    def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.dataidxs = dataidxs
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download

        self.data, self.target = self.__build_truncated_dataset__()

    def __build_truncated_dataset__(self):

        mnist_dataobj = MNIST(self.root, self.train, self.transform, self.target_transform, self.download)

        data = mnist_dataobj.data
        target = mnist_dataobj.targets

        if self.dataidxs is not None:
            data = data[self.dataidxs]
            target = target[self.dataidxs]

        return data, target

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.target[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        # print("mnist img:", img)
        # print("mnist target:", target)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)


class CIFAR_truncated(data.Dataset):

    def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):

        self.root = root
        self.dataidxs = dataidxs
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download

        self.data, self.target = self.__build_truncated_dataset__()

    def __build_truncated_dataset__(self):

        if args.dataset == 'cifar10':
            cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)
        elif args.dataset == 'cifar100':
            cifar_dataobj = CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download)

        data = cifar_dataobj.data
        target = np.array(cifar_dataobj.targets)

        if self.dataidxs is not None:
            data = data[self.dataidxs]
            target = target[self.dataidxs]

        return data, target

    def truncate_channel(self, index):
        for i in range(index.shape[0]):
            gs_index = index[i]
            self.data[gs_index, :, :, 1] = 0.0
            self.data[gs_index, :, :, 2] = 0.0

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.target[index]
        img = Image.fromarray(img)
        # print("cifar10 img:", img)
        # print("cifar10 target:", target)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)


def load_mnist_data():
    transform = transforms.Compose([transforms.ToTensor()])

    mnist_train_ds = MNIST_truncated(args.datadir, train=True, download=True, transform=transform)
    mnist_test_ds = MNIST_truncated(args.datadir, train=False, download=True, transform=transform)
    ##self, root, dataidxs = None, train = True, transform = None, target_transform = None, download = False)
    # X_train, y_train = mnist_train_ds.data, mnist_train_ds.target
    # X_test, y_test = mnist_test_ds.data, mnist_test_ds.target#不必分开

    # X_train = X_train.data.numpy()
    # y_train = y_train.data.numpy()
    # X_test = X_test.data.numpy()
    # y_test = y_test.data.numpy()

    return (mnist_train_ds, mnist_test_ds)


def load_cifar_data():
    aug = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip()
    ]

    aug.append(CIFAR10Policy())
    aug.append(transforms.ToTensor())
    aug.append(Cutout(n_holes=1, length=16))

    if args.dataset == 'cifar10':
        aug.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)))
        transform_train = transforms.Compose(aug)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

    else:  # Cifar100
        aug.append(transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)))
        transform_train = transforms.Compose(aug)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])

    cifar_train_ds = CIFAR_truncated(args.datadir, train=True, download=True, transform=transform_train)
    cifar_test_ds = CIFAR_truncated(args.datadir, train=False, download=True, transform=transform_test)

    return (cifar_train_ds, cifar_test_ds)


def partition_data(n_parties, ii):  # logdir,
    # np.random.seed(2020)
    # torch.manual_seed(2020)

    dataset = args.dataset  # 数据集名称

    if dataset == 'mnist':
        train_data, test_data = load_mnist_data()
    elif dataset == 'cifar10' or dataset == 'cifar100':
        train_data, test_data = load_cifar_data()

    partition = args.partition  # 数据分布方式
    n_train = train_data.target.shape[0]  # 训练集的行数，即大小
    beta = ii  # Dirichlet分布的参数

    # 数据独立同分布
    if partition == "homo":
        idxs = np.random.permutation(n_train)  # 随机打乱
        batch_idxs = np.array_split(idxs, n_parties)
        net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}

    # Non-IID Distribution-based label imbalance
    elif partition == "noniid-labeldir":
        min_size = 0
        min_require_size = args.batch_size  # 每个客户端拥有的最少数据量
        K_labels = 10 if args.dataset == 'cifar10' else 100  # 数据集的类别数
        print('labels, partition', K_labels, beta)

        net_dataidx_map = {}

        while min_size < min_require_size:  # 以此为循环意味着，如果这次分配很离谱，连客户端最少数据量都没达到，就要重来
            idx_batch = [[] for _ in range(n_parties)]  # idx_batch=[[],[],[],...],为每一个客户端建立一个数据集
            for k in range(K_labels):  # 对于标签为k的数据
                idx_k = np.where(train_data.target == k)[0]  # 标签为k的数据行号
                np.random.shuffle(idx_k)  # 打乱
                proportions = np.random.dirichlet(np.repeat(beta,
                                                            n_parties))  # np.repeat(beta, n_parties)=[0.4,0.4,0.4,...],proportions=[2.17502809e-02 2.47818666e-04 2.39564463e-02 6.18418354e-017.38328375e-03 9.57234931e-07 5.84141102e-02 5.77214376e-03 1.89605159e-03 2.62160554e-01]
                ## Balance
                proportions = np.array([p * (len(idx_j) < n_train / n_parties) for p, idx_j in zip(proportions,
                                                                                                   idx_batch)])  # 如果N<n_patites,也就是说这个标签的数据不够用户分一个，proportion就为0，否则不动
                proportions = proportions / proportions.sum()  # 归一化，使得概率之和为1

                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]  # cumsum累加，比如[1,2,3]  [1,3,6]

                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k,
                                                                                            proportions))]  # np.spilt[ary,indices_or_sections]牛逼，所以上面要用cumsum
                min_size = min([len(idx_j) for idx_j in idx_batch])  # 每个batch拥有的最少数据量,

        for j in range(n_parties):  # 将分配好的数据集随机发放给客户
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]  # 返回的是每个客户端的数据包

    # Quantity-based label imbalance:基于数量的标签失衡
    elif partition == "noniid-#label" :
        num = ii
        K = 10 if args.dataset == 'cifar10' else 100  # 每方固有的标签数  # 总共的标签数
        print('labels, partition=', K, num)
        if num == K:  # 相当于每个标签都均分给每个客户，不仅仅是数量均衡
            net_dataidx_map = {i: np.ndarray(0, dtype=np.int64) for i in range(
                n_parties)}  # net_dataidx_map：net{0: array([], dtype=int64), 1: array([], dtype=int64), 2: array([], dtype=int64), 3: array([], dtype=int64)}
            for i in range(K):
                idx_k = np.where(train_data.target == i)[0]  # 找到标签为i的数据
                np.random.shuffle(idx_k)
                split = np.array_split(idx_k, n_parties)  # 均分为n份
                for j in range(n_parties):
                    net_dataidx_map[j] = np.append(net_dataidx_map[j], split[j])
        else:
            times = [0 for i in range(K)]  # 全0列表，表示这个标签被用了几次
            contain = []
            for i in range(n_parties):
                current = [i % K]  # 取i相对于k的余数比如：current=[2%10]=[2]
                times[i % K] += 1  # 如果客户数大于标签数
                j = 1
                while (j < num):
                    ind = random.randint(0, K - 1)
                    if (ind not in current):  # 如果这个随机数没有出现在current中，将其并入：即这个用户拥有这个标签
                        j = j + 1
                        current.append(ind)
                        times[ind] += 1  # 标签被用次数加1
                contain.append(current)
            # contain:每个客户拥有的标签id [[0, 4, 5], [1, 2, 9], [2, 3, 5], [3, 7, 2], [4, 6, 7]]

            net_dataidx_map = {i: np.ndarray(0, dtype=np.int64) for i in range(n_parties)}
            for i in range(K):
                if times[i] == 0:
                    continue
                idx_k = np.where(train_data.target == i)[0]  # 标签为k的数据集合
                np.random.shuffle(idx_k)
                split = np.array_split(idx_k, times[i])  # 被用了几次就被分成几份，可以知道每个客户端的标签量虽然一定，但数量不一定
                ids = 0
                for j in range(n_parties):
                    if i in contain[j]:
                        net_dataidx_map[j] = np.append(net_dataidx_map[j], split[ids])
                        ids += 1

    # Quantity Skew在数量倾斜方面，本地数据集的大小因各方而异。与基于分布的标签不平衡设置一样，我们使用狄利克雷分布向双方分配不同数量的数据样本。
    elif partition == "iid-diff-quantity":
        idxs = np.random.permutation(n_train)
        min_size = 0
        while min_size < 10:
            proportions = np.random.dirichlet(np.repeat(beta, n_parties))
            proportions = proportions / proportions.sum()
            min_size = np.min(proportions * len(idxs))
        proportions = (np.cumsum(proportions) * len(idxs)).astype(int)[:-1]
        batch_idxs = np.split(idxs, proportions)
        net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}

    # Noise-based feature imbalance基于噪声的特征不平衡，首先将整个数据集随机且平均地划分为多个方。对于每一方，我们向其局部数据集添加不同级别的高斯噪声，以实现不同的特征分布。

    return (train_data, test_data, net_dataidx_map)  # ，traindata_cls_counts
    '''
    #Quantity-based label imbalance:
    elif partition > 0 and partition <= 11:
        num = eval(partition[13:])

        times=[0 for i in range(10)]
        contain=[]
        for i in range(n_parties):
            current=[i%K]
            times[i%K]+=1
            j=1
            while (j<num):
                ind=random.randint(0,K-1)
                if (ind not in current):
                    j=j+1
                    current.append(ind)
                    times[ind]+=1
            contain.append(current)
        net_dataidx_map ={i:np.ndarray(0,dtype=np.int64) for i in range(n_parties)}
        for i in range(K):
            idx_k = np.where(y_train==i)[0]
            np.random.shuffle(idx_k)
            split = np.array_split(idx_k,times[i])
            ids=0
            for j in range(n_parties):
                if i in contain[j]:
                    net_dataidx_map[j]=np.append(net_dataidx_map[j],split[ids])
                    ids+=1

    '''

    '''
def get_dataset(dir,name):
    if name=="mnist":
       
        dir 主要是为了便于在交互式时使用，所以它会试图返回人们感兴趣的名字集合，而不是试图保证结果的严格性或一致性，它具体的行为也可能在不同版本之间改变。
        root:数据路径
        train:是否是训练集或者测试集
        download=true:从互联网下载数据集并把数据集放在root路径中
        transform:图像类型的转换
        

        train_dataset=tv.datasets.MNIST(dir,root='/data/cyn/Fed/data',train=True,download=True,transform=tv.transforms.ToTensor())
        eval_dataset=tv.datasets.MNIST(dir,root='/data/cyn/Fed/data',train=False,transform=tv.transforms.ToTensor())

    elif name=='cifar':
        #设置两个转换格式
        #transforms.Compose是将多个transform组合起来使用（由transform构成的列表）

        transform_train=tv.transforms.Compose([
            #transforms.RandomCrop：切割中心点的位置随机选取
            tv.transforms.RandomCrop(32,padding=4),tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            #tv.transforms.Normalize:给定均值：(R,G,B) 方差：(R,G,B),将会把Tensor正则化
            tv.transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),
        ])

        transform_test=tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),
        ])

        train_dataset=tv.datasets.CIFAR10(root='/data/cyn/Fed/data',dir, train=True,download=True,transform=transform_train)
        eval_dataset = tv.datasets.CIFAR10(root='/data/cyn/Fed/data',dir,  train=False, transform=transform_test)

    return  train_dataset,eval_dataset
'''


'''
def partition_data(conf,train_dataset,datadir,partition,n_parties,beta=0.4):

    
    conf  配置文件
    train_dataset 数据集名称
    partition 数据分布（IID/Non-IID/...）
    
    all_range = list(range(len(train_dataset)))
    #数据独立同分布
    if conf["partition"]=="iid":
        data_len = int(len(train_dataset) /conf['no_module'])
        indices = all_range[id * data_len:(id + 1) * data_len]

    #Non-IID Quantity-based label imbalance
    elif conf["partition"]=="noniid-label":
        #各方拥有的标签数
        K=conf["K"]

        #读取训练集的行数：all_range=N

        #生成随机数
        np.random.seed(2020)


    return indices
'''
