import numpy as np
import pickle
np.random.seed(42)

def split_data_dirichlet(labels, n_clients, alpha=1.0, min_size=10):
    num_classes = len(np.unique(labels))
    data_indices = [[] for _ in range(n_clients)]
    
    # Preallocate min_size samples to each client
    all_indices = np.arange(len(labels))
    np.random.shuffle(all_indices)

    for i in range(n_clients):
        assigned = all_indices[:min_size]
        data_indices[i].extend(assigned)
        all_indices = all_indices[min_size:]

    # Now assign the remaining samples by Dirichlet
    remaining_labels = labels[all_indices]
    remaining_indices = all_indices

    for c in range(num_classes):
        class_indices = remaining_indices[np.where(remaining_labels == c)[0]]
        if len(class_indices) == 0:
            continue
        np.random.shuffle(class_indices)
        proportions = np.random.dirichlet(np.repeat(alpha, n_clients))
        proportions = (np.cumsum(proportions) * len(class_indices)).astype(int)[:-1]
        split_indices = np.split(class_indices, proportions)

        for i in range(n_clients):
            data_indices[i].extend(split_indices[i])

    return data_indices

def get_data_dict(path):
    with open(path, 'rb') as f:
        data_dict = pickle.load(f)
    return data_dict