#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import numpy as np
from torchvision import datasets, transforms


def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

def mnist_rotated(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/(num_users*2))
    portion = int(num_items/5)
    dict_users, first_idxs, sec_idxs = {}, [i for i in range(int(len(dataset)/2))], [i for i in range(int(len(dataset)/2), len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(first_idxs, num_items+(i%5 - 2)*portion, replace=False)) | set(np.random.choice(sec_idxs, num_items-(i%5 - 2)*portion, replace=False))
        first_idxs = list(set(first_idxs) - dict_users[i])
        sec_idxs = list(set(sec_idxs) - dict_users[i])
    return dict_users

def mnist_rotated2(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/(num_users*2))
    portion = int(num_items/10)
    dict_users, first_idxs, sec_idxs = {}, [i for i in range(int(len(dataset)/2))], [i for i in range(int(len(dataset)/2), len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(first_idxs, num_items+2*(i%5 - 2)*portion, replace=False)) | set(np.random.choice(sec_idxs, num_items-2*(i%5 - 2)*portion, replace=False))
        first_idxs = list(set(first_idxs) - dict_users[i])
        sec_idxs = list(set(sec_idxs) - dict_users[i])
    return dict_users

def emnist_rotated(dataset, num_users):
    dict_users = {}
    total_len = len(dataset)
    half = num_users // 2
    first_half_indices = [i for i in range(total_len // 2)]
    second_half_indices = [i for i in range(total_len // 2, total_len)]
    num_items_first = len(first_half_indices) // half
    num_items_second = len(second_half_indices) // (num_users - half)

    for i in range(num_users):
        if i < half:
            dict_users[i] = list(np.random.choice(first_half_indices, num_items_first, replace=False))
            first_half_indices = list(set(first_half_indices) - set(dict_users[i]))
        else:
            dict_users[i] = list(np.random.choice(second_half_indices, num_items_second, replace=False))
            second_half_indices = list(set(second_half_indices) - set(dict_users[i]))
    return dict_users

def unbalanced(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    def filter_dataset_by_label(dataset):
        odd_indices = [i for i, (_, label) in enumerate(dataset) if label %2 == 0]
        even_indices = [i for i, (_, label) in enumerate(dataset) if label %2 == 1]
        return odd_indices, even_indices
    num_items = int(len(dataset)/(num_users*2))
    portion = int(num_items/10)
    dict_users = {}
    first_idxs, sec_idxs = filter_dataset_by_label(dataset)
    for i in range(num_users):
        if i < num_users-1:
            dict_users[i] = set(np.random.choice(first_idxs, num_items+2*(i%5 - 2)*portion, replace=False)) | set(np.random.choice(sec_idxs, num_items-2*(i%5 - 2)*portion, replace=False))
            first_idxs = list(set(first_idxs) - dict_users[i])
            sec_idxs = list(set(sec_idxs) - dict_users[i])
            dict_users[i] = list(dict_users[i])
        else:
            dict_users[i] = list(set(first_idxs) | set(sec_idxs))
    return dict_users

def unbalanced_lab(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    def filter_dataset_by_label(dataset):
        odd_indices = [i for i, (_, label) in enumerate(dataset) if label %2 == 0]
        even_indices = [i for i, (_, label) in enumerate(dataset) if label %2 == 1]
        return odd_indices, even_indices
    num_items = int(len(dataset)/(20))
    dict_users = {}
    first_idxs, sec_idxs = filter_dataset_by_label(dataset)
    for i in range(5):
        dict_users[i] = list(set(np.random.choice(first_idxs, (2*i+1)*num_items, replace=False)) | set(np.random.choice(sec_idxs, (9-2*i)*num_items, replace=False)))
    return dict_users

def cifar_rotated2(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/(num_users*2))
    portion = int(num_items/10)
    dict_users, first_idxs, sec_idxs = {}, [i for i in range(int(len(dataset)/2))], [i for i in range(int(len(dataset)/2), len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(first_idxs, num_items+2*(i%5 - 2)*portion, replace=False)) | set(np.random.choice(sec_idxs, num_items-2*(i%5 - 2)*portion, replace=False))
        first_idxs = list(set(first_idxs) - dict_users[i])
        sec_idxs = list(set(sec_idxs) - dict_users[i])
    return dict_users

def mnist_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # 60,000 training imgs -->  200 imgs/shard X 300 shards
    num_shards, num_imgs = 200, 300
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.datasets[0].dataset.train_labels.numpy()

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign 2 shards/client
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users

def mnist_2c(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # 60,000 training imgs -->  200 imgs/shard X 300 shards
    num_shards, num_imgs = 200, 300
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign 2 shards/client
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users


def mnist_noniid_unequal(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset s.t clients
    have unequal amount of data
    :param dataset:
    :param num_users:
    :returns a dict of clients with each clients assigned certain
    number of training imgs
    """
    # 60,000 training imgs --> 50 imgs/shard X 1200 shards
    num_shards, num_imgs = 1200, 50
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # Minimum and maximum shards assigned per client:
    min_shard = 1
    max_shard = 30

    # Divide the shards into random chunks for every client
    # s.t the sum of these chunks = num_shards
    random_shard_size = np.random.randint(min_shard, max_shard+1,
                                          size=num_users)
    random_shard_size = np.around(random_shard_size /
                                  sum(random_shard_size) * num_shards)
    random_shard_size = random_shard_size.astype(int)

    # Assign the shards randomly to each client
    if sum(random_shard_size) > num_shards:

        for i in range(num_users):
            # First assign each client 1 shard to ensure every client has
            # atleast one shard of data
            rand_set = set(np.random.choice(idx_shard, 1, replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        random_shard_size = random_shard_size-1

        # Next, randomly assign the remaining shards
        for i in range(num_users):
            if len(idx_shard) == 0:
                continue
            shard_size = random_shard_size[i]
            if shard_size > len(idx_shard):
                shard_size = len(idx_shard)
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)
    else:

        for i in range(num_users):
            shard_size = random_shard_size[i]
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        if len(idx_shard) > 0:
            # Add the leftover shards to the client with minimum images:
            shard_size = len(idx_shard)
            # Add the remaining shard to the client with lowest data
            k = min(dict_users, key=lambda x: len(dict_users.get(x)))
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[k] = np.concatenate(
                    (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

    return dict_users


def cifar_iid(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


def cifar_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 40, 1250
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    # labels = dataset.train_labels.numpy()
    labels = np.array(dataset.targets)

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign
    for i in range(num_users):
        '''
        flag = 1
        while flag:
            cc = np.random.choice(idx_shard, 2, replace=False)
            if abs(cc[0]-cc[1])> 1 or i == num_users-1:
                flag = 0
        rand_set = set(cc)
        
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
        '''
        mod = i % 5
        if i < 5:
            r1 = 4*i
            r2 = 4*(i+5)
        elif i < 10:
            r1 = 4*mod +1
            r2 = 4*(mod+5) +1
        elif i < 15:
            r1 = 8*mod + 2
            r2 = 8*mod + 4 + 2
        else:
            r1 = 8*mod + 3
            r2 = 8* mod + 4 + 3

        dict_users[i] = np.concatenate(
                (idxs[r1*num_imgs:(r1+1)*num_imgs], idxs[r2*num_imgs:(r2+1)*num_imgs]), axis=0)
    return dict_users


if __name__ == '__main__':
    dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,),
                                                            (0.3081,))
                                   ]))
    num = 100
    d = mnist_noniid(dataset_train, num)
