"""Utility functions for running the main server loop in flwr next."""

import ast
import copy
import gc
import inspect
import random
import time
from collections.abc import Callable, Generator
from dataclasses import asdict, dataclass
from logging import DEBUG, WARNING
from queue import Queue
from typing import Any, cast

from composer.loggers import RemoteUploaderDownloader
from flwr.common import (
    ConfigsRecord,
    EvaluateIns,
    FitIns,
    Message,
    NDArrays,
    Parameters,
    RecordSet,
    log,
)
from flwr.common.recordset_compat import evaluateins_to_recordset, fitins_to_recordset
from flwr.common.typing import ConfigsRecordValues
from flwr.server import Driver

import ray
from repo.constants import (
    CLIENT_STATES,
    EMPTY,
    HISTORY,
    KEY_LIGHTWEIGHT_SERVER_STATE,
    LOCAL_STEPS_CUMULATIVE,
    repo_TTL,
    SERVER_ROUND,
    SERVER_STEPS_CUMULATIVE,
    TIME_OFFSET,
)
from repo.server.client_sampler import ClientSampler
from repo.server.load_balancer import ServerLoadBalancer
from repo.strategy.dispatcher import SupportedStrategies
from repo.utils import ClientState, ModelStateNames
from repo.wandb_history import WandbHistory

DROPOUT_FUNCTIONS: list[str] = ["random"]


class TooManyFailuresError(Exception):
    """Exception raised when a client is dropped out of the tree."""


class ServerStateDictError(Exception):
    """Exception raised when the server state dictionary is not lightweight."""


@dataclass
class ServerState:
    """Server state for federated learning.

    This class holds the server state, including the model states, history,
    and other relevant information for the federated learning process.
    """

    model_states: NDArrays
    layer_names_and_types: tuple[tuple[str, ModelStateNames], ...]
    transmission_scheduler: Callable[[str | int, int], list[ModelStateNames]]
    strategy: SupportedStrategies
    history: WandbHistory
    time_offset: float
    server_steps_cumulative: int
    client_states: dict[str | int, ClientState]
    sampled_clients: list[int]
    previously_sampled_clients: list[int]
    start_round: int
    rng: random.Random
    current_round: int
    current_time_elapsed: float
    local_checkpoint_path: str
    client_sampler: ClientSampler
    momentum_vector: NDArrays | None
    second_momentum_vector: NDArrays | None
    aggregation_mask_scheduler: Callable[[str | int, int], list[ModelStateNames]]
    remote_up_down: RemoteUploaderDownloader | None = None
    ray_garbage_queue: Queue[ray.ObjectRef] | None = None


def create_server_state_lightweight_dict(
    server_state: ServerState,
) -> dict[str, Any]:
    """Create a lightweight dictionary representation of the server state.

    This function extracts essential information from a ServerState object and creates
    a reduced dictionary containing only the key components needed for checkpointing
    and state recovery. The lightweight dictionary excludes large model parameters and
    other memory-intensive components, focusing on elements like the current round,
    history, elapsed time, client states, and cumulative server steps.

    Parameters
    ----------
    server_state : ServerState
        The complete server state object containing all federated learning state
        information, including model parameters, client states, round information,
        and other tracking metadata.

    Returns
    -------
    dict[str, Any]
        A lightweight dictionary containing only essential server state components:
        - server_round: The current round number
        - history: The WandbHistory object tracking metrics
        - time_offset: The current elapsed time in the federated learning process
        - client_states: String representation of client states dictionary
        - server_steps_cumulative: The total number of steps taken by the server

    Notes
    -----
    The lightweight dictionary is used for efficient serialization during checkpointing
    and when transmitting server state information. The client_states are converted to
    a string representation to ensure compatibility with various serialization methods.

    """
    return {
        SERVER_ROUND: server_state.current_round,
        HISTORY: server_state.history,
        TIME_OFFSET: server_state.current_time_elapsed,
        CLIENT_STATES: str(
            {k: asdict(v) for k, v in server_state.client_states.items()},
        ),
        SERVER_STEPS_CUMULATIVE: server_state.server_steps_cumulative,
    }


