#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
# Pytorch version: 1.1.0 or above


## Returns train and test datasets and a user group which is a dict where the keys are the user index and the values are the corresponding data for 
## each of those users.
import numpy as np
#from torchvision import datasets, transforms


def U_data_random(args, dataset):
    """
    Sample random client data from MNIST, CIFAR10 dataset
    :param args:
    :param dataset:
    :return: dict of image index
    """
    num_users = args.U_num_users
    num_items = int(len(dataset)/num_users)
    dict_users, dict_users_data_size, all_idxs = {i: np.array(list()) for i in range(num_users)}, {i: np.array(0) for i in range(num_users)}, [i for i in range(len(dataset))]
    for i in range(num_users):
        #dict_users[i] = np.random.choice(all_idxs, num_items, replace=False)
        dict_users[i] = np.random.choice(all_idxs, num_items, replace=True)
        dict_users_data_size[i] = np.array(num_items)
        all_idxs = list(set(all_idxs) - set(dict_users[i]))
    return dict_users, dict_users_data_size


def A_data_random(args, dataset):
    """
    Sample random client data from MNIST, CIFAR10 dataset
    :param args:
    :param dataset:
    :return: dict of image index
    """
    num_users = args.A_num_users
    num_items = int(len(dataset)/num_users)
    dict_users, dict_users_data_size, all_idxs = {i: np.array(list()) for i in range(num_users)}, {i: np.array(0) for i in range(num_users)}, [i for i in range(len(dataset))]
    for i in range(num_users):
        #dict_users[i] = np.random.choice(all_idxs, num_items, replace=False)
        dict_users[i] = np.random.choice(all_idxs, num_items, replace=True)
        dict_users_data_size[i] = np.array(num_items)
        all_idxs = list(set(all_idxs) - set(dict_users[i]))
    return dict_users, dict_users_data_size


