import json
import logging
import os
import random
import numpy as np
import torch
# import matplotlib.pyplot as plt

import torch
from torchvision import datasets, transforms
torch.set_printoptions(sci_mode=False)

def _data_transforms_cifar10_100():
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        # transforms.ToPILImage(),
        # transforms.RandomCrop(32, padding=4),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
    ])

    valid_transform = transforms.Compose([
        # transforms.ToPILImage(),
        # transforms.RandomCrop(32, padding=4),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
    ])

    return train_transform, valid_transform

def fdil_batch_data(data, batch_size, model_name="resnet"):

    data_x = data['x']
    data_y = data['y']

    # randomly shuffle data
    np.random.seed(100)
    rng_state = np.random.get_state()
    np.random.shuffle(data_x)
    np.random.set_state(rng_state)
    np.random.shuffle(data_y)

    batch_data = list()
    for i in range(0, len(data_x), batch_size):
        batched_x = data_x[i:i + batch_size]
        batched_y = data_y[i:i + batch_size]

        batched_x = [i.tolist() for i in batched_x]
        batched_y = [i.tolist() for i in batched_y]

        batched_x = torch.from_numpy(np.asarray(batched_x)).float()
        batched_y = torch.from_numpy(np.asarray(batched_y)).long()
        batch_data.append((batched_x, batched_y))
    return batch_data


def partition_data(label_list, n_nets, alpha):
    # print("*********partition data***************")
    min_size = 0
    K = 100
    N = len(label_list)
    # print("N = " + str(N))
    net_dataidx_map = {}

    while min_size < 10:
        idx_batch = [[] for _ in range(n_nets)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(label_list == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
            ## Balance
            proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)])
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(n_nets):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]

    return net_dataidx_map


def dirichlet_distribution(grouped_data, map, new_users, maximun, batch_size):
    X_train, Y_train, X_test, Y_test = grouped_data
    train_idx_map, test_idx_map = map

    new_train_data = {}
    new_test_data = {}

    for index, idx_list in train_idx_map.items():
        key = new_users[index]
        temp_data = {"x": [X_train[i] for i in idx_list],
                     "y": [Y_train[i] for i in idx_list]}
        new_train_data[key] = temp_data

    for index, idx_list in test_idx_map.items():
        key = new_users[index]
        temp_data = {"x": [X_test[i] for i in idx_list],
                     "y": [Y_test[i] for i in idx_list]}

        new_test_data[key] = temp_data

    train_data_num = 0
    test_data_num = 0
    train_data_local_dict = dict()
    test_data_local_dict = dict()
    train_data_local_num_dict = dict()
    test_data_global = list()
    client_idx = 0
    for u in new_users:
        user_train_data_num = len(new_train_data[u]['x'])
        user_test_data_num = len(new_test_data[u]['x'])
        train_data_num += user_train_data_num
        test_data_num += user_test_data_num
        if user_train_data_num > maximun:
            train_data_local_num_dict[client_idx] = maximun
            new_data = {}
            new_data["x"] = new_train_data[u]["x"][:maximun]
            new_data["y"] = new_train_data[u]["y"][:maximun]
            train_batch = fdil_batch_data(new_data, batch_size)
        else:
            train_data_local_num_dict[client_idx] = user_train_data_num
            train_batch = fdil_batch_data(new_train_data[u], batch_size)
        test_batch = fdil_batch_data(new_test_data[u], batch_size)

        train_data_local_dict[client_idx] = train_batch
            # test_data_local_dict[client_idx] = test_batch
        test_data_global += test_batch
        client_idx += 1
    random.shuffle(test_data_global)



    return train_data_local_num_dict, train_data_local_dict,test_data_global
    