def read_server_state_lightweight_dict(
    server_state_dict: dict[str, Any],
    n_total_clients: int,
    n_local_steps: int,
) -> tuple[WandbHistory, int, float, int, dict[str | int, ClientState]]:
    """Parse a lightweight server state dictionary into its component parts.

    This function extracts and reconstructs the essential components of a server state
    from a lightweight dictionary representation. It retrieves the current round,
    history object, client states, time offset, and server steps count. If certain
    components are missing from the dictionary, the function generates appropriate
    default values to maintain compatibility with previous versions.

    Parameters
    ----------
    server_state_dict : dict[str, Any]
        A lightweight dictionary containing the serialized server state components.
        Expected to contain keys like SERVER_ROUND, HISTORY, CLIENT_STATES,
        TIME_OFFSET, and SERVER_STEPS_CUMULATIVE.
    n_total_clients : int
        The total number of clients in the federated learning process. Used to
        generate default client states if they are missing from the dictionary.
    n_local_steps : int
        The number of local steps per round. Used in conjunction with start_round
        to calculate default values for client states if needed.

    Returns
    -------
    tuple[WandbHistory, int, float, int, dict[str | int, ClientState]]
        A tuple containing:
        - history: The WandbHistory object tracking metrics
        - start_round: The current round number
        - time_offset: The elapsed time in the federated learning process
        - server_steps_cumulative: The total number of steps taken by the server
        - client_states: Dictionary mapping client IDs to their respective states

    Notes
    -----
    The function provides backward compatibility by generating default values when
    certain components are missing. For example, if client states are not found in
    the dictionary, it creates dummy client states with appropriate cumulative step
    counts based on the current round and local steps configuration.

    """
    # Get the server round
    start_round = int(server_state_dict.pop(SERVER_ROUND))
    # Get the history
    history: WandbHistory = server_state_dict.pop(HISTORY)
    # Read the server state dictionary
    if CLIENT_STATES in server_state_dict:
        saved_client_states: dict[str | int, dict[str, Any]] = ast.literal_eval(
            server_state_dict.pop(CLIENT_STATES),
        )
    else:
        # NOTE: This `LOCAL_STEPS_CUMULATIVE` is just used for backlogging and not
        # by any logic during training so we put a zero for now.
        log(
            WARNING,
            "No client state found in the checkpoint. "
            "We will put dummy values (%s) for the %s clients.",
            int(n_local_steps * start_round),
            n_total_clients,
        )
        saved_client_states = {
            cid: {LOCAL_STEPS_CUMULATIVE: n_local_steps * start_round}
            for cid in range(n_total_clients)
        }
    # NOTE: We maintain partial compatibility across ClientState implementations
    # by only loading the attributes we actually need
    # this decision should be revisited at a late time

    # Automatically deduce the existing args from the ClientState class
    client_states_args = inspect.signature(ClientState.__init__).parameters.keys()

    # Filter and create client_states dictionary in one line
    client_states = {
        k: ClientState(**{attr: v[attr] for attr in v if attr in client_states_args})
        for k, v in saved_client_states.items()
    }
    # Get the time offset
    time_offset = 0.0
    if TIME_OFFSET in server_state_dict:
        time_offset = server_state_dict.pop(TIME_OFFSET)
    # Get the server steps cumulative
    if SERVER_STEPS_CUMULATIVE in server_state_dict:
        server_steps_cumulative = server_state_dict.pop(SERVER_STEPS_CUMULATIVE)
    else:
        # Make it back compatible with the previous versions
        server_steps_cumulative = max(
            *[
                _client_states.local_steps_cumulative
                for _client_states in client_states.values()
            ],
            0,
        )
    # Check if there are any keys left uninterpreted
    if server_state_dict:
        log(
            WARNING,
            "Server state keys %s are not consistent",
            server_state_dict.keys(),
        )
    # Return the server state components
    return (
        history,
        start_round,
        time_offset,
        server_steps_cumulative,
        client_states,
    )