def U_data_if_iid_equal(args, dataset):
    """
    Sample I.I.D. (ratio=0.1) or non-I.I.D, equal or unequal client data from MNIST, CIFAR10 dataset
    :param args:
    :param dataset:
    :return: returns a dict of clients with each client assigned iid or non-iid, equal or unequal training imgs
    """
    # MNIST: 60,000 training imgs, CIFAR10: 50,000 training imgs
    num_users = args.U_num_users
    num_classes = args.num_classes
    ratio = args.U_iid_ratio # ratio=0.01,0.1,0.55
    idxs_of_label = list()
    minor_ratio = round(1-ratio, 6) / (num_classes-1)
    if args.data == 'MNIST':
        labels = dataset.targets.numpy() # For MNIST
        idxs_of_label.append([i for i in range(5923)]) # indexes of label 0
        idxs_of_label.append([i for i in range(5923, 12665)]) # indexes of label 1
        idxs_of_label.append([i for i in range(12665, 18623)]) # indexes of label 2
        idxs_of_label.append([i for i in range(18623, 24754)]) # indexes of label 3
        idxs_of_label.append([i for i in range(24754, 30596)]) # indexes of label 4
        idxs_of_label.append([i for i in range(30596, 36017)]) # indexes of label 5
        idxs_of_label.append([i for i in range(36017, 41935)]) # indexes of label 6
        idxs_of_label.append([i for i in range(41935, 48200)]) # indexes of label 7
        idxs_of_label.append([i for i in range(48200, 54051)]) # indexes of label 8
        idxs_of_label.append([i for i in range(54051, 60000)]) # indexes of label 9
    else:
        labels = np.array(dataset.targets) # For CIFAR10 or 100
        for i in range(num_classes):
            idxs_of_label.append([i for i in range(5000*i, 5000*(i+1))]) # indexes of label i

    num_total = args.train_size
    dict_users, dict_users_data_size = {i: np.array(list()) for i in range(num_users)}, {i: np.array(0) for i in range(num_users)}
    idxs = np.arange(num_total)
    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    num_user_data_init = round(args.data_init * args.rou)
    # Chose equal splits for every user
    if args.unequal == 0:
        num_user_major = (np.round(num_user_data_init * ratio)).astype(int)
        num_user_minor = (np.round(num_user_data_init * minor_ratio)).astype(int)
        if (args.over_sampling==1) and (num_user_minor==0):
            num_user_minor = 100
        num_user_data_final = num_user_major + num_user_minor*(num_classes-1)
        #print("num_user_data_init:",num_user_data_init,"num_user_data_final",num_user_data_final)
        for i in range(num_users):
            major_class = i % num_classes
            idxs_in_shard = list()
            for k in range(num_classes):
                if k == major_class:
                    #idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_major, replace=False))
                    idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_major, replace=True))
                else:
                    if args.over_sampling == 1:
                        minor_sampling = np.random.choice(idxs_of_label[k], num_user_minor, replace=True)
                        num_sampling = int(round(num_user_major/num_user_minor))
                        for i in range(num_sampling):
                            idxs_in_shard.extend(minor_sampling)
                        num_user_data_final += num_user_minor*(num_sampling-1)
                    else:
                        #idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_minor, replace=False))
                        idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_minor, replace=True))
            dict_users[i] = idxs[idxs_in_shard]
            dict_users_data_size[i] = np.array(num_user_data_final)
    # Chose uneuqal splits for every user
    else:
        scale_range = args.scale_range
        for i in range(num_users):
            # 0 <= num_user_data_mid <= 5000, uniform dist: [low, high)
            num_user_data_mid = num_user_data_init + (np.random.uniform(low=-1.0, high=1.0))*scale_range
            num_user_major = (np.round(num_user_data_mid * ratio)).astype(int)
            num_user_minor = (np.round(num_user_data_mid * minor_ratio)).astype(int)
            if (args.over_sampling==1) and (num_user_minor==0):
                num_user_minor = 100
            num_user_data_final = num_user_major + num_user_minor*(num_classes-1)
            #print("num_user_data_init:",num_user_data_init,"num_user_data_final",num_user_data_final)
            major_class = i % num_classes
            idxs_in_shard = list()
            for k in range(num_classes):
                if k == major_class:
                    #idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_major, replace=False))
                    idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_major, replace=True))
                else:
                    if args.over_sampling == 1:
                        minor_sampling = np.random.choice(idxs_of_label[k], num_user_minor, replace=True)
                        num_sampling = int(round(num_user_major/num_user_minor))
                        for i in range(num_sampling):
                            idxs_in_shard.extend(minor_sampling)
                        num_user_data_final += num_user_minor*(num_sampling-1)
                    else:
                        #idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_minor, replace=False))
                        idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_minor, replace=True))
            dict_users[i] = idxs[idxs_in_shard]
            dict_users_data_size[i] = np.array(num_user_data_final)
    return dict_users, dict_users_data_size


