import copy

import numpy as np
from torchvision import datasets, transforms
import random

def dirichlet_noniid(train_dataset, test_dataset, num_users, alpha, seed, args):
    num_train_indices = len(train_dataset.targets)  # the number of training samples
    num_test_indices = len(test_dataset.targets)    # the number of testing samples

    train_labels = np.array(train_dataset.targets)
    test_labels = np.array(test_dataset.targets)
    num_class = len(np.unique(train_dataset.targets))
    train_idxs_classes = [[] for _ in range(num_class)]
    test_idxs_classes = [[] for _ in range(num_class)]

    # number of training and testing sample for each client
    # num_client_train_sample = int(num_train_indices / num_users)
    # num_client_test_sample = int(num_test_indices / num_users)
    num_client_train_sample = args.train_num
    num_client_test_sample = args.test_num

    # sorted by label, samples with the same labels are put in a list
    for i in range(num_train_indices):
        train_idxs_classes[train_labels[i]].append(i)
    for i in range(num_test_indices):
        test_idxs_classes[test_labels[i]].append(i)

    train_dict_users = {i: [] for i in range(num_users)}
    test_dict_users = {i: [] for i in range(num_users)}

    # generate the data distribution for each client
    # use random seed to initialize
    random_state = np.random.RandomState(seed)
    q = random_state.dirichlet(np.repeat(alpha, num_class), num_users)

    data_replace = (args.data_replace == 1)

    # partition dataset according to q for each client
    for i in range(num_users):
        # make sure that each client have num_client_train_sample samples
        temp_train_sample = num_client_train_sample
        temp_test_sample = num_client_test_sample
        # partition each class for clients
        for j in range(num_class):
            num_train = int(num_client_train_sample * q[i][j] + 0.5) if j < num_class - 1 else temp_train_sample
            num_test = int(num_client_test_sample * q[i][j] + 0.5) if j < num_class - 1 else temp_test_sample
            num_train = min(num_train, temp_train_sample)
            num_test = min(num_test, temp_test_sample)
            num_train = max(num_train, 0)
            num_test = max(num_test, 0)

            # num_train = num_train if temp_train_sample - num_train >= 0 else temp_train_sample
            # num_test = num_test if temp_test_sample - num_test >= 0 else temp_test_sample
            temp_train_sample -= num_train
            temp_test_sample -= num_test
            assert num_train >= 0 and num_test >= 0
            train_dict_users[i] += random_state.choice(train_idxs_classes[j], num_train, replace=data_replace).tolist()
            test_dict_users[i] += random_state.choice(test_idxs_classes[j], num_test, replace=data_replace).tolist()
        train_dict_users[i] = np.array(train_dict_users[i])
        test_dict_users[i] = np.array(test_dict_users[i])
    return train_dict_users, test_dict_users

def pathological_noniid(train_dataset, test_dataset, num_users, alpha, seed, args, random=True):
    num_train_indices = len(train_dataset.targets)  # the number of training samples
    num_test_indices = len(test_dataset.targets)  # the number of testing samples

    train_labels = np.array(train_dataset.targets)
    test_labels = np.array(test_dataset.targets)
    num_class = len(np.unique(train_dataset.targets))
    train_idxs_classes = [[] for _ in range(num_class)]
    test_idxs_classes = [[] for _ in range(num_class)]

    num_client_train_sample = args.train_num
    num_client_test_sample = args.test_num

    # sorted by label, samples with the same labels are put in a list
    for i in range(num_train_indices):
        train_idxs_classes[train_labels[i]].append(i)
    for i in range(num_test_indices):
        test_idxs_classes[test_labels[i]].append(i)

    train_dict_users = {i: [] for i in range(num_users)}
    test_dict_users = {i: [] for i in range(num_users)}
    random_state = np.random.RandomState(seed)

    data_replace = (args.data_replace == 1)

    if random:
        print('Random split class!')
        class_idxs = [i for i in range(num_class)]
        for i in range(num_users):
            class_idx = random_state.choice(class_idxs, alpha, replace=False)
            for j in class_idx:
                train_selected = random_state.choice(train_idxs_classes[j],
                                                     int(num_client_train_sample / alpha), replace=data_replace)
                train_dict_users[i] += train_selected.tolist()
                test_selected = random_state.choice(test_idxs_classes[j], int(num_client_test_sample / alpha),
                                                    replace=data_replace)
                test_dict_users[i] += test_selected.tolist()
                # train_idxs_classes[class_idx] = np.array(list(set(train_idxs_classes[class_idx]) - set(train_selected)))
                # test_idxs_classes[class_idx] = np.array(list(set(test_idxs_classes[class_idx]) - set(test_selected)))
            last_train_num = num_client_train_sample - len(train_dict_users[i])
            last_test_num = num_client_test_sample - len(test_dict_users[i])
            if last_train_num > 0:
                train_selected = random_state.choice(train_idxs_classes[class_idx[0]], last_train_num, replace=data_replace)
                train_dict_users[i] += train_selected.tolist()
            if last_test_num > 0:
                test_selected = random_state.choice(test_idxs_classes[class_idx[0]], last_test_num, replace=data_replace)
                test_dict_users[i] += test_selected.tolist()
            train_dict_users[i] = np.array(train_dict_users[i])
            test_dict_users[i] = np.array(test_dict_users[i])
    else:
        class_idx = 0
        for i in range(num_users):
            # if i == 0:
            for j in range(alpha):
                train_selected = random_state.choice(train_idxs_classes[class_idx], int(num_client_train_sample / alpha), replace=data_replace)
                train_dict_users[i] += train_selected.tolist()
                test_selected = random_state.choice(test_idxs_classes[class_idx], int(num_client_test_sample / alpha), replace=data_replace)
                test_dict_users[i] += test_selected.tolist()
                # train_idxs_classes[class_idx] = np.array(list(set(train_idxs_classes[class_idx]) - set(train_selected)))
                # test_idxs_classes[class_idx] = np.array(list(set(test_idxs_classes[class_idx]) - set(test_selected)))
                class_idx += 1
                if class_idx == num_class: class_idx = 0
            last_train_num = num_client_train_sample - len(train_dict_users[i])
            last_test_num = num_client_test_sample - len(test_dict_users[i])
            if last_train_num > 0:
                train_selected = random_state.choice(train_idxs_classes[class_idx], last_train_num, replace=data_replace)
                train_dict_users[i] += train_selected.tolist()
            if last_test_num > 0:
                test_selected = random_state.choice(test_idxs_classes[class_idx], last_test_num, replace=data_replace)
                test_dict_users[i] += test_selected.tolist()
            train_dict_users[i] = np.array(train_dict_users[i])
            test_dict_users[i] = np.array(test_dict_users[i])
    return train_dict_users, test_dict_users

