import json
import logging
import os
import random
import numpy as np
import torch
# import matplotlib.pyplot as plt

import torch
from torchvision import datasets, transforms
torch.set_printoptions(sci_mode=False)

def _data_transforms_cifar10():
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        # transforms.ToPILImage(),
        # transforms.RandomCrop(32, padding=4),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
    ])

    valid_transform = transforms.Compose([
        # transforms.ToPILImage(),
        # transforms.RandomCrop(32, padding=4),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
    ])

    return train_transform, valid_transform

def fdil_batch_data(data, batch_size, model_name="resnet"):

    data_x = data['x']
    data_y = data['y']

    # randomly shuffle data
    np.random.seed(100)
    rng_state = np.random.get_state()
    np.random.shuffle(data_x)
    np.random.set_state(rng_state)
    np.random.shuffle(data_y)

    batch_data = list()
    for i in range(0, len(data_x), batch_size):
        batched_x = data_x[i:i + batch_size]
        batched_y = data_y[i:i + batch_size]

        batched_x = [i.tolist() for i in batched_x]
        batched_y = [i.tolist() for i in batched_y]

        batched_x = torch.from_numpy(np.asarray(batched_x)).float()
        batched_y = torch.from_numpy(np.asarray(batched_y)).long()
        batch_data.append((batched_x, batched_y))
    return batch_data

# Dirichlet划分为n_nets份数据（client_total)
def partition_data(label_list, n_nets, alpha):
    # print("*********partition data***************")
    min_size = 0
    K = 10
    N = len(label_list)
    # print("N = " + str(N))
    net_dataidx_map = {}

    while min_size < 10:
        idx_batch = [[] for _ in range(n_nets)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(label_list == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
            ## Balance
            proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)])
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(n_nets):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]

    return net_dataidx_map

# partition_data划分后的数据分配给客户端并划分训练测试集
def dirichlet_distribution(grouped_data, map, new_users, maximun, batch_size):
    X_train, Y_train, X_test, Y_test = grouped_data
    train_idx_map, test_idx_map = map

    new_train_data = {}
    new_test_data = {}

    for index, idx_list in train_idx_map.items():
        key = new_users[index]
        temp_data = {"x": [X_train[i] for i in idx_list],
                     "y": [Y_train[i] for i in idx_list]}
        new_train_data[key] = temp_data

    for index, idx_list in test_idx_map.items():
        key = new_users[index]
        temp_data = {"x": [X_test[i] for i in idx_list],
                     "y": [Y_test[i] for i in idx_list]}

        new_test_data[key] = temp_data

    train_data_num = 0
    test_data_num = 0
    train_data_local_dict = dict()
    test_data_local_dict = dict()
    train_data_local_num_dict = dict()
    test_data_global = list()
    client_idx = 0
    for u in new_users:
        user_train_data_num = len(new_train_data[u]['x'])
        user_test_data_num = len(new_test_data[u]['x'])
        train_data_num += user_train_data_num
        test_data_num += user_test_data_num
        if user_train_data_num > maximun:
            train_data_local_num_dict[client_idx] = maximun
            new_data = {}
            new_data["x"] = new_train_data[u]["x"][:maximun]
            new_data["y"] = new_train_data[u]["y"][:maximun]
            train_batch = fdil_batch_data(new_data, batch_size)
        else:
            train_data_local_num_dict[client_idx] = user_train_data_num
            train_batch = fdil_batch_data(new_train_data[u], batch_size)
        test_batch = fdil_batch_data(new_test_data[u], batch_size)

        train_data_local_dict[client_idx] = train_batch
            # test_data_local_dict[client_idx] = test_batch
        test_data_global += test_batch
        client_idx += 1
    random.shuffle(test_data_global)
    # print('test_data_global',len(test_data_global))
    # # test_data_global = test_data_global[:len(test_data_global)//10]
    # print('test_data_global',len(test_data_global))
    return train_data_local_num_dict, train_data_local_dict,test_data_global
    

