#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
import numpy as np
from torchvision import datasets, transforms
from collections import Counter
from itertools import combinations,cycle

import pdb

np.random.seed(1)
torch.manual_seed(1)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(1)

def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


def mnist_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # 60,000 training imgs -->  200 imgs/shard X 300 shards
    num_shards, num_imgs = 200, 300
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()



    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign 2 shards/client
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users


def mnist_noniid_degree4(dataset, num_users, num_classes):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # 60,000 training imgs -->  200 imgs/shard X 300 shards
    num_imgs = 60000 // (num_users * num_classes)
    num_shards = 60000 // num_imgs
    #num_shards, num_imgs = 400, 150
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign 4 shards/client
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, num_classes, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users






def mnist_noniid_unequal(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset s.t clients
    have unequal amount of data
    :param dataset:
    :param num_users:
    :returns a dict of clients with each clients assigned certain
    number of training imgs
    """
    # 60,000 training imgs --> 50 imgs/shard X 1200 shards
    num_shards, num_imgs = 1200, 50
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # Minimum and maximum shards assigned per client:
    min_shard = 1
    max_shard = 30

    # Divide the shards into random chunks for every client
    # s.t the sum of these chunks = num_shards
    random_shard_size = np.random.randint(min_shard, max_shard+1,
                                          size=num_users)
    random_shard_size = np.around(random_shard_size /
                                  sum(random_shard_size) * num_shards)
    random_shard_size = random_shard_size.astype(int)

    # Assign the shards randomly to each client
    if sum(random_shard_size) > num_shards:

        for i in range(num_users):
            # First assign each client 1 shard to ensure every client has
            # atleast one shard of data
            rand_set = set(np.random.choice(idx_shard, 1, replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        random_shard_size = random_shard_size-1

        # Next, randomly assign the remaining shards
        for i in range(num_users):
            if len(idx_shard) == 0:
                continue
            shard_size = random_shard_size[i]
            if shard_size > len(idx_shard):
                shard_size = len(idx_shard)
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)
    else:

        for i in range(num_users):
            shard_size = random_shard_size[i]
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        if len(idx_shard) > 0:
            # Add the leftover shards to the client with minimum images:
            shard_size = len(idx_shard)
            # Add the remaining shard to the client with lowest data
            k = min(dict_users, key=lambda x: len(dict_users.get(x)))
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[k] = np.concatenate(
                    (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

    return dict_users


def cifar_iid(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


def cifar_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 200, 250
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    # labels = dataset.train_labels.numpy()
    #labels = np.array(dataset.train_labels)
    labels = np.array(dataset.targets)

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)

    return dict_users


def cifar_noniid_degree4(dataset, num_users, num_classes):
    """
    Sample non-I.I.D client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_imgs = 50000 // (num_users * num_classes)
    num_shards = 50000 // num_imgs
    #num_shards, num_imgs = 400, 125
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    # labels = dataset.train_labels.numpy()
    #labels = np.array(dataset.train_labels)
    labels = np.array(dataset.targets)[:len(idxs)]

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, num_classes, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)

    return dict_users



