"""Utility functions for broadcasting to clients on main server loop in flwr next."""

import ast
import uuid
from collections.abc import Callable, Generator
from functools import partial
from logging import DEBUG, ERROR, WARNING
from typing import TYPE_CHECKING, Any, cast

from flwr.common import (
    Code,
    ConfigsRecord,
    FitRes,
    Message,
    MessageType,
    MetricsRecord,
    NDArrays,
    Parameters,
    RecordSet,
    Scalar,
    Status,
    log,
)
from flwr.common.recordset_compat import (
    _check_mapping_from_recordscalartype_to_scalar,  # noqa: PLC2701
    _embed_status_into_recordset,  # noqa: PLC2701
    _extract_status_from_recordset,  # noqa: PLC2701
    parameters_to_parametersrecord,
    parametersrecord_to_parameters,
)

from repo.conf.base_schema import BaseConfig, CommStack
from repo.constants import (
    BATCH_ID,
    CLIENT_STATE_ACCUMULATOR,
    CODE,
    COMM_STACK,
    EMPTY,
    ENDPOINT_ID,
    FILE_NAME,
    FOLDER_NAME,
    GATHER_INS,
    GATHER_MSG,
    GATHER_P,
    GATHER_RES,
    LAYER_NAMES,
    LAYER_TYPES,
    MASK,
    MESSAGE,
    METRICS,
    NUM_EXAMPLES,
    PARAMETERS,
    repo_TTL,
    QUERY,
    STATUS,
    TYPE,
    UNEXPECTED_EMPTY_CONTENT_MSG,
)
from repo.masks_utils import (
    generate_full_mask,
    generate_mask,
    mask_to_batches,
)
from repo.server.comms import load_recordset_parameters_from_remote
from repo.server.load_balancer import ServerLoadBalancer
from repo.server.utils import ServerState, TooManyFailuresError
from repo.utils import (
    ClientState,
)

UNEXPECTED_EMPTY_CONTENT_STATUS = Status(
    code=Code.FIT_NOT_IMPLEMENTED,
    message=UNEXPECTED_EMPTY_CONTENT_MSG,
)
ERROR_FITRES = FitRes(
    status=UNEXPECTED_EMPTY_CONTENT_STATUS,
    parameters=Parameters(tensors=[], tensor_type=EMPTY),
    metrics={},
    num_examples=1,
)

if TYPE_CHECKING:
    from flwr.common.typing import ConfigsRecordValues


def gather_res_to_recordset(
    metrics: dict[str, Scalar],
    num_examples: int,
    status: Status,
    *,
    keep_input: bool,
) -> RecordSet:
    """Convert gather result data into a RecordSet structure.

    This function creates a RecordSet containing the provided metrics, number of
    examples, and status information for transmission in the federated learning system.
    It properly formats all fields according to the expected structure, embedding the
    metrics as ConfigsRecords, number of examples as MetricsRecords, and adding empty
    parameters as a placeholder.

    Parameters
    ----------
    metrics : dict[str, Scalar]
        Dictionary of metrics collected during training or evaluation.
    num_examples : int
        Number of examples processed during training or evaluation.
    status : Status
        Status code and message indicating success or failure.
    keep_input : bool
        Whether to preserve the input data when creating parameter records.

    Returns
    -------
    RecordSet
        A properly formatted RecordSet containing all the provided gather result data
        with appropriate record types for each component.

    """
    # Create an empty RecordSet object
    recordset = RecordSet()
    # Add the metrics
    recordset.configs_records[f"{GATHER_RES}.{METRICS}"] = ConfigsRecord(
        cast("dict[str, ConfigsRecordValues]", metrics),
    )
    # Add the number of examples
    recordset.metrics_records[f"{GATHER_RES}.{NUM_EXAMPLES}"] = MetricsRecord(
        {NUM_EXAMPLES: num_examples},
    )
    # Add the parameters
    recordset.parameters_records[f"{GATHER_RES}.{PARAMETERS}"] = (
        parameters_to_parametersrecord(
            Parameters(tensors=[], tensor_type="empty"),
            keep_input=keep_input,
        )
    )
    # Embed the status and return the recordset
    return _embed_status_into_recordset(GATHER_RES, status, recordset)