def check_server_state_lightweight_dict(
    server_state_dict: dict[str, Any],
    *,
    raise_if_not_lightweight: bool = False,
) -> bool:
    """Check if a server state dictionary contains only lightweight keys.

    This function verifies that a server state dictionary only contains the expected
    keys for a lightweight representation. It validates that all keys in the provided
    dictionary are included in the predefined set of lightweight server state keys
    (KEY_LIGHTWEIGHT_SERVER_STATE) and that there are no additional keys present.

    Parameters
    ----------
    server_state_dict : dict[str, Any]
        The server state dictionary to check for lightweight compliance. This should
        be a dictionary containing server state information, typically created by
        the `create_server_state_lightweight_dict` function.
    raise_if_not_lightweight : bool, optional
        If True and the dictionary is found not to be lightweight, raises a
        ServerStateDictError exception. Default is False, which only returns the
        boolean result without raising an exception.

    Returns
    -------
    bool
        True if the server state dictionary contains only the expected lightweight
        keys, False otherwise.

    Raises
    ------
    ServerStateDictError
        If raise_if_not_lightweight is True and the server state dictionary contains
        unexpected keys or is missing expected keys.

    Notes
    -----
    A lightweight server state dictionary is expected to contain only the specific
    keys defined in KEY_LIGHTWEIGHT_SERVER_STATE. This is important for efficient
    serialization and storage of server state information during checkpointing in
    the federated learning process.

    """
    # Initialize the return value
    is_lightweight = True
    # Copy to avoid modifying the original dictionary
    dummy_server_state = copy.deepcopy(server_state_dict)
    # Check if the server state is lightweight
    for key in server_state_dict:
        if key not in KEY_LIGHTWEIGHT_SERVER_STATE:
            log(
                DEBUG,
                "Server state key %s is not consistent",
                key,
            )
            is_lightweight = False
        dummy_server_state.pop(key)
    # Check if there are any keys left
    if dummy_server_state:
        log(
            DEBUG,
            "Server state keys %s are not consistent",
            dummy_server_state.keys(),
        )
        is_lightweight = False
    # Raise an error if the server state is not lightweight
    if not is_lightweight and raise_if_not_lightweight:
        raise ServerStateDictError
    return is_lightweight


def wait_for_nodes_to_connect(driver: Driver, n_nodes: int, timeout: float = 3) -> None:
    """Wait for a specified number of client nodes to connect.

    This function continuously checks the number of client nodes that have connected by
    calling the `get_node_ids` method on the provided driver object. It logs the current
    number of connected client nodes at each check. The function will exit once the
    number of connected client nodes meets or exceeds the specified `n_nodes`. If the
    required number of nodes are not found, the function will wait for the specified
    `timeout` period before checking again.

    Parameters
    ----------
    driver : Driver
        The driver object used to interact with the client nodes.
    n_nodes : int
        The minimum number of client nodes that must be connected.
    timeout : float, optional
        The time in seconds to wait between checks. Default is 3 seconds.

    """
    # The Driver API might not immediately return enough client node IDs, so we
    # loop and wait until enough client nodes are available.
    while True:
        all_node_ids = driver.get_node_ids()
        log(DEBUG, f"Got {len(all_node_ids)} client nodes: {all_node_ids}")
        if len(all_node_ids) >= n_nodes:
            break
        time.sleep(timeout)


def construct_message_for_client(
    ids: tuple[int, int],
    gen_ins_function: Callable[[int, int | str], dict[str, ConfigsRecordValues]],
    driver: Driver,
    message_constants: tuple[str, str],
    server_state: ServerState,
) -> Message:
    """Create a message with instructions for a specific client.

    This function constructs a message containing fitting or evaluation instructions
    targeted at a specific client on a specific node. It creates an appropriate record
    set based on the message type, populates it with the current server state and round
    information, and adds client-specific configuration generated by the provided
    instruction function.

    Parameters
    ----------
    ids : tuple[int, int]
        A tuple containing the node ID and client ID to which the message will be sent.
        The first element is the node ID, and the second is the client ID.
    gen_ins_function : Callable[[int, int | str], dict[str, ConfigsRecordValues]]
        A function that generates instruction configurations for the client based on
        the current round and client ID.
    driver : Driver
        The driver object used to create the message for transmission.
    message_constants : tuple[str, str]
        A tuple containing two strings: the first is the message type (e.g., "fit"
        or "evaluate") and the second is a string identifier used in recordset keys.
    server_state : ServerState
        The current state of the server, including the current round, client states,
        and other federated learning metadata.

    Returns
    -------
    Message
        A message object ready to be sent to the specified client on the specified
        node, containing the appropriate instruction record set.

    """
    # Unpack the message constants
    message_type, msg_str = message_constants
    # Unpack the ids
    node_id, cid = ids
    record_set = fit_or_evaluate_ins_recordset(
        msg_str,
        server_state.current_round,
        cast("list[int] | list[str]", [cid]),
        server_state.client_states,
        server_state.server_steps_cumulative,
    )
    # Create a config record for the client
    record_set.configs_records.update(
        {str(cid): ConfigsRecord(gen_ins_function(server_state.current_round, cid))},
    )
    return driver.create_message(
        content=record_set,
        message_type=message_type,
        dst_node_id=node_id,
        group_id=str(server_state.current_round),
        ttl=repo_TTL,
    )