def A_data_if_iid_equal(args, dataset):
    """
    Sample I.I.D. (ratio=0.1) or non-I.I.D, equal or unequal client data from MNIST, CIFAR10 dataset
    :param args:
    :param dataset:
    :return: returns a dict of clients with each client assigned iid or non-iid, equal or unequal training imgs
    """
    # MNIST: 60,000 training imgs, CIFAR10: 50,000 training imgs
    num_users = args.A_num_users
    num_classes = args.num_classes
    ratio = args.A_iid_ratio # ratio=0.01,0.1,0.55
    idxs_of_label = list()
    minor_ratio = round(1-ratio, 6) / (num_classes-1)
    if args.data == 'MNIST':
        labels = dataset.targets.numpy() # For MNIST
        idxs_of_label.append([i for i in range(5923)]) # indexes of label 0
        idxs_of_label.append([i for i in range(5923, 12665)]) # indexes of label 1
        idxs_of_label.append([i for i in range(12665, 18623)]) # indexes of label 2
        idxs_of_label.append([i for i in range(18623, 24754)]) # indexes of label 3
        idxs_of_label.append([i for i in range(24754, 30596)]) # indexes of label 4
        idxs_of_label.append([i for i in range(30596, 36017)]) # indexes of label 5
        idxs_of_label.append([i for i in range(36017, 41935)]) # indexes of label 6
        idxs_of_label.append([i for i in range(41935, 48200)]) # indexes of label 7
        idxs_of_label.append([i for i in range(48200, 54051)]) # indexes of label 8
        idxs_of_label.append([i for i in range(54051, 60000)]) # indexes of label 9
    else:
        labels = np.array(dataset.targets) # For CIFAR10 or 100
        for i in range(num_classes):
            idxs_of_label.append([i for i in range(5000*i, 5000*(i+1))]) # indexes of label i

    num_total = args.train_size
    dict_users, dict_users_data_size = {i: np.array(list()) for i in range(num_users)}, {i: np.array(0) for i in range(num_users)}
    idxs = np.arange(num_total)
    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    num_user_data_init = args.data_init - round(args.data_init * args.rou)
    # Chose equal splits for every user
    if args.unequal == 0:
        num_user_major = (np.round(num_user_data_init * ratio)).astype(int)
        num_user_minor = (np.round(num_user_data_init * minor_ratio)).astype(int)
        num_user_data_final = num_user_major + num_user_minor*(num_classes-1)
        #print("num_user_data_init:",num_user_data_init,"num_user_data_final",num_user_data_final)
        for i in range(num_users):
            major_class = i % num_classes
            idxs_in_shard = list()
            for k in range(num_classes):
                if k == major_class:
                    #idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_major, replace=False))
                    idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_major, replace=True))
                else:
                    #idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_minor, replace=False))
                    idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_minor, replace=True))
            dict_users[i] = idxs[idxs_in_shard]
            dict_users_data_size[i] = np.array(num_user_data_final)
    # Chose uneuqal splits for every user
    else:
        scale_range = args.scale_range
        for i in range(num_users):
            # 0 <= num_user_data_mid <= 5000, uniform dist: [low, high)
            num_user_data_mid = num_user_data_init + (np.random.uniform(low=-1.0, high=1.0))*scale_range
            num_user_major = (np.round(num_user_data_mid * ratio)).astype(int)
            num_user_minor = (np.round(num_user_data_mid * minor_ratio)).astype(int)
            num_user_data_final = num_user_major + num_user_minor*(num_classes-1)
            #print("num_user_data_init:",num_user_data_init,"num_user_data_final",num_user_data_final)
            major_class = i % num_classes
            idxs_in_shard = list()
            for k in range(num_classes):
                if k == major_class:
                    #idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_major, replace=False))
                    idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_major, replace=True))
                else:
                    #idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_minor, replace=False))
                    idxs_in_shard.extend(np.random.choice(idxs_of_label[k], num_user_minor, replace=True))
            dict_users[i] = idxs[idxs_in_shard]
            dict_users_data_size[i] = np.array(num_user_data_final)
    return dict_users, dict_users_data_size


"""
#################################################################### MAIN ##########################################################################
if __name__ == '__main__':
    data = 'cifar10'
    if data == 'MNIST':
        dataset_train = datasets.MNIST(root='./data/mnist/', train=True, download=True, transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307,), (0.3081,))
                                       ]))
    else:
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        dataset_train = datasets.CIFAR10(root='./data/mnist/', train=True, download=True, transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            normalize,
                                        ]))
    dict_users = data_if_iid_equal(dataset_train) 
    #dict_users = data_random(dataset_train)
    
    #labels = dataset_train.targets.numpy() # For MNIST
    labels = np.array(dataset_train.targets) # For CIFAR10 or 100
    #print(labels[dict_users[0].astype(int)])
    for i in range(0, 100, 7):
        data_user = dict_users[i].astype(int)
        labels_i = labels[data_user]
        major_class = i % 10
        for j in range(10):
            print("cifar10 of User ", i, " for label ", j, ": ", len(labels_i[labels_i==j]), "total: ", len(labels_i), "major_class: ", major_class)
"""    