def recordset_to_gather_res(recordset: RecordSet, *, keep_input: bool) -> FitRes:
    """Convert a RecordSet structure into a FitRes object.

    This function extracts model parameters, metrics, and status information from a
    recordset and converts them into a FitRes object for use in the federated learning
    system. It properly retrieves parameters from the parameter records, extracts the
    number of examples from metrics records, and converts configuration records back
    to Python scalars for the metrics dictionary.

    Parameters
    ----------
    recordset : RecordSet
        The recordset containing parameters, metrics, and status information to convert.
    keep_input : bool
        Whether to preserve the input data when extracting parameter records.

    Returns
    -------
    FitRes
        A FitRes object containing the extracted parameters, metrics, number of
        examples, and status information from the recordset.

    """
    # Obtain the parameters from the recordset
    parameters = parametersrecord_to_parameters(
        recordset.parameters_records[f"{GATHER_RES}.{PARAMETERS}"],
        keep_input=keep_input,
    )
    # Obtain the number of examples from the recordset
    num_examples = cast(
        "int",
        recordset.metrics_records[f"{GATHER_RES}.{NUM_EXAMPLES}"][NUM_EXAMPLES],
    )
    # Obtain the metrics from the recordset
    configs_record = recordset.configs_records[f"{GATHER_RES}.{METRICS}"]
    metrics = _check_mapping_from_recordscalartype_to_scalar(
        configs_record._data,  # noqa: SLF001
    )
    # Obtain the status from the recordset
    status = _extract_status_from_recordset(GATHER_RES, recordset)
    # Build the FitRes object
    return FitRes(
        status=status,
        parameters=parameters,
        num_examples=num_examples,
        metrics=metrics,
    )


def finalize_gather_message(
    node_id: int,
    server_load_balancer: ServerLoadBalancer,
    server_state: ServerState,
    comm_stack: CommStack,
    aggregation_mask: tuple[tuple[bool, ...], list[str], list[str]],
) -> list[Message]:
    """Create parameter gathering messages for a specific client node.

    This function prepares messages requesting trained parameters from a specific client
    node using the provided aggregation mask. It supports both batched and non-batched
    communication modes. For batched communication, it splits the aggregation mask into
    multiple smaller masks and creates a separate message for each batch. For
    non-batched communication, it creates a single message with the complete aggregation
    mask.

    Parameters
    ----------
    node_id : int
        The identifier of the client node to which the messages will be sent.
    server_load_balancer : ServerLoadBalancer
        The load balancer responsible for managing message routing.
    server_state : ServerState
        The current state of the server containing round number for message grouping.
    comm_stack : CommStack
        Configuration specifying communication methods and batching requirements.
    aggregation_mask : tuple[tuple[bool, ...], list[str], list[str]]
        A mask specifying which parameters to gather, consisting of a tuple of
        (boolean selectors, layer names, layer types).

    Returns
    -------
    list[Message]
        A list of messages ready to be sent to the specified client node to request
        trained parameters according to the aggregation mask.

    """
    # Initialize the variables to return
    list_of_messages: list[Message] = []
    # If the communication stack is configured to use batches, we need to create
    # a batch of messages to send to each node
    if comm_stack.n_batches:
        # Create chunks fo the transmission mask based on the number of chunks
        batch_of_masks = mask_to_batches(
            full_mask=aggregation_mask,
            layer_names_and_types=server_state.layer_names_and_types,
            n_batches=comm_stack.n_batches,
        )
        # Iterate over the batches of masks
        for i, mask in enumerate(batch_of_masks):
            # Create the shared message of i-th batch for each node
            message = server_load_balancer.driver.create_message(
                content=get_gather_recordset(
                    aggregation_mask=mask,
                    batch_id=i,
                ),
                message_type=MessageType.QUERY,
                dst_node_id=0,
                group_id=str(server_state.current_round),
                ttl=repo_TTL,
            )
            # Assign the template message to each node
            node_message = server_load_balancer.driver.create_message(
                content=message.content,
                message_type=MessageType.QUERY,
                dst_node_id=node_id,
                group_id=str(server_state.current_round),
                ttl=repo_TTL,
            )
            list_of_messages.append(node_message)
        return list_of_messages

    # If the communication stack is not configured to use batches, we can
    # create a single message to send to each node
    shared_message = server_load_balancer.driver.create_message(
        content=get_gather_recordset(
            aggregation_mask=aggregation_mask,
        ),
        message_type=MessageType.QUERY,
        dst_node_id=0,
        group_id=str(server_state.current_round),
        ttl=repo_TTL,
    )
    node_message = server_load_balancer.driver.create_message(
        content=shared_message.content,
        message_type=MessageType.QUERY,
        dst_node_id=node_id,
        group_id=str(server_state.current_round),
        ttl=repo_TTL,
    )
    list_of_messages.append(node_message)
    return list_of_messages


