"""Utility functions for broadcasting to clients on main server loop in flwr next."""

import uuid
from logging import DEBUG

from flwr.common import (
    Code,
    ConfigsRecord,
    Message,
    MessageType,
    Parameters,
    RecordSet,
    log,
)
from flwr.common.recordset_compat import parameters_to_parametersrecord

from repo.conf.base_schema import CommStack
from repo.constants import (
    BATCH_ID,
    BROADCAST,
    BROADCAST_INS,
    BROADCAST_MSG,
    BROADCAST_P,
    CODE,
    COMM_STACK,
    ENDPOINT_ID,
    FILE_NAME,
    FOLDER_NAME,
    MASK,
    MESSAGE,
    OK,
    PARAMETERS,
    repo_TTL,
    QUERY,
    STATUS,
    TYPE,
)
from repo.masks_utils import generate_full_mask, generate_mask, mask_to_batches
from repo.server.comms import (
    offload_recordset_parameters_to_remote,
)
from repo.server.load_balancer import ServerLoadBalancer
from repo.server.s3_utils import delete_past_communication_states
from repo.server.utils import ServerState
from repo.shm.utils import compress_with_strict
from ray import ObjectRef


def finalize_broadcast_message(
    msg_str: str,
    server_load_balancer: ServerLoadBalancer,
    server_state: ServerState,
    comm_stack: CommStack,
    transmission_mask: tuple[tuple[bool, ...], list[str], list[str]],
) -> tuple[list[Message], list[ObjectRef] | None]:
    """Create messages for broadcasting parameters to all registered nodes.

    This function prepares messages for broadcasting model parameters to nodes in a
    federated learning system. It supports both batched and non-batched transmission
    modes based on the communication stack configuration. For batched transmission,
    it divides the parameters into multiple messages. It handles offloading parameters
    to remote storage when necessary and creates individualized messages for each node.

    Parameters
    ----------
    msg_str : str
        Identifier string used for record keys
    server_load_balancer : ServerLoadBalancer
        Load balancer managing node registration and messaging
    server_state : ServerState
        Server state containing model parameters and configuration
    comm_stack : CommStack
        Communication stack configuration controlling transmission behavior
    transmission_mask : tuple[tuple[bool, ...], list[str], list[str]]
        Tuple containing (boolean selectors, layer names, layer types) that
        specifies which parameters to transmit

    Returns
    -------
    tuple[list[Message], list[ObjectRef] | None]
        A tuple containing a list of messages to be sent and a list of Ray object
        references (or None if no Ray objects were created)

    """
    # Initialize the variables to return
    ray_object_refs: list[ObjectRef] = []
    list_of_messages: list[Message] = []
    # If the communication stack is configured to use batches, we need to create
    # a batch of messages to send to each node
    if comm_stack.n_batches:
        # Create chunks for the transmission mask based on the number of chunks
        batch_of_masks = mask_to_batches(
            full_mask=transmission_mask,
            layer_names_and_types=server_state.layer_names_and_types,
            n_batches=comm_stack.n_batches,
        )
        # Iterate over the batches of masks
        for i, mask in enumerate(batch_of_masks):
            # Create the shared message of i-th batch for each node
            message = server_load_balancer.driver.create_message(
                content=get_broadcast_recordset(
                    msg_str=msg_str,
                    transmission_mask=mask,
                    batch_id=i,
                ),
                message_type=MessageType.QUERY,
                dst_node_id=0,
                group_id=str(server_state.current_round),
                ttl=repo_TTL,
            )
            # Add the parameters to the message
            message.content, c_list_of_ray_object_refs = (
                offload_recordset_parameters_to_remote(
                    parameters=list(
                        compress_with_strict(
                            data=server_state.model_states,
                            selectors=mask[0],
                            strict=True,
                        ),
                    ),
                    remote_uploader_downloader=server_state.remote_up_down,
                    outgoing_recordset=message.content,
                    comm_stack=comm_stack,
                    msg_str=msg_str,
                )
            )
            # Assign the template message to each node
            for node_id in server_load_balancer.get_registered_nodes():
                node_message = server_load_balancer.driver.create_message(
                    content=message.content,
                    message_type=MessageType.QUERY,
                    dst_node_id=node_id,
                    group_id=str(server_state.current_round),
                    ttl=repo_TTL,
                )
                list_of_messages.append(node_message)
            if c_list_of_ray_object_refs is not None:
                ray_object_refs.extend(c_list_of_ray_object_refs)
        return list_of_messages, ray_object_refs or None

    # If the communication stack is not configured to use batches, we can
    # create a single message to send to each node
    shared_message = server_load_balancer.driver.create_message(
        content=get_broadcast_recordset(
            msg_str=msg_str,
            transmission_mask=transmission_mask,
        ),
        message_type=MessageType.QUERY,
        dst_node_id=0,
        group_id=str(server_state.current_round),
        ttl=repo_TTL,
    )
    shared_message.content, list_of_ray_object_refs = (
        offload_recordset_parameters_to_remote(
            parameters=server_state.model_states,
            remote_uploader_downloader=server_state.remote_up_down,
            outgoing_recordset=shared_message.content,
            comm_stack=comm_stack,
            msg_str=msg_str,
        )
    )
    for node_id in server_load_balancer.get_registered_nodes():
        node_message = server_load_balancer.driver.create_message(
            content=shared_message.content,
            message_type=MessageType.QUERY,
            dst_node_id=node_id,
            group_id=str(server_state.current_round),
            ttl=repo_TTL,
        )
        list_of_messages.append(node_message)
    return list_of_messages, list_of_ray_object_refs


