from typing import List

import torch

def assign_data_owners(n_data: int, n_owners: int, data_per_owners: List[int]=None, method='random', **kwargs):
    if method == 'random':
        if not data_per_owners:
            return torch.clamp(torch.randperm(n_data) // (n_data // n_owners), max=n_owners - 1)
        else:
            mask = []
            for i in range(n_owners):
                mask.extend([i] * data_per_owners[i])
            
            perm = torch.randperm(n_data)
            for i in range(len(perm)):
                perm[i] = mask[perm[i]]

            return perm
    elif method == 'by class':
        targets = kwargs['targets']
        unique_labels = targets[:, 0].unique()
        #label_assignment = torch.minimum(torch.arange(len(unique_labels)) // n_owners,
        #                                 torch.full((len(unique_labels),), n_owners - 1))
        label_to_owner_dict = {0:2, 1:2, 2:0, 3:1, 4:2, 5:0, 6: 2, 7:1, 8:1, 9:0}
        per_sample_assignment = []
        for i in range(n_data):
            per_sample_assignment.append(label_to_owner_dict[targets[i, 0].item()])

        return torch.Tensor(per_sample_assignment).int()

        
def assign_data_owners_by_class(targets, class_to_owner_dict):
    per_sample_assignment = []
    for i in range(len(targets)):
        per_sample_assignment.append(class_to_owner_dict[targets[i].item()])

    return torch.Tensor(per_sample_assignment).int()


