"""Communication utilities for efficient model parameter transfer in federated training.

This module implements optimized communication methods for transferring large model
parameters between distributed components using different backends. It provides
functionality to offload parameters from recordsets to remote storage and load them back
as needed, supporting three communication methods:

1. S3 object storage: Parameters are serialized, stored in S3, and replaced with
   references in the recordset.
2. Shared Memory (shm): Parameters are stored in shared memory accessible by
   multiple processes on the same machine.
3. Ray object store: Parameters are stored in Ray's distributed object store
   with optimized conversions.

The module optimizes large parameter transfers by:
- Processing arrays in batches to balance parallelism with overhead
- Converting between precision formats (float16/float32) as appropriate
- Optimizing serialization operations
- Using vectorized operations where possible

These optimizations help reduce memory usage, network transfer times, and overall
communication overhead in distributed federated learning workflows.

Functions:
    offload_recordset_parameters_to_remote: Move parameters from recordset to remote
        storage
    load_recordset_parameters_from_remote: Retrieve parameters from remote storage into
        recordset
"""

import ast
import time
from copy import deepcopy
from logging import DEBUG, ERROR
from multiprocessing.shared_memory import SharedMemory
from pathlib import Path
from queue import Queue
from tempfile import TemporaryDirectory

import numpy as np
from composer.loggers import RemoteUploaderDownloader
from composer.utils.file_helpers import (
    validate_given_remote_path,
)
from flwr.common import (
    Code,
    ConfigsRecord,
    Message,
    NDArrays,
    Parameters,
    RecordSet,
    log,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.recordset_compat import (
    _extract_status_from_recordset,  # noqa: PLC2701
    parameters_to_parametersrecord,
    parametersrecord_to_parameters,
)

import ray
from repo.conf.base_schema import CommStack
from repo.constants import COMM_STACK, ENDPOINT_ID, PARAMETERS
from repo.file_utils import (
    download_file_from_s3,
    dump_model_parameters_to_file,
    upload_file_to_s3,
)
from repo.server.s3_utils import (
    extract_s3_comm_config_from_configrecord,
)
from repo.shm.utils import (
    ModelParametersMetadata,
    get_parameters_shm,
    is_shm_existing,
    set_parameters_shm,
)
from repo.utils import (
    batch_convert_and_put_float16,
    convert_to_float32,
    load_model_parameters_from_file,
)


def offload_to_s3(
    parameters: NDArrays,
    remote_uploader_downloader: RemoteUploaderDownloader,
    recordset: RecordSet,
    msg_str: str,
) -> RecordSet:
    """Offload model parameters to S3 object storage.

    This function serializes model parameters and uploads them to S3 object storage,
    replacing the original parameters in the recordset with empty placeholders. It
    creates a temporary local file, uploads it to S3, and then verifies the upload was
    successful before returning. This approach reduces memory usage during transmission
    and enables cross-node parameter sharing through persistent cloud storage.

    Parameters
    ----------
    parameters : NDArrays
        The model parameters to offload to S3 storage.
    remote_uploader_downloader : RemoteUploaderDownloader
        The uploader/downloader instance for interacting with S3.
    recordset : RecordSet
        The recordset containing the message configuration and parameters.
    msg_str : str
        A string identifier used to prefix keys in the recordset.

    Returns
    -------
    RecordSet
        The modified recordset with emptied parameters and S3 reference information.

    Notes
    -----
    This function optimizes data offloading by:
    1. Using temporary local files to minimize memory overhead during serialization
    2. Extracting endpoint and file information from the recordset for consistent naming
    3. Employing polling to verify successful upload completion
    4. Emptying the recordset parameters after upload to reduce memory usage

    """
    # Create a temporary directory for storing the downloaded parameters
    with TemporaryDirectory() as temp_dir:
        # Extract endpoint id from the content of the message
        endpoint_id, file_name, folder_name = extract_s3_comm_config_from_configrecord(
            recordset.configs_records[f"{msg_str}.{COMM_STACK}"],
        )
        # Set the file names
        local_file_name = Path(temp_dir) / f"{endpoint_id}_{file_name}.npz"
        dump_model_parameters_to_file(
            local_file_name,
            parameters,
        )
        # Upload the parameters to S3 Object Store
        upload_file_to_s3(
            remote_up_down=remote_uploader_downloader,
            remote_file_name=f"{folder_name}/{endpoint_id}/{file_name}.npz",
            local_file_name=local_file_name,
        )
        # Empty the recordset parameters
        recordset.parameters_records[f"{msg_str}.parameters"] = (
            parameters_to_parametersrecord(
                Parameters(tensors=[], tensor_type="empty"),
                keep_input=False,
            )
        )

        # Check whether the server has uploaded the parameters
        remote_file_name_no_ext = (
            f"s3://{remote_uploader_downloader.remote_bucket_name}/"
            f"{remote_uploader_downloader.backend_kwargs['prefix']}/"
            f"{folder_name}/{endpoint_id}/{file_name}"
        )
        while not (
            validate_given_remote_path(
                remote_file_name_no_ext + ".bin",
            )
            or validate_given_remote_path(remote_file_name_no_ext + ".npz")
        ):
            time.sleep(0.5)
        log(
            DEBUG,
            "Node %s parameters have been pushed to the S3",
            endpoint_id,
        )
    return recordset


def offload_to_shm(
    parameters: NDArrays,
    recordset: RecordSet,
    msg_str: str,
) -> RecordSet:
    """Offload model parameters to shared memory.

    This function stores model parameters in shared memory that can be accessed by
    multiple processes on the same machine. It creates or reuses shared memory segments,
    writes the parameters to them, and replaces the recordset's parameters with metadata
    references to the shared memory. This approach optimizes memory usage by allowing
    different processes to access the same physical memory.

    Parameters
    ----------
    parameters : NDArrays
        The model parameters to offload to shared memory.
    recordset : RecordSet
        The recordset containing the message configuration and parameters.
    msg_str : str
        A string identifier used to prefix keys in the recordset.

    Returns
    -------
    RecordSet
        The modified recordset with emptied parameters and references to the shared
        memory.

    Notes
    -----
    This function optimizes data offloading by:
    1. Storing parameters in shared memory to avoid duplicate memory allocation
    2. Handling cases where existing shared memory is too small by recreating it
    3. Emptying the recordset parameters to reduce memory usage after transfer
    4. Storing essential metadata in the recordset for later reconstruction

    """
    # Get parameters metadata
    parameters_metadata = ModelParametersMetadata.from_ndarrays(parameters)
    # Create the parameters shared memory
    shm_name = str(
        recordset.configs_records[f"{msg_str}.{COMM_STACK}"][ENDPOINT_ID],
    )
    try:
        shm_parameters, _shm_parameters_sh = get_parameters_shm(
            parameters_metadata=parameters_metadata,
            create=not is_shm_existing(shm_name),
            name=shm_name,
        )
    except TypeError:
        log(DEBUG, "offload_to_shm: Shared memory too small, replacing it.")
        try:
            shm = SharedMemory(name=shm_name)
            shm.close()
            shm.unlink()
        except Exception as e:  # noqa: BLE001
            if "[Errno 2] No such file or directory" not in str(e):
                log(
                    ERROR,
                    "Removing Shared Memory %s failed because of %s",
                    shm_name,
                    e,
                )
        shm_parameters, _shm_parameters_sh = get_parameters_shm(
            parameters_metadata=parameters_metadata,
            create=True,
            name=shm_name,
        )
    # Set the parameters in the shared memory
    set_parameters_shm(
        old_parameters_sh=shm_parameters,
        new_parameters=parameters,
        parameter_pos=range(len(parameters)),
    )
    # Empty the recordset parameters
    recordset.parameters_records[f"{msg_str}.parameters"] = (
        parameters_to_parametersrecord(
            Parameters(tensors=[], tensor_type="empty"),
            keep_input=False,
        )
    )
    # Serialize ModelParametersMetadata and set it in the recordset
    parameters_metadata_dict = parameters_metadata.__dict__
    parameters_metadata_dict["dtypes"] = [
        str(v) for v in parameters_metadata_dict["dtypes"]
    ]
    recordset.configs_records[f"{msg_str}.parameters_metadata"] = ConfigsRecord(
        {"metadata": str(parameters_metadata_dict)},
    )
    return recordset


def offload_to_ray(
    parameters: NDArrays,
    recordset: RecordSet,
    msg_str: str,
) -> tuple[RecordSet, list[ray.ObjectRef] | None]:
    """Offload model parameters to Ray's distributed object store.

    This function converts model parameters to float16 precision and stores them in
    Ray's distributed object store. It processes arrays in parallel using Ray remote
    functions and serializes references to these objects in the recordset. The original
    parameters in the recordset are then replaced with empty parameters to reduce
    memory usage.

    Parameters
    ----------
    parameters : NDArrays
        The model parameters to offload to Ray's object store.
    recordset : RecordSet
        The recordset containing the message configuration and parameters.
    msg_str : str
        A string identifier used to prefix keys in the recordset.

    Returns
    -------
    tuple[RecordSet, list[ray.ObjectRef] | None]
        A tuple containing the modified recordset with emptied parameters and
        a list of Ray object references to the stored parameters.

    Notes
    -----
    This function optimizes data offloading by:
    1. Converting parameters to float16 to reduce memory usage and transfer size
    2. Processing arrays in parallel using Ray remote functions
    3. Using Ray's distributed object store for efficient cross-node communication
    4. Preserving object references for later retrieval

    """
    # Batch process all arrays in parallel
    # Avoiding separate convert and put operations by combining them
    ray_object_refs = []
    # Create futures for each item in the ndarrays list
    futures = [batch_convert_and_put_float16.remote(ndarray) for ndarray in parameters]
    for future in futures:
        array_ref = ray.get(future)
        ray_object_refs.append(array_ref)

    # Put serialized ObjectRef in recordset using vectorized operation
    recordset.configs_records[f"{msg_str}.ray"] = ConfigsRecord(
        {
            "object_refs": ray.cloudpickle.dumps(ray_object_refs),  # type: ignore[reportAttributeAccessIssue]
        },
    )
    # Empty the recordset parameters
    recordset.parameters_records[f"{msg_str}.parameters"] = (
        parameters_to_parametersrecord(
            Parameters(tensors=[], tensor_type="empty"),
            keep_input=False,
        )
    )
    return recordset, ray_object_refs


def offload_recordset_parameters_to_remote(
    parameters: NDArrays,
    remote_uploader_downloader: RemoteUploaderDownloader | None,
    outgoing_recordset: RecordSet,
    msg_str: str,
    comm_stack: CommStack,
) -> tuple[RecordSet, list[ray.ObjectRef] | None]:
    """Offload model parameters to remote storage based on communication configuration.

    This function serves as a dispatcher that moves model parameters from a recordset to
    the appropriate remote storage based on the communication stack configuration. It
    supports offloading to S3 object storage, shared memory, or Ray's distributed object
    store. The function delegates to specialized handlers for each storage type and
    returns the modified recordset with appropriate parameter references replacing the
    original large parameter data.

    Parameters
    ----------
    parameters : NDArrays
        The model parameters to offload to remote storage.
    remote_uploader_downloader : RemoteUploaderDownloader | None
        The uploader/downloader instance for interacting with S3, or None if S3 is not
        used.
    outgoing_recordset : RecordSet
        The recordset containing the message configuration and parameters to be
        modified.
    msg_str : str
        A string identifier used to prefix keys in the recordset.
    comm_stack : CommStack
        Configuration specifying which communication method(s) to use.

    Returns
    -------
    tuple[RecordSet, list[ray.ObjectRef] | None]
        A tuple containing the modified recordset with emptied parameters and references
        to the remote storage, and a list of Ray object references if Ray communication
        is used (or None otherwise).

    Notes
    -----
    This function optimizes parameter offloading by:
    1. Selecting the most appropriate storage backend based on configuration
    2. Using specialized handlers optimized for each storage type
    3. Emptying the recordset parameters after offloading to reduce memory usage
    4. Preserving Ray object references for later garbage collection when needed

    """
    # Initialize ray_object_refs to None
    ray_object_refs = None

    # Dispatch to the appropriate handler based on communication protocol
    if comm_stack.s3 and remote_uploader_downloader is not None:
        outgoing_recordset = offload_to_s3(
            parameters,
            remote_uploader_downloader,
            outgoing_recordset,
            msg_str,
        )
    elif comm_stack.shm:
        outgoing_recordset = offload_to_shm(parameters, outgoing_recordset, msg_str)
    elif comm_stack.ray:
        outgoing_recordset, ray_object_refs = offload_to_ray(
            parameters,
            outgoing_recordset,
            msg_str,
        )
    else:
        outgoing_recordset.parameters_records[f"{msg_str}.{PARAMETERS}"] = (
            parameters_to_parametersrecord(
                ndarrays_to_parameters(parameters),
                keep_input=True,
            )
        )

    # Update the message content with the modified recordset
    return outgoing_recordset, ray_object_refs


def load_from_s3_parameters(
    remote_uploader_downloader: RemoteUploaderDownloader,
    recordset: RecordSet,
    msg_str: str,
) -> tuple[RecordSet, NDArrays]:
    """Load model parameters from S3 object storage.

    This function retrieves model parameters that were previously offloaded to S3
    storage and loads them back into memory. It extracts reference information from the
    recordset, verifies the parameters are available in S3, downloads them to a
    temporary local file, and then loads them into memory. This approach enables
    efficient cross-node parameter sharing through persistent cloud storage.

    Parameters
    ----------
    remote_uploader_downloader : RemoteUploaderDownloader
        The uploader/downloader instance for interacting with S3.
    recordset : RecordSet
        The recordset containing references to the parameters in S3 storage.
    msg_str : str
        A string identifier used to prefix keys in the recordset.

    Returns
    -------
    tuple[RecordSet, NDArrays]
        A tuple containing the original recordset and the retrieved model parameters.

    Notes
    -----
    This function optimizes parameter loading by:
    1. Using temporary local files to minimize memory overhead during deserialization
    2. Polling S3 to ensure parameters are available before attempting download
    3. Supporting multiple file formats (.bin and .npz) for flexibility
    4. Choosing the appropriate file format based on what's available in S3

    """
    # Create a temporary directory for storing the downloaded parameters
    with TemporaryDirectory() as temp_dir:
        # Extract endpoint id from the content of the message
        endpoint_id, file_name, folder_name = extract_s3_comm_config_from_configrecord(
            recordset.configs_records[f"{msg_str}.{COMM_STACK}"],
        )
        # Check whether the server has uploaded the parameters
        remote_file_name_no_ext = (
            f"s3://{remote_uploader_downloader.remote_bucket_name}/"
            f"{remote_uploader_downloader.backend_kwargs['prefix']}/"
            f"{folder_name}/{endpoint_id}/{file_name}"
        )
        while not (
            validate_given_remote_path(
                remote_file_name_no_ext + ".bin",
            )
            or validate_given_remote_path(remote_file_name_no_ext + ".npz")
        ):
            time.sleep(0.5)
        # Set the file names depending on the extension found
        local_file_name = (
            Path(temp_dir) / f"tmp-{endpoint_id}.bin"
            if validate_given_remote_path(remote_file_name_no_ext + ".bin")
            else Path(temp_dir) / f"tmp-{endpoint_id}.npz"
        )
        download_file_from_s3(
            remote_up_down=remote_uploader_downloader,
            remote_file_name=(
                f"{folder_name}/{endpoint_id}/{file_name}.bin"
                if validate_given_remote_path(remote_file_name_no_ext + ".bin")
                else f"{folder_name}/{endpoint_id}/{file_name}.npz"
            ),
            local_file_name=local_file_name,
        )
        log(
            DEBUG,
            "Node %s parameters have been read from disk and assigned to the Message",
            endpoint_id,
        )
        return recordset, load_model_parameters_from_file(local_file_name)


def load_from_shm_parameters(
    recordset: RecordSet,
    msg_str: str,
) -> tuple[RecordSet, NDArrays]:
    """Load model parameters from shared memory.

    This function retrieves model parameters that were previously stored in shared
    memory segments and loads them back into memory. It extracts metadata from the
    recordset, accesses the shared memory segment, and creates numpy arrays backed by
    the shared memory buffers. For safety, it makes a deep copy of the parameters to
    avoid potential memory corruption when shared memory references are lost. It also
    handles error cases such as shared memory being too small.

    Parameters
    ----------
    recordset : RecordSet
        The recordset containing references to parameters in shared memory and metadata.
    msg_str : str
        A string identifier used to prefix keys in the recordset.

    Returns
    -------
    tuple[RecordSet, NDArrays]
        A tuple containing the original recordset and the retrieved model parameters.

    Notes
    -----
    This function optimizes parameter loading by:
    1. Using shared memory to minimize copying between processes on the same machine
    2. Reconstructing parameters using metadata stored in the recordset
    3. Handling cases where shared memory needs to be recreated with the right size
    4. Making a deep copy to avoid issues with dangling references to shared buffers

    """
    # Get parameters metadata
    parameters_metadata_dict = ast.literal_eval(
        str(
            recordset.configs_records[f"{msg_str}.parameters_metadata"]["metadata"],
        ),
    )
    parameters_metadata_dict["dtypes"] = [
        np.dtype(v) for v in parameters_metadata_dict["dtypes"]
    ]
    # Create the parameters shared memory
    shm_name = str(
        recordset.configs_records[f"{msg_str}.{COMM_STACK}"][ENDPOINT_ID],
    )
    try:
        shm_parameters, _shm_parameters_sh = get_parameters_shm(
            parameters_metadata=ModelParametersMetadata(
                **parameters_metadata_dict,
            ),
            create=False,
            name=shm_name,
        )
    except TypeError:
        log(DEBUG, "load_from_shm_parameters: Shared memory too small, replacing it.")
        try:
            shm = SharedMemory(name=shm_name)
            shm.close()
            shm.unlink()
        except Exception as e:  # noqa: BLE001
            if "[Errno 2] No such file or directory" not in str(e):
                log(
                    ERROR,
                    "Removing Shared Memory %s failed because of %s",
                    shm_name,
                    e,
                )
        shm_parameters, _shm_parameters_sh = get_parameters_shm(
            parameters_metadata=ModelParametersMetadata(
                **parameters_metadata_dict,
            ),
            create=True,
            name=shm_name,
        )
    # NOTE: We need to copy since we will lose the reference to the shared memory once
    # this function returns, making invalid all the buffers backing the ndarrays
    return recordset, deepcopy(shm_parameters)


def load_from_ray_parameters(
    recordset: RecordSet,
    msg_str: str,
    ray_gc_queue: Queue[ray.ObjectRef] | None = None,
) -> tuple[RecordSet, NDArrays]:
    """Load model parameters from Ray's distributed object store.

    This function retrieves model parameters that were previously offloaded to Ray's
    distributed object store and processes them for use. It deserializes object
    references from the recordset, retrieves parameters in optimized batches, converts
    them back to float32 precision, and returns them in their original order. This
    approach minimizes memory pressure by processing batches and allows for efficient
    handling of large parameter sets.

    Parameters
    ----------
    recordset : RecordSet
        The recordset containing references to parameters stored in Ray's object store.
    msg_str : str
        A string identifier used to prefix keys in the recordset.
    ray_gc_queue : Queue[ray.ObjectRef] | None, optional
        Queue for tracking object references for garbage collection, or None if not
        used.

    Returns
    -------
    tuple[RecordSet, NDArrays]
        A tuple containing the original recordset and the retrieved model parameters
        as numpy arrays.

    Notes
    -----
    This function optimizes parameter loading by:
    1. Processing object references in batches to manage memory usage
    2. Converting float16 parameters back to float32 precision in parallel
    3. Preserving the original parameter ordering using an index dictionary
    4. Using the Ray garbage collection queue to track and release object references

    """
    # Get serialized ObjectRef from recordset
    serialized_obj_refs = recordset.configs_records[f"{msg_str}.ray"]["object_refs"]

    # Unpickle object references in one go
    obj_refs = ray.cloudpickle.loads(serialized_obj_refs)  # type: ignore[reportAttributeAccessIssue]
    # Building and ordering dictionary
    ordering_dict_obj_refs = {obj_ref: i for i, obj_ref in enumerate(obj_refs)}

    # Get futures for retrieving objects from Ray object store and convert them directly
    # to float32
    float32_ndarrays_dict: dict[int, np.ndarray] = {}
    batch_size = 20  # Adjust based on memory constraints and performance needs

    # Process ray object references in batches
    while obj_refs:
        # Wait for all objects to be ready in the Ray object store
        ready_obj_refs, obj_refs = ray.wait(
            obj_refs,
            num_returns=min(batch_size, len(obj_refs)),
        )
        # Get all objects in the batch at once
        batch_arrays = ray.get(ready_obj_refs)
        # Submit conversion tasks in parallel
        conversion_refs = [convert_to_float32.remote(arr) for arr in batch_arrays]
        # Get all converted arrays at once
        converted_arrays = ray.get(conversion_refs)
        # Convert to float32 and append to result list
        float32_ndarrays_dict.update(
            {
                ordering_dict_obj_refs[obj_ref]: arr
                for obj_ref, arr in zip(ready_obj_refs, converted_arrays, strict=True)
            },
        )
        # If the Ray Garbage Collector has been passed, put the resolved object refs
        # in the queue for garbage collection
        if ray_gc_queue is not None:
            for ready_obj_ref in ready_obj_refs:
                ray_gc_queue.put(ready_obj_ref)
    # Get the ordered list of float32 ndarrays
    float32_ndarrays: list[np.ndarray] = [
        float32_ndarrays_dict[i] for i in range(len(float32_ndarrays_dict))
    ]
    return recordset, float32_ndarrays


def load_recordset_parameters_from_remote(
    remote_uploader_downloader: RemoteUploaderDownloader | None,
    incoming_message: Message,
    msg_str: str,
    comm_stack: CommStack,
    ray_gc_queue: Queue[ray.ObjectRef] | None = None,
) -> tuple[Message, NDArrays]:
    """Load model parameters from remote storage into a message recordset.

    This function retrieves model parameters that were previously offloaded to remote
    storage (S3, shared memory, or Ray's object store) and incorporates them back into
    the message recordset. It first extracts basic parameter information from the
    recordset, checks if the task was successful, and then dispatches to the appropriate
    handler based on the configured communication protocol. This approach allows
    efficient handling of large parameter sets by using specialized storage methods.

    Parameters
    ----------
    remote_uploader_downloader : RemoteUploaderDownloader | None
        The uploader/downloader instance for interacting with S3, or None if S3 is not
        used.
    incoming_message : Message
        The incoming message containing a recordset with parameter references.
    msg_str : str
        A string identifier used to prefix keys in the recordset.
    comm_stack : CommStack
        Configuration specifying which communication method(s) to use.
    ray_gc_queue : Queue[ray.ObjectRef] | None, optional
        Queue for Ray garbage collection, or None if not using Ray.

    Returns
    -------
    tuple[Message, NDArrays]
        A tuple containing the updated message with modified recordset and the loaded
        model parameters.

    Notes
    -----
    This function optimizes parameter loading by:
    1. Extracting minimal parameter information before deciding how to load
    2. Only performing actual loading operations when needed based on status code
    3. Using specialized handlers for each storage method (S3, shared memory, Ray)
    4. Preserving the original message structure while updating its content

    """
    # Extract the content of the incoming message
    recordset = incoming_message.content
    # Get the NDArrays from the parameters
    ndarrays = parameters_to_ndarrays(
        parametersrecord_to_parameters(
            record=recordset.parameters_records[f"{msg_str}.{PARAMETERS}"],
            keep_input=True,
        ),
    )

    # Check if the task was successful
    if _extract_status_from_recordset(msg_str, recordset).code != Code.OK:
        # No translation performed as we assume the task failed
        return incoming_message, []

    # Dispatch to the appropriate handler based on communication protocol
    if comm_stack.s3 and remote_uploader_downloader is not None:
        recordset, ndarrays = load_from_s3_parameters(
            remote_uploader_downloader,
            recordset,
            msg_str,
        )
    elif comm_stack.shm:
        recordset, ndarrays = load_from_shm_parameters(recordset, msg_str)
    elif comm_stack.ray:
        recordset, ndarrays = load_from_ray_parameters(recordset, msg_str, ray_gc_queue)

    # Update the message content with the modified recordset
    incoming_message.content = recordset
    return incoming_message, ndarrays
