"""Client Application for a Flower-based LLM training node.

This module defines a client-side application using the Flower framework. It handles
parameter broadcasting, training, evaluation, and communication with a node manager.
"""

import ast
import copy
import time
import uuid
import warnings
from logging import DEBUG
from typing import cast

from flwr.common import (
    Code,
    ConfigsRecord,
    Context,
    EvaluateRes,
    FitRes,
    Message,
    Parameters,
    RecordSet,
    Status,
)
from flwr.common.logger import update_console_handler
from flwr.common.recordset_compat import (
    evaluateres_to_recordset,
    fitres_to_recordset,
    recordset_to_evaluateins,
    recordset_to_fitins,
)

from repo.constants import (
    BATCH_ID,
    BROADCAST_INS,
    CHIAPPE_SODE,
    COMM_STACK,
    ENDPOINT_ID,
    EVALUATE_RESULTS,
    EXPECTED_LATENCY,
    FILE_NAME,
    FOLDER_NAME,
    GATHER_INS,
    GATHER_RES,
    MASK,
    PARAMETERS,
    QUERY,
    SERVER_ROUND,
    TYPE,
)
from repo.masks_utils import (
    combine_masks,
    generate_empty_mask,
    mask_to_batches,
)
from repo.node_manager.node_manager_app import NodeManagerApp
from repo.server.comms import (
    load_recordset_parameters_from_remote,
    offload_recordset_parameters_to_remote,
)
from repo.server.gather_utils import gather_res_to_recordset
from repo.shm.constants import NM_PARAMETERS_SHM
from repo.shm.utils import compress_with_strict, set_parameters_shm

# Fix the logger
update_console_handler(level=DEBUG, colored=False, timestamps=True)

# Filter user warning from configuration of MPT
warnings.filterwarnings(
    action="ignore",
    category=UserWarning,
    message=("If not using a Prefix Language Model*"),
    append=True,
)
# TODO(<Anonymous>): These don't work -- not sure why
# Filter deprecation warning from pkg_resources
warnings.filterwarnings(
    action="ignore",
    category=DeprecationWarning,
    message=("Deprecated call to *"),
    append=True,
)
warnings.filterwarnings(
    action="ignore",
    category=DeprecationWarning,
    message=("pkg_resources is deprecated*"),
    append=True,
)

# Flower ClientApp
app = NodeManagerApp()


def set_parameters(msg: Message, ctx: Context) -> Message:  # noqa: ARG001
    """Receive and store parameters from the server.

    This function handles parameter broadcasts from the server by receiving the
    parameter data, applying the appropriate mask for sparse updates, and storing them
    in shared memory. For batched communication, it tracks and combines the transmission
    masks across multiple batches to ensure complete parameter updates.

    Parameters
    ----------
    msg : Message
        Incoming message containing parameters and mask information.
    ctx : Context
        Execution context passed by Flower.

    Returns
    -------
    Message
        Reply message confirming successful parameter storage.

    """
    # The payload we receive may be sparse based what the server sent thus, when we
    # write to the shm we only need to write a subset of the parameters. The mask is a
    # tuple of (mask, layer_names, layer_types) where mask is a tuple of: (a) booleans
    # indicating which parameters to keep; (b) layer_names is a list of strings
    # indicating the names of the layers; (c) layer_types is a list of strings
    # indicating the types of the layers.
    mask: tuple[tuple[bool, ...], list[str], list[str]] = ast.literal_eval(
        str(msg.content.configs_records[f"{BROADCAST_INS}.{COMM_STACK}"][MASK]),
    )
    # For batched communication, only set what has been received currently
    if app.cfg.repo.comm_stack.n_batches:
        # Get the batch ID from the message
        batch_id = msg.content.configs_records[f"{BROADCAST_INS}.{COMM_STACK}"][
            BATCH_ID
        ]
        # If this is the first batch, we need to initialize the whole transmission mask
        # that will be used by the client with the empty mask so we can combine it with
        # all the partial masks
        if batch_id == 0:
            app.transmission_mask = generate_empty_mask(
                layer_names_and_types=app.layer_names_and_types,
            )
        assert app.transmission_mask is not None, (
            "Transmission mask must be initialized"
        )
        # Combine the currently received mask with the previous one
        app.transmission_mask = combine_masks(
            mask_a=app.transmission_mask,
            mask_b=mask,
            layer_names_and_types=app.layer_names_and_types,
        )
    else:
        # Handle null case, we need to set the mask to the one received
        # from the server
        app.transmission_mask = mask

    # Obtain parameters from the communication stack
    msg, parameters = load_recordset_parameters_from_remote(
        remote_uploader_downloader=app.remote_up_down,
        incoming_message=msg,
        comm_stack=app.cfg.repo.comm_stack,
        msg_str=BROADCAST_INS,
    )

    # Create new shared memory passing the mask to use
    assert app.round_parameters is not None, "Round parameters must be initialized"

    # Set parameters to the local shared memory
    set_parameters_shm(
        list(
            compress_with_strict(
                data=app.round_parameters,
                selectors=mask[0],
                strict=True,
            ),
        ),
        parameters,
        parameter_pos=range(len(parameters)),
    )
    # Create the reply of success
    recordset = RecordSet()
    recordset.configs_records["broadcast"] = ConfigsRecord({"status": "OK"})
    return msg.create_reply(content=recordset)