def load_cifar10_100(client_num_in_total, alpha, batch_size, maximun):
    class_num = 100
    n = 10
    labels = range(class_num)
    new_users = []

    for i in range(client_num_in_total):
        if i < 10:
            new_users.append("f_0000" + str(i))
        else:
            new_users.append("f_000" + str(i))
    
    train_transform, test_transform = _data_transforms_cifar10_100()
    grouped_data = []

    # cifar10 all   
    trainset = datasets.CIFAR10(root='/home/ycli/CL_WYY_2403/Dataset/CIFAR10', train=True,
                                            download=True, transform=train_transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset),
                                            shuffle=False, num_workers=2)

    testset = datasets.CIFAR10(root='/home/ycli/CL_WYY_2403/Dataset/CIFAR10', train=False,
                                        download=True, transform=test_transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset),
                                            shuffle=False, num_workers=2)

    X_train, Y_train = next(iter(trainloader))
    X_test, Y_test = next(iter(testloader))

    grouped_data.append((X_train, Y_train, X_test, Y_test))#50000,10000

    # cifar100   这里划分了10个任务，在后面加载时选5个

    trainset = datasets.CIFAR100(root='/home/ycli/Dataset/CIFAR100', train=True,
                                            download=True, transform=train_transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset),
                                            shuffle=False, num_workers=2)

    testset = datasets.CIFAR100(root='/home/ycli/Dataset/CIFAR100', train=False,
                                        download=True, transform=test_transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset),
                                            shuffle=False, num_workers=2)

    X_train, Y_train = next(iter(trainloader))
    X_test, Y_test = next(iter(testloader))

    for i in range(0, len(labels), n):
        if i+n > len(labels):  
            continue
        label_group = labels[i:i+n]
            
        X_train_group = X_train[torch.any(Y_train.unsqueeze(1) == torch.tensor(label_group), dim=1)]
        Y_train_group = Y_train[torch.any(Y_train.unsqueeze(1) == torch.tensor(label_group), dim=1)]
            
        X_test_group = X_test[torch.any(Y_test.unsqueeze(1) == torch.tensor(label_group), dim=1)]
        Y_test_group = Y_test[torch.any(Y_test.unsqueeze(1) == torch.tensor(label_group), dim=1)]
            
        grouped_data.append((X_train_group, Y_train_group, X_test_group, Y_test_group))


    print('load cifar-10 data')
    train_map = partition_data(grouped_data[0][1], client_num_in_total, alpha)
    test_map  = partition_data(grouped_data[0][3], client_num_in_total, alpha)
    map = (train_map, test_map)
    train_num_dict, init_train, init_test = dirichlet_distribution(grouped_data[0], map, new_users, maximun, batch_size)

    #选10-60
    print('load increment data1')
    train_map1 = partition_data(grouped_data[2][1], client_num_in_total, alpha)
    test_map1  = partition_data(grouped_data[2][3], client_num_in_total, alpha)
    map1 = (train_map1, test_map1)
    _, incre_train1, incre_test1 = dirichlet_distribution(grouped_data[2], map1, new_users, maximun, batch_size)   

    print('load increment data2')
    train_map2 = partition_data(grouped_data[3][1], client_num_in_total, alpha)
    test_map2  = partition_data(grouped_data[3][3], client_num_in_total, alpha)
    map2 = (train_map2, test_map2)
    _, incre_train2, incre_test2 = dirichlet_distribution(grouped_data[3], map2, new_users, maximun, batch_size)   

    print('load increment data3')
    train_map3 = partition_data(grouped_data[4][1], client_num_in_total, alpha)
    test_map3  = partition_data(grouped_data[4][3], client_num_in_total, alpha)
    map3 = (train_map3, test_map3)
    _, incre_train3, incre_test3 = dirichlet_distribution(grouped_data[4], map3, new_users, maximun, batch_size)   

    print('load increment data4')
    train_map4 = partition_data(grouped_data[5][1], client_num_in_total, alpha)
    test_map4  = partition_data(grouped_data[5][3], client_num_in_total, alpha)
    map4 = (train_map4, test_map4)
    _, incre_train4, incre_test4 = dirichlet_distribution(grouped_data[5], map4, new_users, maximun, batch_size)   

    print('load increment data5')
    train_map5 = partition_data(grouped_data[6][1], client_num_in_total, alpha)
    test_map5  = partition_data(grouped_data[6][3], client_num_in_total, alpha)
    map5 = (train_map5, test_map5)
    _, incre_train5, incre_test5 = dirichlet_distribution(grouped_data[6], map5, new_users, maximun, batch_size)   


    incremental_train_data = {}
    incremental_test_data = {}
    for i in range(client_num_in_total):
        incremental_train_data[i] = []
        incremental_test_data[i] = []
       
    for i in range(client_num_in_total):
        incremental_train_data[i].append(incre_train1[i])
        incremental_train_data[i].append(incre_train2[i])
        incremental_train_data[i].append(incre_train3[i])
        incremental_train_data[i].append(incre_train4[i])
        incremental_train_data[i].append(incre_train5[i])

        incremental_test_data[i].append(incre_test1)
        incremental_test_data[i].append(incre_test2)
        incremental_test_data[i].append(incre_test3)
        incremental_test_data[i].append(incre_test4)
        incremental_test_data[i].append(incre_test5)
    
    return [train_num_dict, init_train, init_test, incremental_train_data, incremental_test_data, 60]


if __name__ == '__main__':
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(print)
    
    print("finish")