#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import numpy as np
from torchvision import datasets, transforms
import pdb
import random

random.seed(886)
np.random.rand(886)

def cifar_iid_seen(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    pre_train_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    pre_val_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    # fine_train_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    # fine_test_dict  = {i: np.array([], dtype='int64') for i in range(num_users)}
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    
    for val in dict_users:
        list_dict_users = list(dict_users[val])
        random.shuffle(list_dict_users)
        client_all_len_train = len(dict_users[val])
        
        pre_train_dict[val] = set(list_dict_users[:int(0.75*client_all_len_train)])
        pre_val_dict[val] = set(list_dict_users[int(0.75*client_all_len_train) : ])

    # return dict_users
    return pre_train_dict, pre_val_dict





def cifar_iid_unseen_train(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    train_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    val_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    test_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])

    for val in dict_users:
        list_dict_users = list(dict_users[val])
        random.shuffle(list_dict_users)
        client_all_len_train = len(dict_users[val])
        
        train_dict[val] = set(list_dict_users[:int(0.8*client_all_len_train)])
        val_dict[val] = set(list_dict_users[int(0.8*client_all_len_train) : ])
    
    return train_dict, val_dict





def cifar_iid_unseen_train_noquery(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    train_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    # val_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    # test_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])

    for val in dict_users:
        list_dict_users = list(dict_users[val])
        random.shuffle(list_dict_users)
        client_all_len_train = len(dict_users[val])
        
        train_dict[val] = set(list_dict_users[:])
    
    return train_dict




def cifar_iid_unseen_test(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    train_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    val_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    test_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])

    return dict_users

def cifar_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 200, 250
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0).astype(int)

    return dict_users



def cifar_noniid_alpha_unseen_train(dataset, num_users, alpha, num_class):

    data_list = dataset.datasets
    labels = []
    for ds in data_list:
        labels.extend(ds.targets)
    labels = np.array(labels)
    min_size = 0
    K = num_class # num classes
    N = len(labels)
    net_dataidx_map = {}
    train_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    val_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    test_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    
    while min_size < 10:
        idx_batch = [[] for _ in range(num_users)]
        # for k in range(K):
        for k in K:
            idx_k = np.where(labels == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, num_users))
            
            ## Balance
            proportions = np.array([p*(len(idx_j)<N/num_users) for p,idx_j in zip(proportions,idx_batch)])
            
            proportions = proportions/proportions.sum()
            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])
        
    
    for j in range(num_users):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]
    for val in net_dataidx_map:
        list_dict_users = list(net_dataidx_map[val])
        random.shuffle(list_dict_users)
        client_all_len_train = len(net_dataidx_map[val])
        
        train_dict[val] = set(list_dict_users[:int(0.80*client_all_len_train)])
        val_dict[val] = set(list_dict_users[int(0.80*client_all_len_train) : ])
    # return net_dataidx_map
    return train_dict, val_dict



def cifar_noniid_alpha_unseen_train_noquery(dataset, num_users, alpha, num_class):

    data_list = dataset.datasets
    labels = []
    for ds in data_list:
        labels.extend(ds.targets)
    labels = np.array(labels)
    min_size = 0
    K = num_class # num classes
    N = len(labels)
    net_dataidx_map = {}
    train_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    val_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    test_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    
    while min_size < 10:
        idx_batch = [[] for _ in range(num_users)]
        # for k in range(K):
        for k in K:
            idx_k = np.where(labels == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, num_users))
            
            ## Balance
            proportions = np.array([p*(len(idx_j)<N/num_users) for p,idx_j in zip(proportions,idx_batch)])
            
            proportions = proportions/proportions.sum()
            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])
        
    
    for j in range(num_users):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]
    for val in net_dataidx_map:
        list_dict_users = list(net_dataidx_map[val])
        random.shuffle(list_dict_users)
        client_all_len_train = len(net_dataidx_map[val])
        
        train_dict[val] = set(list_dict_users[:])
    return train_dict


def cifar_noniid_alpha_unseen_test(dataset, num_users, alpha, num_class):

    data_list = dataset.datasets
    labels = []
    for ds in data_list:
        labels.extend(ds.targets)
    labels = np.array(labels)
    min_size = 0
    K = num_class # num classes
    N = len(labels)
    net_dataidx_map = {}
    train_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    val_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    test_dict = {i: np.array([], dtype='int64') for i in range(num_users)}
    
    while min_size < 10:
        idx_batch = [[] for _ in range(num_users)]
        for k in range(K):
            idx_k = np.where(labels == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, num_users))
            ## Balance
            proportions = np.array([p*(len(idx_j)<N/num_users) for p,idx_j in zip(proportions,idx_batch)])
            proportions = proportions/proportions.sum()
            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])
            
    for j in range(num_users):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]
    
    return net_dataidx_map


if __name__ == '__main__':
    dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                   ]))
    num = 100
    d = mnist_noniid(dataset_train, num)