def get_gather_recordset(
    aggregation_mask: tuple[tuple[bool, ...], list[str], list[str]],
    batch_id: int = 0,
) -> RecordSet:
    """Create a RecordSet for parameter gathering instructions.

    This function prepares a RecordSet containing instructions for client nodes to
    return their trained parameters according to the specified aggregation mask.
    It sets up an empty parameters record, adds query configuration to identify
    this as a gather parameters operation, includes status information, and configures
    communication stack instructions with endpoint and mask information needed for
    the clients to properly respond.

    Parameters
    ----------
    aggregation_mask : tuple[tuple[bool, ...], list[str], list[str]]
        A mask specifying which parameters to gather, consisting of a tuple of
        (boolean selectors, layer names, layer types).
    batch_id : int, default=0
        The batch identifier for batched communication, where parameters are
        requested in multiple smaller batches.

    Returns
    -------
    RecordSet
        A properly formatted RecordSet containing all the necessary instructions
        for clients to return their parameters according to the specified mask.

    """
    # Create an empty RecordSet object
    recordset = RecordSet()
    recordset.parameters_records[f"{GATHER_INS}.{PARAMETERS}"] = (
        parameters_to_parametersrecord(
            Parameters(tensors=[], tensor_type="empty"),
            keep_input=True,
        )
    )
    # Add the broadcast message to the recordset
    recordset.configs_records[QUERY] = ConfigsRecord({TYPE: GATHER_P})
    # Add Status to the recordset
    recordset.configs_records[f"{GATHER_INS}.{STATUS}"] = ConfigsRecord(
        {
            CODE: int(Code.OK.value),
            MESSAGE: GATHER_MSG,
        },
    )
    # Add communication stack instructions to the recordset
    recordset.configs_records[f"{GATHER_INS}.{COMM_STACK}"] = ConfigsRecord(
        {
            ENDPOINT_ID: str(uuid.uuid4()),
            FOLDER_NAME: COMM_STACK,
            FILE_NAME: PARAMETERS,
            MASK: repr(aggregation_mask),
            BATCH_ID: batch_id,
        },
    )
    return recordset