def gather_parameters(msg: Message, ctx: Context) -> Message:  # noqa: ARG001
    """Collect and return trained parameters to the server.

    This function handles requests to gather trained parameters from the node manager.
    It applies the appropriate aggregation mask, handles any batching configuration
    for the communication stack, and constructs a response message with the parameters
    and training metrics. For batched communication, it only returns metrics with the
    first batch.

    Parameters
    ----------
    msg : Message
        Incoming message requesting parameters, containing batch configuration.
    ctx : Context
        Execution context passed by Flower.

    Returns
    -------
    Message
        Reply message containing the requested parameters and, for the first batch,
        training metrics.

    """
    # Obtain the aggregation mask from the NodeManagerApp. We need to copy it as we may
    # change it in-place during the construction of the message
    aggregation_mask = copy.deepcopy(app.aggregation_mask)

    # Create new shared memory passing the mask to use
    assert app.trained_parameters is not None, "Round parameters must be initialized"

    # Obtain which batch ID has been requested
    batch_id_requested = cast(
        "int",
        msg.content.configs_records[f"{GATHER_INS}.{COMM_STACK}"][BATCH_ID],
    )
    # Don't return any train metric if this is not the first batch
    metrics_to_return = {} if batch_id_requested > 0 else app.training_metrics
    # If the communication stack is configured to use batches, we need to create
    # a batch of messages to send to each node
    if app.cfg.repo.comm_stack.n_batches:
        # Create chunks fo the transmission mask based on the number of chunks
        batch_of_masks = mask_to_batches(
            full_mask=aggregation_mask,
            layer_names_and_types=app.layer_names_and_types,
            n_batches=app.cfg.repo.comm_stack.n_batches,
        )
        # Get the mask requested
        aggregation_mask = batch_of_masks[batch_id_requested]
    # Generate the recordset for the gather results
    recordset = gather_res_to_recordset(
        metrics=metrics_to_return,
        num_examples=app.training_samples,
        status=Status(code=Code.OK, message=CHIAPPE_SODE),
        keep_input=True,
    )
    recordset.configs_records[f"{GATHER_RES}.{COMM_STACK}"] = ConfigsRecord(
        {
            ENDPOINT_ID: str(uuid.uuid4()),
            FOLDER_NAME: COMM_STACK,
            FILE_NAME: PARAMETERS,
            MASK: repr(aggregation_mask),
        },
    )
    # Add the parameters to the message
    gather_recordset, _c_list_of_ray_object_refs = (
        offload_recordset_parameters_to_remote(
            parameters=list(
                compress_with_strict(
                    data=app.trained_parameters,
                    selectors=aggregation_mask[0],
                    strict=True,
                ),
            ),
            remote_uploader_downloader=app.remote_up_down,
            outgoing_recordset=recordset,
            comm_stack=app.cfg.repo.comm_stack,
            msg_str=GATHER_RES,
        )
    )
    # Return the response message
    return msg.create_reply(content=gather_recordset)


def free_resources(msg: Message, ctx: Context) -> Message:  # noqa: ARG001
    """Free resources for this client node.

    Parameters
    ----------
    msg : Message
        Incoming message requesting to free resources.
    ctx : Context
        Execution context passed by Flower.

    Returns
    -------
    Message
        Reply message confirming resource deallocation.

    """
    return msg.create_reply(content=msg.content)


