#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import numpy as np
from torchvision import datasets, transforms

def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users



def mnist_noniid(dataset, num_users):
    l = len(dataset)
    labels = np.arange(0, l,dtype='int64')
    idxs = np.arange(0, l,dtype='int64')
    for i in range(0, l):
        labels[i] = dataset[i][1]
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}

    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]
    shard = int(l / num_users)

    for i in range(0, num_users):
        dict_users[i] = idxs[i * shard:(i + 1) * shard]

    return dict_users


def mnist_partial_noniid(dataset,num_users,portion):
    l = len(dataset)
    labels = np.arange(0, l, dtype='int64')
    idxs = np.arange(0, l, dtype='int64')
    for i in range(0, l):
        labels[i] = dataset[i][1]
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}

    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]
    shard = int(l/ num_users * portion)
    base = int(l/num_users)

    remaining_idxs = np.array([],dtype='int64')
    for i in range(0, num_users):
        dict_users[i] = idxs[i*base:i*base+shard]
        remaining_idxs = np.append(remaining_idxs,idxs[i*base+shard:(i+1)*base])
    np.random.shuffle(remaining_idxs)
    rl = int(remaining_idxs.shape[0]/num_users)
    for i in range(0, num_users):
        dict_users[i] = np.append(dict_users[i] ,remaining_idxs[i*rl:(i+1)*rl] )

    return dict_users






def cifar_iid(dataset,num_users):
    l = len(dataset)
    idxs = np.arange(0,l)
    np.random.shuffle(idxs)
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
    shard = int(l/num_users)
    for i in range(0,num_users):
        dict_users[i] = idxs[i*shard:(i+1)*shard]
    return dict_users




def cifar_noniid(dataset, num_users, portion):
    l = len(dataset)
    cl = int(l / 10)
    labels = np.arange(0, l,dtype='int64')
    idxs = np.arange(0, l,dtype='int64')
    for i in range(0, l):
        labels[i] = dataset[i][1]
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}

    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]

    idx_sort = idxs_labels[0, :]
    idx_iid = np.array([],dtype='int64')
    for i in range(0, 10):
        idx_iid = np.append(idx_iid, idx_sort[int((1 - portion) * cl * i ): (i + 1) * cl])
    np.random.shuffle(idx_iid)
    # idxs = idxs_labels[0,:]
    shard = int(l / num_users * portion)
    shard_iid = int(l / num_users * (1 - portion))

    for i in range(0, num_users):
        dc = int(i / (num_users / 10))
        id = int(i % (num_users / 10))
        dict_users[i] = np.append(idx_sort[int(l / 10 * dc + id * shard):int(l / 10 * dc + (id + 1) * shard)],
                                  idx_iid[i * shard_iid:(i + 1) * shard_iid])

    return dict_users












if __name__ == '__main__':
    dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                   ]))
    num = 100
    d = mnist_noniid(dataset_train, num)
