import torch
import numpy as np
import copy
from tool.utils import logger


def client_selection(client_num, fraction, dataset_size, client_dataset_size_list, drop_rate, probabilities=None, style="FedAvg"):
    if probabilities is None:
        probabilities = []
    assert sum(client_dataset_size_list) == dataset_size
    idxs_users = [0]

    selected_num = max(int(fraction * client_num), 1)
    if float(drop_rate) != 0:
        drop_num = max(int(selected_num * drop_rate), 1)
        selected_num -= drop_num

    if "FedAvg".lower() in style.lower():
        idxs_users = np.random.choice(
            a=range(client_num),
            size=selected_num,
            replace=False,
        )

    return idxs_users