import numpy as np
import argparse
import torch
from torchvision import datasets, transforms
import pickle as pkl
import os, shutil
# import utils

def Dirichlet_sampler(args, train_dataset, cur_classes, use_IID=False, use_balance=True):
    train_data, train_targets, train_ids = train_dataset.get_data_label(cur_classes)
    n_samples_train = len(train_data)
    n_classes = len(cur_classes)
    # n_samples_test = len(test_data)

    ##### Determine locations of different classes
    all_ids_train = np.array(train_targets)
    class_ids_train = {class_num: np.where(all_ids_train == cur_classes[class_num])[0] for class_num in range(n_classes)}
    # all_ids_test = np.array(test_data.targets)
    # class_ids_test = {class_num: np.where(all_ids_test == class_num)[0] for class_num in range(n_classes)}

    ##### Determine distribution over classes to be assigned per client
    # Returns n_clients x n_classes matrix
    n_clients = args.num_clients
    if(use_IID):
        args.dirichlet_alpha = 1e5
    dist_of_client = np.random.dirichlet(np.repeat(args.dirichlet_alpha, n_clients), size=n_classes).transpose()
    dist_of_client /= dist_of_client.sum()

    #### Run OT if using balanced partitioning
    if(use_balance):
        for i in range(100):
            s0 = dist_of_client.sum(axis=0, keepdims=True)
            s1 = dist_of_client.sum(axis=1, keepdims=True)
            dist_of_client /= s0
            dist_of_client /= s1

    ##### Allocate number of samples per class to each client based on distribution
    samples_per_class_train = (np.floor(dist_of_client * n_samples_train))
    # samples_per_class_test = (np.floor(dist_of_client * n_samples_test))

    start_ids_train = np.zeros((n_clients+1, n_classes), dtype=np.int32)
    # start_ids_test = np.zeros((n_clients+1, n_classes), dtype=np.int32)
    for i in range(0, n_clients):
        start_ids_train[i+1] = start_ids_train[i] + samples_per_class_train[i]
        # start_ids_test[i+1] = start_ids_test[i] + samples_per_class_test[i]

    ##### Save IDs
    # Train
    client_ids = {client_num: {} for client_num in range(n_clients)}
    for client_num in range(n_clients):
        l = np.array([], dtype=np.int32)
        for class_num in range(n_classes):
            start, end = start_ids_train[client_num, class_num], start_ids_train[client_num+1, class_num]
            l = np.concatenate((l, class_ids_train[class_num][start:end].tolist())).astype(np.int32)
        client_ids[client_num] = train_ids[l]

    return client_ids

