import numpy as np

def dirichlet_noniid(train_dataset, test_dataset, num_users, alpha, seed, args):
    num_train_indices = len(train_dataset.targets)  # the number of training samples
    num_test_indices = len(test_dataset.targets)    # the number of testing samples

    train_labels = np.array(train_dataset.targets)
    test_labels = np.array(test_dataset.targets)
    num_class = len(np.unique(train_dataset.targets))
    train_idxs_classes = [[] for _ in range(num_class)]
    test_idxs_classes = [[] for _ in range(num_class)]

    num_client_train_sample = args.train_num
    num_client_test_sample = args.test_num

    # sorted by label, samples with the same labels are put in a list
    for i in range(num_train_indices):
        train_idxs_classes[train_labels[i]].append(i)
    for i in range(num_test_indices):
        test_idxs_classes[test_labels[i]].append(i)

    train_dict_users = {i: [] for i in range(num_users)}
    test_dict_users = {i: [] for i in range(num_users)}

    # generate the data distribution for each client
    # use random seed to initialize
    random_state = np.random.RandomState(seed)
    q = random_state.dirichlet(np.repeat(alpha, num_class), num_users)

    data_replace = (args.data_replace == 1)

    # partition dataset according to q for each client
    for i in range(num_users):
        # make sure that each client have num_client_train_sample samples
        temp_train_sample = num_client_train_sample
        temp_test_sample = num_client_test_sample
        # partition each class for clients
        for j in range(num_class):
            num_train = int(num_client_train_sample * q[i][j] + 0.5) if j < num_class - 1 else temp_train_sample
            num_test = int(num_client_test_sample * q[i][j] + 0.5) if j < num_class - 1 else temp_test_sample
            num_train = min(num_train, temp_train_sample)
            num_test = min(num_test, temp_test_sample)
            num_train = max(num_train, 0)
            num_test = max(num_test, 0)

            temp_train_sample -= num_train
            temp_test_sample -= num_test
            assert num_train >= 0 and num_test >= 0
            train_dict_users[i] += random_state.choice(train_idxs_classes[j], num_train, replace=data_replace).tolist()
            test_dict_users[i] += random_state.choice(test_idxs_classes[j], num_test, replace=data_replace).tolist()
        train_dict_users[i] = np.array(train_dict_users[i])
        test_dict_users[i] = np.array(test_dict_users[i])
    return train_dict_users, test_dict_users