def mnist_forget(dataset, num_users, num_classes):
    """
    Sample unbalanced mnist dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # each digit has 50000 examples

    idxs = np.arange(len(dataset))
    refer = {i: idxs[dataset.targets==i][:5000] for i in range(10)}
    dict_users = {i: np.array([]) for i in range(num_users)}
    # client 0 - 9 contains only odd digits
    odd_imgs = 25000 // (100 * 2)
    odd_shards = [i for i in range(25000 // odd_imgs)]
    odd_idx = np.concatenate([refer[i] for i in range(1,10,2)],axis=0) # already sorted by labels

    starts = [i for i in range(0,25000,125)]
    odd_digits = dict(zip(list(range(1,10,2)),[starts[i:i+40] for i in range(0,len(starts),40)]))
    tasks = cycle(combinations([1,3,5,7,9],2))

    for i in range(100):
        starta, startb = -1, -1
        while starta ==-1 or startb ==-1:
            a, b = next(tasks)
            if a in odd_digits and b in odd_digits:
                starta = odd_digits[a].pop()
                startb = odd_digits[b].pop()
                if len(odd_digits[a]) == 0:
                    odd_digits.pop(a)
                if len(odd_digits[b]) == 0:
                    odd_digits.pop(b)

            if len(odd_digits) ==0:
                break
        assert starta != -1
        assert startb != -1

        aindx = odd_idx[starta:starta+odd_imgs]
        bindx = odd_idx[startb:startb+odd_imgs]

        dict_users[i] = np.concatenate([aindx,bindx], axis=0)

    # client 100 - 119 contains only even digits
    even_imgs = 25000 // (100 * 2)
    even_shards = [i for i in range(25000 // even_imgs)]
    even_idx = np.concatenate([refer[i] for i in range(0,9,2)],axis=0) # already sorted by labels

    starts = [i for i in range(0,25000,125)]
    even_digits = dict(zip(list(range(0,9,2)),[starts[i:i+40] for i in range(0,len(starts),40)]))
    tasks = cycle(combinations([0,2,4,6,8],2))

    for i in range(100, num_users):
        starta, startb = -1, -1
        while starta ==-1 or startb ==-1:
            a, b = next(tasks)
            if a in even_digits and b in even_digits:
                starta = even_digits[a].pop()
                startb = even_digits[b].pop()
                if len(even_digits[a]) == 0:
                    even_digits.pop(a)
                if len(even_digits[b]) == 0:
                    even_digits.pop(b)
            if len(odd_digits) ==0:
                break
        assert starta != -1
        assert startb != -1

        aindx = even_idx[starta:starta+even_imgs]
        bindx = even_idx[startb:startb+even_imgs]

        dict_users[i] = np.concatenate([aindx,bindx], axis=0)


    return dict_users



def fmnist_forget(dataset, num_users, num_classes):
    """
    Sample unbalanced mnist dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # each class has 50000 examples
    # 90 client for majority
    # shoes: [5,7,9,6]
    # clothes: [0,1,2,3,4,8(bag)]

    shoes_labels = [5,7,6,9]
    cloth_labels = [0,1,2,3,4,8]


    idxs = np.arange(len(dataset))
    refer = {i: idxs[dataset.targets==i][:5000] for i in range(10)}
    dict_users = {i: np.array([]) for i in range(num_users)}
    # client 0 - 99 contains only clothes digits
    cloth_imgs = 30000 // (100 * 2)
    cloth_idx = np.concatenate([refer[i][:4500] for i in cloth_labels],axis=0) # already sorted by labels

    starts = [i for i in range(0,27000,150)]
    clothes = dict(zip(cloth_labels,[starts[i:i+30] for i in range(0,len(starts),30)])) # 30 = (27000 / 150) / len(cloth_labels)
    tasks = cycle(combinations(cloth_labels,2))

    for i in range(90):
        starta, startb = -1, -1
        while starta ==-1 or startb ==-1:
            a, b = next(tasks)
            if a in clothes and b in clothes:
                starta = clothes[a].pop()
                startb = clothes[b].pop()
                if len(clothes[a]) == 0:
                    clothes.pop(a)
                if len(clothes[b]) == 0:
                    clothes.pop(b)

            if len(clothes) ==0:
                break
        assert starta != -1
        assert startb != -1

        aindx = cloth_idx[starta:starta+cloth_imgs]
        bindx = cloth_idx[startb:startb+cloth_imgs]
        current_batch = np.concatenate([aindx,bindx], axis=0)
        np.random.shuffle(current_batch)
        dict_users[i] = current_batch

    # client 90 - 110 contains only even digits
    shoe_imgs = cloth_imgs
    shoe_idx = np.concatenate([refer[i][:4500] for i in shoes_labels],axis=0) # already sorted by labels


    starts = [i for i in range(0,18000,150)]
    shoes = dict(zip(shoes_labels,[starts[i:i+30] for i in range(0,len(starts),30)])) # 30 = (18000 / 150) / len(shoe_labels)

    tasks = cycle(combinations(shoes_labels,2))

    for i in range(90, 110):
        starta, startb = -1, -1
        while starta ==-1 or startb ==-1:
            a, b = next(tasks)
            if a in shoes and b in shoes:
                starta = shoes[a].pop()
                startb = shoes[b].pop()
                if len(shoes[a]) == 0:
                    shoes.pop(a)
                if len(shoes[b]) == 0:
                    shoes.pop(b)
            if len(shoes) ==0:
                break
        assert starta != -1
        assert startb != -1

        aindx = shoe_idx[starta:starta+shoe_imgs]
        bindx = shoe_idx[startb:startb+shoe_imgs]

        current_batch = np.concatenate([aindx,bindx], axis=0)
        np.random.shuffle(current_batch)
        dict_users[i] = current_batch

    return dict_users