def gather_parameters_from_nodes(
    fit_round_results: Generator[Message, None, None],
    server_load_balancer: ServerLoadBalancer,
    server_state: ServerState,
    comm_stack: CommStack,
) -> Generator[Message, None, None]:
    """Collect trained parameters from client nodes after training.

    This function coordinates the parameter gathering process by iterating through
    fit round results, sending parameter requests to each client node, and collecting
    their responses. It constructs appropriate aggregation masks (either full or
    partial based on client sampling), creates parameter gathering messages for each
    client node, handles the message exchange process including batching support,
    and processes responses by enhancing metrics with layer structure information.

    Parameters
    ----------
    fit_round_results : Generator[Message, None, None]
        Generator yielding messages from completed fit operations, used to determine
        which nodes to request parameters from.
    server_load_balancer : ServerLoadBalancer
        The load balancer responsible for distributing tasks to client nodes and
        handling message routing.
    server_state : ServerState
        The current state of the server containing round information, client sampling
        history, and aggregation mask configuration.
    comm_stack : CommStack
        Configuration specifying communication methods and batching requirements.

    Yields
    ------
    Message
        Response messages from client nodes containing their trained parameters and
        associated metrics, with enhanced layer structure information for proper
        parameter reconstruction during aggregation.

    """
    # Construct aggregation mask generator
    aggregation_mask_generator = partial(
        generate_mask,
        server_state.layer_names_and_types,
        server_state.aggregation_mask_scheduler,
    )
    # Obtain the current aggregation mask
    # NOTE: Taking one client as representative since we assume all clients have
    # the same mask. In the future we may have more complex masks and schedulers
    aggregation_mask = (
        generate_full_mask(server_state.layer_names_and_types)
        if set(server_state.sampled_clients)
        != set(
            server_state.previously_sampled_clients,
        )
        else aggregation_mask_generator(
            server_state.sampled_clients[-1],
            server_state.current_round,
        )
    )
    # Get the next message from the `fit_round_results` generator
    for fit_round_result in fit_round_results:
        # Get the node ID from the message
        node_id = fit_round_result.metadata.src_node_id
        # Create a message for each node
        messages = finalize_gather_message(
            node_id=node_id,
            server_load_balancer=server_load_balancer,
            server_state=server_state,
            comm_stack=comm_stack,
            aggregation_mask=aggregation_mask,
        )
        # Push all messages to all nodes
        log(
            DEBUG,
            "Pushing %s gather messages to %s nodes",
            len(messages),
            len(server_load_balancer.get_available_nodes()),
        )
        message_ids: list[str] = []
        # NOTE: We must send over a message at a time to avoid the Driver to pack them
        # together thus creating problems with the ~500MB limit of gRPC
        for message in messages:
            message_ids.extend(server_load_balancer.driver.push_messages([message]))
        log(
            DEBUG,
            "Pushed %s gather messages to %s nodes",
            len(messages),
            len(server_load_balancer.get_available_nodes()),
        )
        total = len(messages)
        # Loop over messages until all replies have been received
        received = 0
        while received < total:
            # Pull messages from the driver using the IDs we are still waiting for
            replies = list(
                server_load_balancer.driver.pull_messages(message_ids=message_ids),
            )
            # Loop over any reply we have received
            for reply in replies:
                log(
                    DEBUG,
                    "Got 1 %s from %s",
                    "result" if reply.has_content() else "error",
                    reply.metadata.src_node_id,
                )
                # Increment the number of received messages
                received += 1
                # Remove the message IDs from the list of message IDs we are waiting for
                message_ids = [
                    message_id
                    for message_id in message_ids
                    if message_id not in reply.metadata.reply_to_message
                ]
                # Interpret the reply if it has content, i.e., it is not an error
                if reply.has_content():
                    # Add the mask to the metrics to allow the strategy to
                    # reconstruct the parameters correctly and execute partial updates
                    _, layer_names, layer_types = cast(
                        "tuple[tuple[bool, ...], list[str], list[str]]",
                        ast.literal_eval(
                            cast(
                                "str",
                                reply.content.configs_records[
                                    f"{GATHER_RES}.{COMM_STACK}"
                                ][MASK],
                            ),
                        ),
                    )
                    reply.content.configs_records[f"{GATHER_RES}.{METRICS}"][
                        LAYER_NAMES
                    ] = repr(
                        layer_names,
                    )
                    reply.content.configs_records[f"{GATHER_RES}.{METRICS}"][
                        LAYER_TYPES
                    ] = repr(
                        layer_types,
                    )
                    yield reply


def get_handle_success_and_failure_gather(
    metrics_accumulator: list[tuple[dict[str, Scalar], Status, int]],
    gather_failures: list[FitRes | BaseException],
    accept_failures_cnt: int | None,
) -> Callable[
    [tuple[bool, FitRes, NDArrays]],
    tuple[bool, FitRes, NDArrays],
]:
    """Closure to generate a function handling parameter gathering success and failure.

    This function creates a handler that processes gather operation results by
    distinguishing between successful and failed operations. For successful results, it
    appends metrics to the accumulator. For failures, it enforces constraints on the
    maximum number of allowed failures and tracks them appropriately. This approach
    allows consistent handling of gather results throughout the parameter collection
    process.

    Parameters
    ----------
    metrics_accumulator : list[tuple[dict[str, Scalar], Status, int]]
        The list where successful gather metrics are accumulated.
    gather_failures : list[FitRes | BaseException]
        The list where failed gather operations are accumulated.
    accept_failures_cnt : int | None
        The maximum number of unintentional failures to accept. If None, all failures
        are accepted. If the failures exceed this number, a TooManyFailuresError
        is raised.

    Returns
    -------
    Callable[[tuple[bool, FitRes, NDArrays]], tuple[bool, FitRes, NDArrays]]
        A callback function that processes gather results, updates accumulators,
        and handles potential failures based on the specified thresholds.

    """

    def handle_success_and_failure_gather(
        result: tuple[bool, FitRes, NDArrays],
    ) -> tuple[bool, FitRes, NDArrays]:
        cnt_failures = 0

        match result:
            case (True, gather_result, returned_parameters):
                metrics_accumulator.append(
                    (
                        gather_result.metrics,
                        gather_result.status,
                        gather_result.num_examples,
                    ),
                )
                return (True, gather_result, returned_parameters)
            case (False, gather_result, returned_parameters):
                cnt_failures += 1
                if (
                    accept_failures_cnt is not None
                    and cnt_failures > accept_failures_cnt
                ):
                    msg = f"""Unintentional failures passed
                        the maximum: {accept_failures_cnt}"""
                    raise TooManyFailuresError(
                        msg,
                    )
                gather_failures.append(gather_result)
                return (False, gather_result, returned_parameters)
        return result

    return handle_success_and_failure_gather