@app.train()
def train(msg: Message, ctx: Context) -> Message:  # noqa: ARG001
    """Process training messages from the server.

    This function handles training requests by retrieving instructions from the
    incoming message, refreshing workers if needed, executing the training process,
    and returning the results. It manages worker lifecycle based on configured
    refresh periods and creates a response message with appropriate time-to-live.

    Parameters
    ----------
    msg : Message
        Incoming training message from the server containing configurations.
    ctx : Context
        Execution context passed by Flower.

    Returns
    -------
    Message
        Reply message confirming training execution with updated time-to-live.

    """
    start_time = time.time()
    # Retrieve the training instructions from the message
    fitins = recordset_to_fitins(msg.content, keep_input=False)
    config = fitins.config
    assert SERVER_ROUND in config, "Server round must be in the config"
    # Refresh the workers if the round is divisible by the refresh period
    if (int(config[SERVER_ROUND]) + 1) % app.refresh_period == 0:
        app.close_workers()
        app.create_and_start_workers()
    # Execute the training process
    app.fit(
        configs=msg.content.configs_records,
    )
    # Compose the success message
    status = Status(code=Code.OK, message=CHIAPPE_SODE)
    fitres = FitRes(
        status=status,
        parameters=Parameters(tensors=[], tensor_type="empty"),
        metrics={},
        num_examples=0,
    )
    recordset = fitres_to_recordset(fitres, keep_input=False)
    ttl = EXPECTED_LATENCY + int(time.time() - start_time)
    # Return the success message
    return msg.create_reply(
        content=recordset,
        ttl=ttl,
    )


@app.evaluate()
def evaluate(msg: Message, ctx: Context) -> Message:  # noqa: ARG001
    """Process evaluation messages from the server.

    This function handles evaluation requests by retrieving instructions from the
    incoming message, executing the evaluation process, and returning the results.
    It passes the configurations to the app's eval method and constructs a response
    containing evaluation metrics, loss, and the number of examples processed.

    Parameters
    ----------
    msg : Message
        Incoming evaluation message from the server containing configurations.
    ctx : Context
        Execution context passed by Flower.

    Returns
    -------
    Message
        Reply message containing evaluation results including loss, metrics,
        and number of examples processed.

    """
    # Retrieve the evaluation instructions from the message
    evaluateins = recordset_to_evaluateins(msg.content, keep_input=False)
    config = evaluateins.config
    assert SERVER_ROUND in config, "Server round must be in the config"
    # Execute the evaluation process
    loss, num_examples, metrics = app.eval(configs=msg.content.configs_records)
    # Compose the success message with the evaluation results
    status = Status(code=Code.OK, message=CHIAPPE_SODE)
    evaluateres = EvaluateRes(
        status=status,
        loss=loss,
        metrics=metrics,
        num_examples=num_examples,
    )
    recordset = evaluateres_to_recordset(evaluateres)
    recordset.configs_records[f"{EVALUATE_RESULTS}.{COMM_STACK}"] = ConfigsRecord(
        {
            ENDPOINT_ID: app.node_manager_uuid + NM_PARAMETERS_SHM,
            FOLDER_NAME: COMM_STACK,
            FILE_NAME: PARAMETERS,
        },
    )
    # Return the success message
    return msg.create_reply(recordset)


@app.query()
def query(msg: Message, ctx: Context) -> Message:
    """Process query messages from the server.

    This function handles different types of query messages by dispatching them
    to the appropriate handler function based on the query type. Supported query
    types include 'broadcast_parameters', 'gather_parameters', and 'free_resources'.

    Parameters
    ----------
    msg : Message
        Incoming query message from the server.
    ctx : Context
        Execution context passed by Flower.

    Returns
    -------
    Message
        Reply message with the result of processing the query.

    Raises
    ------
    ValueError
        If an unknown query type is received.

    """
    content = msg.content
    assert QUERY in content.configs_records, "Query message must contain 'query' key"
    query_type = content.configs_records[QUERY][TYPE]
    match query_type:
        case "broadcast_parameters":
            return set_parameters(msg=msg, ctx=ctx)
        case "gather_parameters":
            return gather_parameters(msg=msg, ctx=ctx)
        case "free_resources":
            return free_resources(msg=msg, ctx=ctx)
    error_msg = f"Unknown query_type: {query_type!s}."
    raise ValueError(error_msg)
