# partition.py
import numpy as np


def create_noniid_indices(dataset, num_clients, alpha=0.5):
    """
    Generate non-IID indices for clients using a Dirichlet distribution.
    :param dataset: PyTorch dataset
    :param num_clients: Number of clients
    :param alpha: Dirichlet distribution parameter; smaller values result in more non-IID
    :return: dict, key=client_id, value=list of sample indices
    """
    num_classes = len(np.unique([y for _, y in dataset]))
    label_indices = {i: [] for i in range(num_classes)}

    for idx, (_, label) in enumerate(dataset):
        label_indices[label].append(idx)

    client_indices = {i: [] for i in range(num_clients)}

    for label, idxs in label_indices.items():
        np.random.shuffle(idxs)
        proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
        proportions = (np.cumsum(proportions) * len(idxs)).astype(int)
        start = 0
        for client_id, end in enumerate(proportions):
            client_indices[client_id].extend(idxs[start:end])
            start = end

    return client_indices