def handle_gather_replies(
    cfg: BaseConfig,
    replies: Generator[Message, None, None],
    server_state: ServerState,
) -> (
    tuple[
        NDArrays | None,
        dict[str, Scalar],
        tuple[
            list[tuple[dict[str, Scalar], Status, int]],
            list[FitRes | BaseException],
        ],
    ]
    | None
):
    """Process parameter gathering responses from client nodes.

    This function handles replies from client nodes in response to parameter gathering
    requests. It extracts parameters using the appropriate communication stack, converts
    message responses to FitRes objects, handles both successful and failed responses,
    aggregates parameters using the strategy, collects client statistics, updates server
    state with client information, and returns aggregated model parameters and metrics.

    Parameters
    ----------
    cfg : BaseConfig
        Server configuration containing settings for communication and failure handling.
    replies : Generator[Message, None, None]
        A generator yielding client response messages containing trained parameters.
    server_state : ServerState
        The current state of the server containing aggregation strategy, metrics
        history, and client tracking information.

    Returns
    -------
    tuple[
        NDArrays | None,
        dict[str, Scalar],
        tuple[list[tuple[dict[str, Scalar], Status, int]],
        list[FitRes | BaseException]]
    ] | None
        A tuple containing: (1) aggregated model parameters or None if aggregation
        failed, (2) aggregated metrics dictionary, and (3) a tuple of raw metrics and
        failures for detailed tracking. Returns None if processing completely failed.

    Raises
    ------
    TooManyFailuresError
        If the number of failures exceeds the configured threshold and
        ignore_failed_rounds is not set, indicating too many failures during the
        gathering process.

    """
    # Process the replies by first obtaining the parameters depending on the
    # communication stack configured and the rest of the message, which may contain
    # train metrics and other results
    processed_msgs = (
        (
            load_recordset_parameters_from_remote(
                remote_uploader_downloader=server_state.remote_up_down,
                incoming_message=msg,
                comm_stack=cfg.repo.comm_stack,
                msg_str=GATHER_RES,
                ray_gc_queue=server_state.ray_garbage_queue,
            )
            if msg.has_content()
            else (msg, [])
        )
        for msg in replies
    )

    # Translate Messages to FitRes
    all_fitres = (
        (
            (recordset_to_gather_res(msg.content, keep_input=True), returned_parameters)
            if msg.has_content()
            else (ERROR_FITRES, [])
        )
        for (msg, returned_parameters) in processed_msgs
    )

    # Add a boolean to better discriminate between success and failure
    results_and_failures = (
        ((fit_res.status.code == Code.OK, fit_res, returned_parameters))
        for (fit_res, returned_parameters) in all_fitres
    )

    # Using a generator limits us in failure/metrics accumulation
    # The output params are not used in the aggregation
    # They are merely populated by the processing of the generator
    failures: list[FitRes | BaseException] = []

    # Instantiate the metrics accumulator
    metrics_accumulator: list[tuple[dict[str, Scalar], Status, int]] = []

    # Handle success and failure, get the handling function first (necessary because we
    # are using generators throughout this function)
    handle_success_and_failure = get_handle_success_and_failure_gather(
        metrics_accumulator,
        failures,
        accept_failures_cnt=cfg.fl.accept_failures_cnt,
    )
    handled_results_and_failures = (
        handle_success_and_failure(result)  # type: ignore[arg-type]
        for result in results_and_failures
    )

    # Retrieve the successful results
    results = (
        (result, returned_parameters)
        for success, result, returned_parameters in handled_results_and_failures
        if success
    )

    # Process the results by aggregating the parameters and metrics using the strategy
    parameters_aggregated = None
    metrics_aggregated: dict = {}
    try:
        # Aggregate training results
        parameters_aggregated, metrics_aggregated = server_state.strategy.aggregate_fit(
            server_round=server_state.current_round,
            results=(
                (fit_res, returned_parameters)
                for fit_res, returned_parameters in results
            ),
            failures=failures,
        )

        # Collect statistics that repo uses from the FitRes of the NodeManagers
        fit_metrics = [
            (num_examples, metrics) for metrics, _, num_examples in metrics_accumulator
        ]
        # Collect client states
        client_state_accumulator: dict[str | int, dict[str, Any]] = {}
        for _, inner_metrics in fit_metrics:
            # NOTE: If the communication of the gather parameters is batched, only one
            # reply will contain the client state accumulator
            if CLIENT_STATE_ACCUMULATOR not in inner_metrics:
                continue
            acc: dict[str | int, dict[str, Any]] = ast.literal_eval(
                cast("str", inner_metrics.pop(CLIENT_STATE_ACCUMULATOR)),
            )
            client_state_accumulator |= acc
        # Remove from the metrics the layer names and types used by the aggregation
        # strategy to reconstruct the parameters
        for _, inner_metrics in fit_metrics:
            inner_metrics.pop(LAYER_NAMES)
            inner_metrics.pop(LAYER_TYPES)

        # NOTE: When using partial participation we need to accumulate the keys of the
        # old state and the new state
        server_state.client_states |= {
            k: ClientState(**v) for k, v in client_state_accumulator.items()
        }
        # NOTE: Update the server steps cumulative by adding to the previous
        # value the maximum number of local steps done across the clients sample
        # in this round
        max_steps = max(
            *[
                _client_state.steps_done
                for _client_state in server_state.client_states.values()
            ],
            0,
        )
        server_state.server_steps_cumulative += max_steps
        # NOTE: Reset the steps_done for all clients to zero as it is just meant
        # to be an ephemeral record of the local steps done during one federated
        # round
        for c_state in server_state.client_states.values():
            c_state.steps_done = 0

        # Aggregate the metrics
        # NOTE: This bypasses any metrics aggregation in the aggregate_fit of the
        # strategy because the metrics are empty there
        if server_state.strategy.fit_metrics_aggregation_fn:
            metrics_aggregated |= server_state.strategy.fit_metrics_aggregation_fn(
                fit_metrics,
            )

        elif server_state.current_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")
    except TooManyFailuresError as e:
        if cfg.fl.ignore_failed_rounds:
            log(
                ERROR,
                """Ignoring failed round %s: %s,
                there are %s failures: %s""",
                server_state.current_round,
                e,
                len(failures),
                failures,
            )
        else:
            raise
    return (
        parameters_aggregated,
        metrics_aggregated,
        (metrics_accumulator, failures),
    )