def message_collaborative_assignment(  # noqa: C901
    server_state: ServerState,
    server_load_balancer: ServerLoadBalancer,
    message_constants: tuple[str, str],
    gen_ins_function: Callable[[int, int | str], dict[str, ConfigsRecordValues]],
) -> tuple[list[Message], list[int], int]:
    """Assign messages to available nodes for collaborative processing.

    This function handles the assignment of federated learning tasks to available
    client nodes. It creates appropriate messages for each client-node pairing based
    on the current server state and load balancer configuration. The function supports
    two modes of operation: a production mode where fake client IDs are used, and a
    non-production mode where client IDs are balanced across available nodes.

    In the non-production mode, the function attempts to balance client assignments
    across nodes. If balanced assignment is not possible, it ensures each available
    node is assigned at least one client, prioritizing clients that were specifically
    sampled for the current round.

    Parameters
    ----------
    server_state : ServerState
        The current state of the server, including sampled clients and other
        metadata needed for message construction.
    server_load_balancer : ServerLoadBalancer
        Load balancer that manages the distribution of client tasks across
        available nodes.
    message_constants : tuple[str, str]
        A tuple containing two strings: the first is the message type and
        the second is a string identifier used in recordset keys.
    gen_ins_function : Callable[[int, int | str], dict[str, ConfigsRecordValues]]
        A function that generates instruction configurations for clients based on
        the current round and client ID.

    Returns
    -------
    tuple[list[Message], list[int], int]
        A tuple containing:
        - A list of constructed messages ready to be sent to nodes
        - A list of remaining client IDs that were not assigned
        - The total number of client assignments attempted

    Notes
    -----
    In production mode, the function assigns fake client IDs to nodes based on the
    load balancer's configuration. In non-production mode, it attempts to balance
    real client IDs across nodes, ensuring every node receives at least one client
    assignment when possible.

    """
    # Initialize list of messages
    messages: list[Message] = []
    # Obtain the list of sampled clients. Copy not to modify the original list
    cids = copy.deepcopy(server_state.sampled_clients)
    # Set the total number of messages to be sent
    total = len(cids)
    # Get the available nodes
    node_ids = server_load_balancer.get_available_nodes()
    # Assign the clients to the nodes
    if server_load_balancer.is_production:
        cids = []
        for node_id in node_ids:
            fake_cid = server_load_balancer.nodes_to_fake_ids[node_id]
            messages.append(
                construct_message_for_client(
                    ids=(node_id, fake_cid),
                    message_constants=message_constants,
                    gen_ins_function=gen_ins_function,
                    driver=server_load_balancer.driver,
                    server_state=server_state,
                ),
            )
    else:
        is_balanced = server_load_balancer.assign_clients_to_nodes(
            cids,
        )
        if is_balanced:
            for node_id in node_ids:
                for cid in server_load_balancer.node_id_to_client_ids[node_id]:
                    messages.append(
                        construct_message_for_client(
                            ids=(node_id, cid),
                            message_constants=message_constants,
                            gen_ins_function=gen_ins_function,
                            driver=server_load_balancer.driver,
                            server_state=server_state,
                        ),
                    )
                    cids.remove(cid)
            assert len(cids) == 0, "All cids should be assigned to nodes"
        else:
            assignees = []
            for node_id in node_ids:
                for cid in server_load_balancer.node_id_to_client_ids[node_id]:
                    if cid in server_state.sampled_clients:
                        messages.append(
                            construct_message_for_client(
                                ids=(node_id, cid),
                                message_constants=message_constants,
                                gen_ins_function=gen_ins_function,
                                driver=server_load_balancer.driver,
                                server_state=server_state,
                            ),
                        )
                        cids.remove(cid)
                        assignees.append(node_id)
                        break
            if set(node_ids) - set(assignees):
                # If some nodes are not assigned, assign them to the remaining cids
                for node_id in set(node_ids) - set(assignees):
                    spare_cid = cids.pop(0)
                    messages.append(
                        construct_message_for_client(
                            ids=(node_id, spare_cid),
                            message_constants=message_constants,
                            gen_ins_function=gen_ins_function,
                            driver=server_load_balancer.driver,
                            server_state=server_state,
                        ),
                    )
            assert len(messages) == len(
                node_ids,
            ), "Number of messages should be equal to the number of nodes"
    return messages, cids, total


