"""Client sampler for federated learning."""

import random


class DropoutFunctionNotFoundError(Exception):
    """The dropout function name passed is not supported."""

    def __init__(
        self,
        function_name: str,
    ) -> None:
        """Initialize the exception."""
        msg = f"The value {function_name} was passed as dropout function name but it is"
        " not supported. Supported values are {DROPOUT_FUNCTIONS}."
        super().__init__(msg)


class ClientSampler:
    """Client sampler for federated learning."""

    def __init__(
        self,
        total_number_of_clients: int,
        number_of_clients_per_round: int,
        dropout_ratio: float,
        dropout_function_name: str,
    ) -> None:
        """Initialize the ClientSampler with sampling configuration parameters.

        This class handles client selection for federated learning rounds, including
        the sampling of clients and applying dropout according to specified strategy.

        Parameters
        ----------
        total_number_of_clients : int
            The total number of clients available in the federated system.
        number_of_clients_per_round : int
            The number of clients to sample for each training round.
        dropout_ratio : float
            The fraction of sampled clients that will be dropped (simulating unreliable
            clients) after initial selection.
        dropout_function_name : str
            The name of the function to use for dropping clients. Currently supported:
            "random" - randomly drops clients based on dropout_ratio.

        """
        self.total_number_of_clients = total_number_of_clients
        self.number_of_clients_per_round = number_of_clients_per_round
        self.dropout_ratio = dropout_ratio
        self.dropout_function_name = dropout_function_name

    def sample_clients(
        self,
        rng: random.Random,
    ) -> list[int]:
        """Sample clients for a federated learning round.

        This function selects a random subset of clients from the available pool and
        applies the specified dropout function to simulate client unavailability. The
        number of clients selected is determined by `number_of_clients_per_round` and
        the dropout is applied according to `dropout_ratio` using the strategy specified
        in `dropout_function_name`.

        Parameters
        ----------
        rng : random.Random
            Random number generator instance to use for sampling clients.

        Returns
        -------
        list[int]
            A list of client IDs selected for the current round after applying dropout.

        Raises
        ------
        DropoutFunctionNotFoundError
            If the specified dropout function name is not supported.

        """
        # Randomly sample clients
        sampled_clients = rng.sample(
            range(self.total_number_of_clients),
            self.number_of_clients_per_round,
        )
        # Apply dropout depending on the ratio and the function
        if self.dropout_function_name == "random":
            # Randomly exclude elements in the list of `sampled_clients` depending on
            # the `self.dropout_ratio`
            if self.dropout_ratio > 0:
                # Calculate number of clients to drop
                n_dropout = int(len(sampled_clients) * self.dropout_ratio)
                # Choose which clients to drop
                dropout_indices = rng.sample(range(len(sampled_clients)), n_dropout)
                # Create a new list without the dropped clients
                sampled_clients = [
                    client
                    for i, client in enumerate(sampled_clients)
                    if i not in dropout_indices
                ]
        else:
            raise DropoutFunctionNotFoundError(self.dropout_function_name)

        return sampled_clients