def gather_round(
    server_state: ServerState,
    server_load_balancer: ServerLoadBalancer,
    cfg: BaseConfig,
    fit_round_results: Generator[Message, None, None],
) -> None:
    """Execute a complete parameter gathering round in the federated learning process.

    This function orchestrates the entire parameter gathering workflow by collecting
    model parameters from client nodes, aggregating them, and updating the server state.
    It collects responses from nodes that completed training, generates parameter
    gathering messages with appropriate aggregation masks, distributes these requests
    to clients, processes the responses, aggregates the returned parameters using the
    configured strategy, and updates the server state with the new model and metrics.

    Parameters
    ----------
    server_state : ServerState
        The current state of the server containing model information, round number,
        and metrics history.
    server_load_balancer : ServerLoadBalancer
        The load balancer responsible for distributing tasks to client nodes and
        collecting their responses.
    cfg : BaseConfig
        Configuration parameters for the federated learning process, including
        communication stack settings and failure handling policies.
    fit_round_results : Generator[Message, None, None]
        Generator yielding messages from completed fit operations, used to determine
        which nodes to request parameters from.

    """
    # Collect the replies from the gathering messages
    gather_replies = gather_parameters_from_nodes(
        fit_round_results=fit_round_results,
        server_load_balancer=server_load_balancer,
        server_state=server_state,
        comm_stack=cfg.repo.comm_stack,
    )
    # Handle the replies and aggregate the parameters
    gather_results = handle_gather_replies(
        cfg=cfg,
        replies=gather_replies,
        server_state=server_state,
    )
    if gather_results is not None:
        # Unpack the results
        (
            parameters_aggregated,
            fit_metrics,
            (_raw_metrics, _failures),
        ) = gather_results
        # Update the server state
        if parameters_aggregated is not None:
            server_state.model_states = parameters_aggregated
        server_state.history.add_metrics_distributed_fit(
            server_round=server_state.current_round,
            metrics=fit_metrics,
        )
