import logging

import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

from .datasets import CIFAR10_truncated  # 真正运行时采用

# from datasets import CIFAR10_truncated#测试本文件采用

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)






class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def _data_transforms_cifar10():
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    train_transform.transforms.append(Cutout(16))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    return train_transform, valid_transform


def load_cifar10_data(datadir):
    train_transform, test_transform = _data_transforms_cifar10()

    cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=train_transform)
    cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=test_transform)

    X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target
    X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target

    return (X_train, y_train, X_test, y_test)


def partition_data_dataset(X_train, y_train, n_nets, alpha,partition_rate=0):
    min_size = 0
    # k是标签数量
    K = 10
    N = y_train.shape[0]
    net_dataidx_map = {}
    indices = np.random.permutation(N)
    split_index = int(partition_rate * N)
    y_train_public_index = list(indices)[:split_index]

    while min_size < 10:
        # print(min_size)
        idx_batch = [[] for _ in range(n_nets)]
        # for each class in the dataset

        # n_nets表示模型的数量
        for k in range(K):  # K个类别

            idx_k = np.where(y_train == k)[0]  # label为k的样本的index
            idx_k = np.setdiff1d(idx_k, y_train_public_index)

            np.random.seed(k)  # 设置随机种子，希望train和test的划分一样
            proportions = np.random.dirichlet(np.repeat(alpha, n_nets))  # 返回一个比例，应该是每个net的数据的比例，例如[0.1,0.1,0.2,0.2,0.4]

            np.random.shuffle(idx_k)

            # print("proportions1",proportions)
            proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)])
            # print("proportions2",proportions)
            proportions = proportions / proportions.sum()
            # print("proportions3",proportions)
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            # print("proportions4",proportions)#此处的proportions应该都不改变
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(n_nets ):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]
    net_dataidx_map_train_public = y_train_public_index
    return net_dataidx_map,net_dataidx_map_train_public


def partition_data(dataset, datadir, partition, n_nets, alpha, partition_rate=0):
    print("*********partition data***************")
    X_train, y_train, X_test, y_test = load_cifar10_data(datadir)
    n_train = X_train.shape[0]
    n_test = X_test.shape[0]

    if partition == "homo":
        total_num = n_train
        test_total_num = n_test
        idxs = np.random.permutation(total_num)
        idxs_test = np.random.permutation(test_total_num)
        if partition_rate:
            batch_idxs = np.array_split(idxs[:int(total_num*(1-partition_rate))], n_nets)
        else:
            batch_idxs = np.array_split(idxs, n_nets)
        batch_idxs_test = np.array_split(idxs_test, n_nets)
        net_dataidx_map_train = {i: batch_idxs[i] for i in range(n_nets)}
        net_dataidx_map_train_public=idxs[int(total_num*(1 - partition_rate)):]
        net_dataidx_map_test={i: batch_idxs_test[i] for i in range(n_nets)}

    elif partition == "hetero":  # 在此处分割数据
        net_dataidx_map_train, net_dataidx_map_train_public = partition_data_dataset(X_train, y_train, n_nets, alpha,
                                                                             partition_rate)
        net_dataidx_map_test = partition_data_dataset(X_test, y_test, n_nets, alpha)[0]

    else:
        raise Exception("partition arg error")

    return X_train, y_train, X_test, y_test, net_dataidx_map_train, net_dataidx_map_test, net_dataidx_map_train_public


# for centralized training
def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None):
    return get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs_train, dataidxs_test)


# for local devices
def get_dataloader_test(dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test):
    return get_dataloader_test_CIFAR10(datadir, train_bs, test_bs, dataidxs_train, dataidxs_test)


def get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None):
    dl_obj = CIFAR10_truncated

    transform_train, transform_test = _data_transforms_cifar10()

    train_ds = dl_obj(datadir, dataidxs=dataidxs_train, train=True, transform=transform_train, download=True)
    test_ds = dl_obj(datadir, dataidxs=dataidxs_test, train=False, transform=transform_test, download=True)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=False, drop_last=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=True)

    return train_dl, test_dl


def get_dataloader_test_CIFAR10(datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None):
    dl_obj = CIFAR10_truncated

    transform_train, transform_test = _data_transforms_cifar10()

    train_ds = dl_obj(datadir, dataidxs=dataidxs_train, train=True, transform=transform_train, download=True)
    test_ds = dl_obj(datadir, dataidxs=dataidxs_test, train=False, transform=transform_test, download=True)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=False, drop_last=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=True)

    return train_dl, test_dl