def get_broadcast_recordset(
    msg_str: str,
    transmission_mask: tuple[tuple[bool, ...], list[str], list[str]],
    batch_id: int = 0,
) -> RecordSet:
    """Create a RecordSet object configured for parameter broadcasting.

    This function constructs a RecordSet with the necessary configuration for
    broadcasting model parameters in a federated learning system. It configures empty
    parameter records, adds broadcast message type, status information, and
    communication stack instructions including endpoint ID, storage locations,
    transmission mask, and batch identification.

    Parameters
    ----------
    msg_str : Message
        Identifier string used for record keys
    transmission_mask : tuple[tuple[bool, ...], list[str], list[str]]
        Tuple containing (boolean selectors, layer names, layer types) that
        specifies which parameters to transmit
    batch_id : int
        Identifier for the batch when using batched transmission (default: 0)

    Returns
    -------
        RecordSet: A fully configured RecordSet object ready for broadcast

    """
    # Create an empty RecordSet object
    recordset = RecordSet()
    recordset.parameters_records[f"{BROADCAST_INS}.{PARAMETERS}"] = (
        parameters_to_parametersrecord(
            Parameters(tensors=[], tensor_type="empty"),
            keep_input=True,
        )
    )
    # Add the broadcast message to the recordset
    recordset.configs_records[QUERY] = ConfigsRecord({TYPE: BROADCAST_P})
    # Add Status to the recordset
    recordset.configs_records[f"{msg_str}.{STATUS}"] = ConfigsRecord(
        {
            CODE: int(Code.OK.value),
            MESSAGE: BROADCAST_MSG,
        },
    )
    # Add communication stack instructions to the recordset
    recordset.configs_records[f"{msg_str}.{COMM_STACK}"] = ConfigsRecord(
        {
            # ENDPOINT_ID: server_uuid,
            ENDPOINT_ID: str(uuid.uuid4()),
            FOLDER_NAME: COMM_STACK,
            FILE_NAME: PARAMETERS,
            MASK: repr(transmission_mask),
            BATCH_ID: batch_id,
        },
    )
    return recordset