def cifar_forget(dataset, num_users, num_classes):
    """
    Sample unbalanced mnist dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # each class has 50000 examples
    # 90 client for majority
    # vehicle: [0,1,8,9]
    # animals: [2,3,4,5,6,7]

    vehi_labels = [0,1,8,9]
    anim_labels = [2,3,4,5,6,7]
    idxs = np.arange(len(dataset))
    targets = np.array(dataset.targets)
    refer = {i: idxs[targets==i][:5000] for i in range(10)}
    dict_users = {i: np.array([]) for i in range(num_users)}
    # client 0 - 99 contains only clothes digits
    anim_imgs = 30000 // (100 * 2)
    anim_idx = np.concatenate([refer[i][:4500] for i in anim_labels],axis=0) # already sorted by labels

    starts = [i for i in range(0,27000,150)]
    animals = dict(zip(anim_labels,[starts[i:i+30] for i in range(0,len(starts),30)])) # 30 = (27000 / 150) / len(cloth_labels)
    tasks = cycle(combinations(anim_labels,2))

    for i in range(90):
        starta, startb = -1, -1
        while starta ==-1 or startb ==-1:
            a, b = next(tasks)
            if a in animals and b in animals:
                starta =animals[a].pop()
                startb = animals[b].pop()
                if len(animals[a]) == 0:
                    animals.pop(a)
                if len(animals[b]) == 0:
                    animals.pop(b)

            if len(animals) ==0:
                break
        assert starta != -1
        assert startb != -1

        aindx = anim_idx[starta:starta+anim_imgs]
        bindx = anim_idx[startb:startb+anim_imgs]
        current_batch = np.concatenate([aindx,bindx], axis=0)
        np.random.shuffle(current_batch)
        dict_users[i] = current_batch

    # client 90 - 110 contains only even digits
    vehi_imgs = anim_imgs
    vehi_idx = np.concatenate([refer[i][:4500] for i in vehi_labels],axis=0) # already sorted by labels


    starts = [i for i in range(0,18000,150)]
    vehicles = dict(zip(vehi_labels,[starts[i:i+30] for i in range(0,len(starts),30)])) # 30 = (18000 / 150) / len(shoe_labels)

    tasks = cycle(combinations(vehi_labels,2))

    for i in range(90, 110):
        starta, startb = -1, -1
        while starta ==-1 or startb ==-1:
            a, b = next(tasks)
            if a in vehicles and b in vehicles:
                starta =vehicles[a].pop()
                startb = vehicles[b].pop()
                if len(vehicles[a]) == 0:
                    vehicles.pop(a)
                if len(vehicles[b]) == 0:
                    vehicles.pop(b)
            if len(vehicles) ==0:
                break
        assert starta != -1
        assert startb != -1

        aindx = vehi_idx[starta:starta+vehi_imgs]
        bindx = vehi_idx[startb:startb+vehi_imgs]

        current_batch = np.concatenate([aindx,bindx], axis=0)
        np.random.shuffle(current_batch)
        dict_users[i] = current_batch

    return dict_users








def nlp_iid(dataset, num_users):
    """
    Sample I.I.D. client data from SST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

def nlp_noniid_new(dataset, num_users,Lam_high,Num_chunk):
    """
    Sample non-I.I.D client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # the probability matrix we need
    num_class = 2
    # P_diag=np.array([0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5])
    # P_single=np.random.rand(9)
    # P_single=P_single/np.sum(P_single)*(1-P_diag[0])
    # P_row=np.insert(P_single,0,P_diag[0])
    # for j in range(num_class-1):
    #     P_single = np.random.rand(9)
    #     P_single=P_single/np.sum(P_single)*(1 - P_diag[j + 1])
    #     P_single=np.insert(P_single,j+1,P_diag[j+1])
    #     P_row=np.vstack((P_row,P_single))

    def cal_pmatrix(Num_leave,Num_locate,Lam=Lam_high,Num_class=num_class):
        P_matrix=np.random.rand(int(Num_leave))
        P_matrix=P_matrix/np.sum(P_matrix)
        #print(P_matrix)
        P_add=np.random.rand(int(Num_leave/Num_class))
        print(P_add.shape)
        P_add=P_add/sum(P_add)*(0.5*(Lam/(1-Lam))-0.5)
        P_matrix[int(Num_locate*(Num_leave/Num_class)):int(Num_locate*(Num_leave/Num_class))+int(Num_leave/Num_class)]=P_matrix[int(Num_locate*(Num_leave/Num_class)):int(Num_locate*(Num_leave/Num_class))+int(Num_leave/Num_class)]+P_add
        P_matrix=P_matrix/(0.5*(1/(1-Lam)))
        return P_matrix

    num_chunk=Num_chunk
    num_shards, num_imgs = int(num_users*num_chunk), int(6920/(num_users*num_chunk))
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    #idxs = np.arange(num_shards*num_imgs)
    idxs = np.arange(len(dataset))
    # labels = dataset.train_labels.numpy()
    #labels = np.array(dataset.train_labels)
    #labels = np.array(dataset.targets)
    labels = np.array([x[2].numpy() for x in dataset])

    # sort labels
    idxs_labels = np.vstack((idxs, labels[:,0]))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign
    for i in range(num_users):
        num_locate=i%num_class
        num_leave=len(idx_shard)
        p_matrix=cal_pmatrix(Num_leave=num_leave,Num_locate=num_locate)
        #print(sum(p_matrix))
        rand_set = set(np.random.choice(idx_shard, num_chunk, p=p_matrix,replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users




if __name__ == '__main__':
    dataset_train = datasets.CIFAR10('./data/cifar/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,),
                                                            (0.3081,))
                                   ]))
    num_clients = 100
    num_classes = 4
    users = cifar_noniid_degree4(dataset_train, num_clients, num_classes)