def find_and_send_message_to_node(
    node_id: int,
    messages: list[Message],
    server_load_balancer: ServerLoadBalancer,
    message_ids: list[str],
) -> None:
    """Find and send a message targeted to a specific node.

    This function searches through a list of messages for one that is destined for
    the specified node ID. When it finds a matching message, it removes it from the
    message list, sends it to the node using the server load balancer, and updates
    the list of message IDs with the IDs of the sent message. This approach ensures
    each node receives messages specifically intended for it and enables tracking
    of in-flight messages.

    Parameters
    ----------
    node_id : int
        The ID of the target node to find a message for.
    messages : list[Message]
        List of available messages to search through, modified in-place by removing
        the found message.
    server_load_balancer : ServerLoadBalancer
        The load balancer responsible for sending messages to nodes.
    message_ids : list[str]
        List of message IDs being tracked, modified in-place by adding IDs of newly
        sent messages.

    """
    new_message: Message | None = None
    # Get the first message available for the node
    for i, message in enumerate(messages):
        if message.metadata.dst_node_id == node_id:
            # Remove the message from the list of messages
            new_message = messages.pop(i)
            # Push the message to the driver and update the message IDs
            message_ids.extend(
                list(server_load_balancer.driver.push_messages([new_message])),
            )
            break


def message_collaborative(
    server_load_balancer: ServerLoadBalancer,
    message_constants: tuple[str, str],
    server_state: ServerState,
    gen_ins_function: Callable[[int, int | str], dict[str, ConfigsRecordValues]],
) -> Generator[Message, None, None]:
    """Coordinate message exchange with client nodes in a collaborative manner.

    This function implements a dynamic work distribution system where client nodes are
    assigned tasks (training or evaluation) sequentially. It first distributes initial
    tasks to available nodes, then continuously polls for completed work. When a node
    completes a task, this function either assigns another pre-prepared message or
    constructs a new message for the next unassigned client. This approach maximizes
    resource utilization by keeping all nodes busy as long as work remains.

    Parameters
    ----------
    server_load_balancer : ServerLoadBalancer
        The load balancer responsible for tracking node availability and handling
        message routing between server and client nodes.
    message_constants : tuple[str, str]
        A tuple containing two strings: the first is the message type (e.g., "train"
        or "evaluate") and the second is a string identifier used in recordset keys.
    server_state : ServerState
        The current state of the server, including round number, client states, and
        other metadata needed for task construction.
    gen_ins_function : Callable[[int, int | str], dict[str, ConfigsRecordValues]]
        A function that generates instruction configurations for clients based on
        the current round and client ID.

    Yields
    ------
    Message
        Reply messages from client nodes containing the results of their assigned
        tasks, such as model updates or evaluation metrics.

    Notes
    -----
    This function manages message flow using a dynamic approach:
    1. Initially assigns one task per available node
    2. Processes replies as they arrive rather than waiting for all replies
    3. Immediately reassigns nodes to new work when they finish previous tasks
    4. Uses garbage collection to manage memory during operation

    """
    log(DEBUG, "Collaborative messaging.")
    messages, cids, total = message_collaborative_assignment(
        server_state=server_state,
        server_load_balancer=server_load_balancer,
        message_constants=message_constants,
        gen_ins_function=gen_ins_function,
    )
    # Initialize the message IDs list
    message_ids: list[str] = []
    # Send one message per node at a time
    for node_id in server_load_balancer.get_available_nodes():
        find_and_send_message_to_node(
            node_id=node_id,
            messages=messages,
            server_load_balancer=server_load_balancer,
            message_ids=message_ids,
        )

    log(DEBUG, "Pushed %s messages: %s", len(messages), message_ids)

    # GC collect
    gc.collect()

    # Loop over pulling and elaborating messages until all replies have been received
    received = 0
    while received < total:
        # Pull messages from the driver using the IDs we are still waiting for
        replies = list(
            server_load_balancer.driver.pull_messages(message_ids=message_ids),
        )
        # Loop over any reply we have received
        for reply in replies:
            log(
                DEBUG,
                "Got 1 %s from %s",
                "result" if reply.has_content() else "error",
                reply.metadata.src_node_id,
            )
            # Increment the number of received messages
            received += 1
            # Remove the message IDs from the list of message IDs we are waiting for
            message_ids = [
                message_id
                for message_id in message_ids
                if message_id not in reply.metadata.reply_to_message
            ]
            # Interpret the reply if it has content, i.e., it is not an error
            if reply.has_content():
                yield reply
            # Send over the next message to the node that replied if any
            if messages:
                find_and_send_message_to_node(
                    node_id=reply.metadata.src_node_id,
                    messages=messages,
                    server_load_balancer=server_load_balancer,
                    message_ids=message_ids,
                )
            # Send over a new message to the node that replied if any clients are left
            elif cids:
                cid = int(cids.pop(0))
                new_message = construct_message_for_client(
                    ids=(reply.metadata.src_node_id, cid),
                    message_constants=message_constants,
                    gen_ins_function=gen_ins_function,
                    driver=server_load_balancer.driver,
                    server_state=server_state,
                )
                # Push the messages to the driver
                message_ids.extend(
                    server_load_balancer.driver.push_messages(
                        [new_message],
                    ),
                )


