from scipy.optimize import linear_sum_assignment
import torch
def select_experts_balanced(count_matrix, k=6):
    batch_size, expert_num = count_matrix.shape
    #assert batch_size * k % expert_num == 0
    expert_select_bound = int(batch_size * k / expert_num) + 1
    #repeat_count_matrix = torch.repeat_interleave(count_matrix, repeats=expert_repeats, dim=1)
    sample_selected_experts = -1 * torch.ones((batch_size, k), dtype=int, device=count_matrix.device)
    min_val = -torch.iinfo(torch.int64).max
    for i in range(k):
        used_expert = sample_selected_experts.view(-1)[sample_selected_experts.view(-1) != -1]
        expert_select_count = torch.bincount(used_expert, minlength=expert_num)
        index = torch.where(expert_select_count >= expert_select_bound)
        #count_matrix[:,index[0]] = -1

        row, col = linear_sum_assignment(-count_matrix.cpu().numpy())
        sample_selected_experts[row,i] = torch.tensor(col, device=sample_selected_experts.device)
        count_matrix[row, col] = min_val
    return sample_selected_experts