# def load_partition_data_distributed_cifar10(process_id, dataset, data_dir, partition_method, partition_alpha,
#                                             client_number, batch_size):
#     X_train, y_train, X_test, y_test, net_dataidx_map_train, traindata_cls_counts, public_dataidx_map_train, public_dataidx_map_test  = partition_data(dataset,
#                                                                                                    data_dir,
#                                                                                                    partition_method,
#                                                                                                    client_number,
#                                                                                                    partition_alpha)
#     class_num = len(np.unique(y_train))
#     logging.info("traindata_cls_counts = " + str(traindata_cls_counts))
#     train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)])
#
#     # get global test data
#     if process_id == 0:
#         train_data_global, test_data_global = get_dataloader(dataset, data_dir, batch_size, batch_size)
#         logging.info("train_dl_global number = " + str(len(train_data_global)))
#         logging.info("test_dl_global number = " + str(len(test_data_global)))
#         train_data_local = None
#         test_data_local = None
#         local_data_num = 0
#     else:
#         # get local dataset
#         dataidxs = net_dataidx_map[process_id - 1]
#         local_data_num = len(dataidxs)
#         logging.info("rank = %d, local_sample_number = %d" % (process_id, local_data_num))
#         # training batch size = 64; algorithms batch size = 32
#         train_data_local, test_data_local = get_dataloader(dataset, data_dir, batch_size, batch_size,
#                                                            dataidxs)
#         logging.info("process_id = %d, batch_num_train_local = %d, batch_num_test_local = %d" % (
#             process_id, len(train_data_local), len(test_data_local)))
#         train_data_global = None
#         test_data_global = None
#     return train_data_num, train_data_global, test_data_global, local_data_num, train_data_local, test_data_local, class_num


def load_partition_data_cifar10(dataset, data_dir, partition_method, partition_alpha, client_number, batch_size,partition_rate):
    X_train, y_train, X_test, y_test, net_dataidx_map_train, net_dataidx_map_test, net_dataidx_map_train_public = partition_data(dataset,
                                                                                                   data_dir,
                                                                                                   partition_method,
                                                                                                   client_number,
                                                                                                   partition_alpha,
                                                                                                   partition_rate)
    class_num_train = len(np.unique(y_train))
    class_num_test = len(np.unique(y_test))
    # logging.info("traindata_cls_counts = " + str(traindata_cls_counts))
    train_data_num = sum([len(net_dataidx_map_train[r]) for r in range(client_number)])
    test_data_num = sum([len(net_dataidx_map_test[r]) for r in range(client_number)])

    train_data_global, test_data_global = get_dataloader(dataset, data_dir, batch_size, batch_size)
    trian_data_public=get_dataloader(dataset,data_dir,batch_size,batch_size,dataidxs_train=net_dataidx_map_train_public)[0]
    #
    # logging.info("train_dl_global number = " + str(len(train_data_global)))
    # logging.info("test_dl_global number = " + str(len(test_data_global)))
    # test_data_num = len(test_data_global)

    # get local dataset
    data_local_num_dict_train = dict()
    data_local_num_dict_test = dict()
    train_data_local_dict = dict()
    test_data_local_dict = dict()

    for client_idx in range(client_number):
        dataidxs_train = net_dataidx_map_train[client_idx]
        dataidxs_test = net_dataidx_map_test[client_idx]

        local_data_num_train = len(dataidxs_train)
        local_data_num_test = len(dataidxs_test)

        data_local_num_dict_train[client_idx] = local_data_num_train
        data_local_num_dict_test[client_idx] = local_data_num_test

        # training batch size = 64; algorithms batch size = 32
        train_data_local, test_data_local = get_dataloader(dataset, data_dir, batch_size, batch_size,
                                                           dataidxs_train, dataidxs_test)

        train_data_local_dict[client_idx] = train_data_local
        test_data_local_dict[client_idx] = test_data_local


    return train_data_num, test_data_num, train_data_global, test_data_global,trian_data_public, \
           data_local_num_dict_train, data_local_num_dict_test,train_data_local_dict, test_data_local_dict, class_num_train,class_num_test
    #训练集数据数量、测试集数据数量、全局训练集提取器、全局测试集提取器、公共数据集合的loader、每个clients本例训练数据数量、每个client本地测试数据数量、每个客户端训练集的dataloader、每个客户端测试机的dataloader、训练集和测试集类的数目