def fit_or_evaluate_ins_recordset(
    msg_str: str,
    current_round: int,
    sampled_clients: list[int] | list[str],
    client_states: dict[str | int, ClientState],
    server_steps_cumulative: int,
) -> RecordSet:
    """Create a RecordSet for fit or evaluation instructions based on the message type.

    This function generates a RecordSet for either fitting or evaluation instructions,
    depending on the message string provided (`msg_str`). It supports creating
    instructions for a list of sampled clients, incorporating the current server round
    and client IDs into the configuration of the instructions. The function is designed
    to work with federated learning scenarios where instructions need to be dynamically
    generated and dispatched to clients based on the current round of training or
    evaluation.

    Parameters
    ----------
    msg_str : str
        A string indicating the type of instructions to generate. Expected values are
        "fitins" for fitting instructions or "evaluateins" for evaluation instructions.
    current_round : int
        The current round of the federated learning process. This is used to track the
        progress of the learning or evaluation over time.
    sampled_clients : list[int] | list[str]
        A list of client identifiers (either integers or strings) that have been sampled
        for participation in the current round. These identifiers are included in the
        instructions to specify the target clients.
    client_states : dict[str | int, ClientState]
        A dictionary mapping client identifiers to their current state information. This
        information is included in the instructions to provide context to the clients
        about their previous interactions with the server.
    server_steps_cumulative : int
        The cumulative number of training steps performed by the server across all
        rounds. This value is included in the instructions to provide context to the
        clients.

    Returns
    -------
    RecordSet
        A RecordSet object containing the generated instructions for fitting or
        evaluation. The RecordSet includes parameters and configuration tailored to the
        specified clients and the current round.


    Notes
    -----
    - The function uses `fitins_to_recordset` and `evaluateins_to_recordset` to convert
        fitting or evaluation instructions into a RecordSet format suitable for
        transmission to clients.
    - The `parameters` field of the instructions is initialized with an empty tensor,
        indicating that no initial parameters are provided to the clients.
    - The configuration includes the `server_round` to inform clients of the current
        round and `client_ids` to specify the target clients for these instructions.
    - This function is part of a federated learning server utility and assumes the
        existence of `FitIns`, `EvaluateIns`, `Parameters`, `RecordSet`,
        `fitins_to_recordset`, and `evaluateins_to_recordset` classes or functions.

    """
    recordset: RecordSet | None = None
    match msg_str:
        case "fitins":
            # Create shared recordset for all clients in this assignment
            recordset = fitins_to_recordset(
                FitIns(
                    parameters=Parameters(tensors=[], tensor_type=EMPTY),
                    # NOTE: We always need to pass the server round to the config record
                    config={
                        SERVER_ROUND: current_round,
                        "client_ids": str(sampled_clients),
                        CLIENT_STATES: str(
                            {k: asdict(v) for k, v in client_states.items()},
                        ),
                        SERVER_STEPS_CUMULATIVE: server_steps_cumulative,
                    },
                ),
                keep_input=True,
            )
        case "evaluateins":
            # Create shared recordset for all clients in this assignment
            recordset = evaluateins_to_recordset(
                EvaluateIns(
                    parameters=Parameters(tensors=[], tensor_type=EMPTY),
                    # NOTE: We always need to pass the server round to the config record
                    config={
                        SERVER_ROUND: current_round,
                        "client_ids": str(sampled_clients),
                        CLIENT_STATES: str(
                            {k: asdict(v) for k, v in client_states.items()},
                        ),
                        SERVER_STEPS_CUMULATIVE: server_steps_cumulative,
                    },
                ),
                keep_input=True,
            )
    assert recordset is not None, "Recordset must be created"
    return recordset