def load_cifar10(client_num_in_total,alpha,batch_size,maximun):
    class_num = 10
    n = 2
    labels = range(class_num)
    new_users = []

    for i in range(client_num_in_total):
        if i < 10:
            new_users.append("f_0000" + str(i))
        else:
            new_users.append("f_000" + str(i))


    train_transform, test_transform = _data_transforms_cifar10()

    trainset = datasets.CIFAR10(root='/home/ycli/Dataset/CIFAR10', train=True,
                                            download=True, transform=train_transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset),
                                            shuffle=False, num_workers=2)


    testset = datasets.CIFAR10(root='/home/ycli/Dataset/CIFAR10', train=False,
                                        download=True, transform=test_transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset),
                                            shuffle=False, num_workers=2)


    X_train, Y_train = next(iter(trainloader))
    X_test, Y_test = next(iter(testloader))

    grouped_data = []
    for i in range(0, len(labels), n):
        if i+n > len(labels):  
            continue
        label_group = labels[i:i+n]  #两个类别

        X_train_group = X_train[torch.any(Y_train.unsqueeze(1) == torch.tensor(label_group), dim=1)]
        Y_train_group = Y_train[torch.any(Y_train.unsqueeze(1) == torch.tensor(label_group), dim=1)]

        X_test_group = X_test[torch.any(Y_test.unsqueeze(1) == torch.tensor(label_group), dim=1)]
        Y_test_group = Y_test[torch.any(Y_test.unsqueeze(1) == torch.tensor(label_group), dim=1)]

        grouped_data.append((X_train_group, Y_train_group, X_test_group, Y_test_group))#（5，4）
# grouped_data 将包含按照标签分组后的训练和测试数据，每个组包含两个类别（n=2）的数据。
# 最终，grouped_data是一个列表，其中每个元素是一个四元组，包含某个标签组的训练数据、训练标签、测试数据和测试标签。
    print('load init data')
    train_map = partition_data(grouped_data[0][1], client_num_in_total, alpha)
    test_map  = partition_data(grouped_data[0][3], client_num_in_total, alpha)
    map = (train_map, test_map)
    train_num_dict, init_train, init_test = dirichlet_distribution(grouped_data[0], map, new_users, maximun, batch_size)
    # print(' init_train',len(init_train[0]))
    # print(' grop_test',len(grouped_data[0][3]))
    print('load increment data1')
    train_map1 = partition_data(grouped_data[1][1], client_num_in_total, alpha)
    test_map1  = partition_data(grouped_data[1][3], client_num_in_total, alpha)
    map1 = (train_map1, test_map1)
    _, incre_train1, incre_test1 = dirichlet_distribution(grouped_data[1], map1, new_users, maximun, batch_size)   

    print('load increment data2')
    train_map2 = partition_data(grouped_data[2][1], client_num_in_total, alpha)
    test_map2  = partition_data(grouped_data[2][3], client_num_in_total, alpha)
    map2 = (train_map2, test_map2)
    _, incre_train2, incre_test2 = dirichlet_distribution(grouped_data[2], map2, new_users, maximun, batch_size)   

    print('load increment data3')
    train_map3 = partition_data(grouped_data[3][1], client_num_in_total, alpha)
    test_map3  = partition_data(grouped_data[3][3], client_num_in_total, alpha)
    map3 = (train_map3, test_map3)
    _, incre_train3, incre_test3 = dirichlet_distribution(grouped_data[3], map3, new_users, maximun, batch_size)   

    print('load increment data4')
    train_map4 = partition_data(grouped_data[4][1], client_num_in_total, alpha)
    test_map4  = partition_data(grouped_data[4][3], client_num_in_total, alpha)
    map4 = (train_map4, test_map4)
    _, incre_train4, incre_test4 = dirichlet_distribution(grouped_data[4], map4, new_users, maximun, batch_size)   

    incremental_train_data = {}
    incremental_test_data = {}
    for i in range(client_num_in_total):
        incremental_train_data[i] = []
        incremental_test_data[i] = []
       
    for i in range(client_num_in_total):
        incremental_train_data[i].append(incre_train1[i])#incre_train1[i]都是按batch分的
        incremental_train_data[i].append(incre_train2[i])
        incremental_train_data[i].append(incre_train3[i])
        incremental_train_data[i].append(incre_train4[i])

        incremental_test_data[i].append(incre_test1)
        incremental_test_data[i].append(incre_test2)
        incremental_test_data[i].append(incre_test3)
        incremental_test_data[i].append(incre_test4)

    return [train_num_dict, init_train, init_test, incremental_train_data, incremental_test_data, 10]


if __name__ == '__main__':
    # logging.basicConfig()
    # logger = logging.getLogger()
    # logger.setLevel(print)
    
    print("finish")