def broadcast_parameters_to_nodes(
    server_load_balancer: ServerLoadBalancer,
    server_state: ServerState,
    comm_stack: CommStack,
) -> None:
    """Broadcast model parameters to all registered nodes in the federated system.

    This function orchestrates the broadcasting of model parameters to all registered
    nodes in the federated learning system. It determines which parameters to send using
    transmission masks, compresses the model state data, and creates appropriate
    messages for each node. It handles both batched and non-batched transmission modes,
    sends the messages, waits for acknowledgment from all nodes, and manages cleanup of
    Ray object references.

    Parameters
    ----------
    server_load_balancer : ServerLoadBalancer
        Load balancer managing node registration and messaging
    server_state : ServerState
        Server state containing model parameters and configuration
    comm_stack : CommStack
        Communication stack configuration controlling transmission behavior

    Raises
    ------
    ValueError
        If the broadcast fails or if the message format is incorrect.

    """
    # Construct transmission mask
    # NOTE: Taking one client as representative since we assume all clients have
    # the same mask. In the future we may have more complex masks and schedulers
    transmission_mask = (
        generate_full_mask(server_state.layer_names_and_types)
        if set(server_state.sampled_clients)
        != set(
            server_state.previously_sampled_clients,
        )
        else generate_mask(
            server_state.layer_names_and_types,
            server_state.transmission_scheduler,
            server_state.sampled_clients[-1],
            # NOTE: The transmission scheduler must look backward in time to the last
            # round's masks
            server_state.current_round - 1,
        )
    )
    # Message name
    msg_str = f"{BROADCAST_INS}"
    # Create a message for each node
    messages, list_of_ray_object_refs = finalize_broadcast_message(
        msg_str=msg_str,
        server_load_balancer=server_load_balancer,
        server_state=server_state,
        comm_stack=comm_stack,
        transmission_mask=transmission_mask,
    )
    # Push all messages to all nodes
    log(
        DEBUG,
        "Pushing %s broadcast messages to %s nodes",
        len(messages),
        len(server_load_balancer.get_registered_nodes()),
    )
    message_ids: list[str] = []
    # NOTE: We must send over a message at a time to avoid the Driver to pack them
    # together thus creating problems with the ~500MB limit of gRPC
    for message in messages:
        message_ids.extend(server_load_balancer.driver.push_messages([message]))
    log(
        DEBUG,
        "Pushed %s broadcast messages to %s nodes",
        len(messages),
        len(server_load_balancer.get_registered_nodes()),
    )
    # Wait for results, ignore empty message_ids
    message_ids = [message_id for message_id in message_ids if message_id]
    all_replies: list[Message] = []
    while True:
        replies = server_load_balancer.driver.pull_messages(message_ids=message_ids)
        all_replies += replies
        if len(all_replies) == len(message_ids):
            break
    # Append list_of_ray_object_refs to the garbage collector queue
    if list_of_ray_object_refs:
        assert server_state.ray_garbage_queue is not None, (
            "Ray garbage queue is None. "
            "Please check if the server state is initialized correctly."
        )
        for obj_ref in list_of_ray_object_refs:
            server_state.ray_garbage_queue.put(obj_ref)
    # Filter correct results
    messages_with_content = [msg for msg in all_replies if msg.has_content()]
    # Elaborate results
    server_load_balancer.set_all_nodes_states(state=False)
    for message_with_content in messages_with_content:
        message_res = message_with_content.content
        assert BROADCAST in message_res.configs_records, "Broadcast key not found"
        assert STATUS in message_res.configs_records[BROADCAST], "Status key not found"
        if message_res.configs_records[BROADCAST][STATUS] != OK:
            msg = "Broadcast failed"
            raise ValueError(msg)
    if comm_stack.s3:
        assert server_state.remote_up_down is not None, (
            "Remote uploader/downloader is None. "
            "Please check if the server state is initialized correctly."
        )
        # Clean up all the past arrays stored in the S3 as they are stale
        delete_past_communication_states(
            bucket_name=server_state.remote_up_down.remote_bucket_name,
            prefix=server_state.remote_up_down.backend_kwargs["prefix"],
        )
    server_load_balancer.set_all_nodes_states(state=True)
    log(DEBUG, f"Received {len(messages_with_content)} correct results")
