"""Utility functions for running S3-related tasks on main server loop in flwr next."""

import operator
import pickle  # noqa: S403
import random
import re
import time
from itertools import groupby
from logging import DEBUG, INFO
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any

from composer import Trainer
from composer.loggers import RemoteUploaderDownloader
from composer.utils import (
    S3ObjectStore,
    maybe_create_object_store_from_uri,
    parse_uri,
)
from composer.utils.file_helpers import (
    list_remote_objects,
    validate_given_remote_path,
)
from flwr.common import (
    ConfigsRecord,
    NDArrays,
    log,
)
from omegaconf import OmegaConf

from repo.conf.base_schema import BaseConfig, S3CommConfig
from repo.constants import (
    CURRENT_SERVER_STATE_FILENAME,
    ENDPOINT_ID,
    FIRST_MOMENTUM,
    MODEL_PARAMETERS,
    SECOND_MOMENTUM,
)
from repo.file_utils import (
    create_remote_up_down,
    custom_pickle_load,
    download_file_from_s3,
    dump_model_parameters_to_file,
    upload_file_to_s3,
)
from repo.server.client_sampler import ClientSampler
from repo.server.param_scheduler_dispatcher import dispatch_model_state_scheduler
from repo.server.utils import (
    ServerState,
    check_server_state_lightweight_dict,
    create_server_state_lightweight_dict,
    read_server_state_lightweight_dict,
)
from repo.strategy.dispatcher import dispatch_strategy
from repo.utils import (
    ClientState,
    ModelStateNames,
    get_initial_parameters,
    load_model_parameters_from_file,
    set_trainer_params_from_ndarrays,
)
from repo.wandb_history import WandbHistory


class NoCheckpointsFoundError(Exception):
    """Exception raised when there are no checkpoints in the path looked up."""


def delete_object(
    object_path: str,
) -> None:
    """Delete a local/remote object from a specified path.

    This function attempts to delete the specified object from the provided path. If the
    path is a remote S3 path, it uses the `delete_remote_object` function to delete the
    object. If the path is a local path, it deletes the local file using the `unlink`
    method.

    Parameters
    ----------
    object_path : str
        The path to the object to be deleted. This can be either a local path or a
        remote S3 path.

    Example
    -------
    >>> delete_object("s3://my-bucket/my-folder/myfile.txt")
    >>> delete_object("/local/path/to/myfile.txt")

    Notes
    -----
    This function uses the `delete_remote_object` function to delete objects from a
    remote S3 path. For local paths, it uses the `Path.unlink` method to delete local
    files.

    """
    try:
        delete_remote_object(object_path)
    except ValueError:
        # Local path
        Path(object_path).unlink()


def list_objects(
    run_uuid_path: str,
) -> tuple[bool, list[str]]:
    """List objects in a local/(remote s3) given path.

    This function attempts to list objects in the specified path. If the path is remote
    S3 path, it uses the `list_remote_objects` function to list the objects. If the path
    is a local path, it lists all files recursively within the directory.

    Parameters
    ----------
    run_uuid_path : str
        The path to list objects from. This can be either a local path or a remote S3
        path.

    Returns
    -------
    tuple[bool, list[str]]
        A tuple where the first element is a boolean indicating whether the path is
        remote (True for remote, False for local), and the second element is a list of
        object paths.

    Example
    -------
    >>> is_remote, objects = list_objects("s3://my-bucket/my-folder")
    >>> print(is_remote)
    True
    >>> print(objects)
    ['s3://my-bucket/my-folder/file1.txt', 's3://my-bucket/my-folder/file2.txt']

    >>> is_remote, objects = list_objects("/local/path/to/folder")
    >>> print(is_remote)
    False
    >>> print(objects)
    ['/local/path/to/folder/file1.txt', '/local/path/to/folder/file2.txt']

    """
    try:
        return True, list_remote_objects(run_uuid_path)
    except ValueError:
        # Local path
        return False, [str(p) for p in Path(run_uuid_path).rglob("*") if p.is_file()]


def extract_s3_comm_config_from_configrecord(
    s3_comm_config: ConfigsRecord,
) -> tuple[str, str, str]:
    """Extract S3 communication configuration details from a ConfigsRecord object.

    This function parses a ConfigsRecord object containing S3 communication
    configuration and extracts essential information required for S3 operations.
    Specifically, it retrieves the `endpoint_id`, `file_name`, and `current_round`
    from the ConfigsRecord. These values are crucial for identifying the correct S3
    bucket and path, and for versioning or round-specific operations.

    Parameters
    ----------
    s3_comm_config : ConfigsRecord
        A ConfigsRecord object containing the S3 communication configuration. Expected
        to have keys for `endpoint_id`, `file_name`, and `current_round`.

    Returns
    -------
    tuple[str, str, str]
        A tuple containing `endpoint_id`, `file_name`, and `current_round` as strings.

    Raises
    ------
    ValueError
        If any of the required keys (`endpoint_id`, `file_name`, or `current_round`) are
        missing from the ConfigsRecord.

    Notes
    -----
    The function ensures that all returned values are strings, even if they are provided
    as different types in the ConfigsRecord. This standardization facilitates their use
    in S3 operations without further type checking or conversion.

    """
    # Extract endpoint id from the content of the message
    endpoint_id: Any
    if ENDPOINT_ID in s3_comm_config:
        endpoint_id = str(s3_comm_config[ENDPOINT_ID])
    else:
        msg = f"{ENDPOINT_ID} is not present in the message"
        raise ValueError(msg)
    file_name: Any
    if "file_name" in s3_comm_config:
        file_name = str(s3_comm_config["file_name"])
    else:
        msg = "file_name is not present in the message"
        raise ValueError(msg)
    folder_name: Any
    if "folder_name" in s3_comm_config:
        folder_name = str(s3_comm_config["folder_name"])
    else:
        msg = "folder_name is not present in the message"
        raise ValueError(msg)
    return endpoint_id, file_name, folder_name


def interpret_resume_round(
    resume_round: int | None,
    run_uuid_path: str,
    state_keys: tuple[str, ...],
    *,
    raise_error: bool = True,
) -> int | None:
    """Interpret the resume round parameter for server checkpoint resumption.

    This function interprets the `resume_round` parameter, which specifies the round
    to resume server operations from. If `resume_round` is negative, it is treated as
    an index into the list of sorted rounds obtained from the server's path, allowing
    for reverse indexing. If `resume_round` is None, the function returns None,
    indicating no specific round to resume from. An error is raised if no checkpoints
    are found when `raise_error` is True and `resume_round` is negative but no rounds
    are available.

    Parameters
    ----------
    resume_round : int | None
        The round number to resume from. If negative, treated as a reverse index. If
        None, indicates no resumption is required.
    run_uuid_path : str
        The path to the run uuid root.
    state_keys : tuple[str, ...]
        A tuple of state keys used to identify the federated rounds.
    raise_error : bool, optional
        Whether to raise an error if no checkpoints are found and `resume_round` is
        negative. Default is True.

    Returns
    -------
    int | None
        The interpreted round number to resume from, or None if no resumption.

    Raises
    ------
    NoCheckpointsFoundError
        If `raise_error` is True, no checkpoints are found, and `resume_round` < 0.

    """
    log(
        DEBUG,
        "The parameter `resume_round=%s` will be interpret as an index "
        "for the list of rounds for the run_uuid_path=%s",
        resume_round,
        run_uuid_path,
    )
    if resume_round is None:
        return None
    if resume_round < 0:
        server_round_indices = obtain_sorted_runs(run_uuid_path, state_keys)
        log(DEBUG, "Found server round indices %s", server_round_indices)
        if not server_round_indices and raise_error:
            raise NoCheckpointsFoundError
        if server_round_indices:
            resume_round = server_round_indices[resume_round]
    return resume_round


def upload_server_state(
    server_state: ServerState,
    remote_up_down: RemoteUploaderDownloader,
    local_checkpoint_path: str,
) -> None:
    """Upload the server state to the S3 Object Store.

    This function saves the server state to a local file and then uploads it to the S3
    Object Store. The server state includes the history (metrics, etc.), current round,
    elapsed time, client states, and server step counts. It creates a local file path
    based on the checkpoint path and current round, dumps the state to this file in
    binary format using pickle, and then uploads the file to S3.

    Parameters
    ----------
    server_state : ServerState
        The state of the server, including history, current round, elapsed time,
        client states, and server step counts.
    remote_up_down : RemoteUploaderDownloader
        The uploader/downloader object for interacting with the S3 Object Store
    local_checkpoint_path : str
        The local path where checkpoint files are temporarily stored before upload

    """
    # Create the server state dictionary containing light stuff
    current_server_state = create_server_state_lightweight_dict(server_state)
    check_server_state_lightweight_dict(
        current_server_state,
        raise_if_not_lightweight=True,
    )
    # Dump the server state to disk and upload to the S3 Object Store
    local_path = (
        Path(local_checkpoint_path)
        / "server"
        / f"{server_state.current_round}"
        / CURRENT_SERVER_STATE_FILENAME
    )
    local_path.parent.mkdir(parents=True, exist_ok=True)
    if not local_path.exists():
        with (local_path).open("wb") as f:
            pickle.dump(current_server_state, f)
    log(DEBUG, "Push server state to S3")
    upload_file_to_s3(
        remote_up_down,
        f"{server_state.current_round}/{CURRENT_SERVER_STATE_FILENAME}",
        local_path,
    )


def upload_momentum_vector(
    current_round: int,
    momentum_vector: NDArrays,
    remote_up_down: RemoteUploaderDownloader,
    local_checkpoint_path: str,
    *,
    is_second_momentum: bool = False,
) -> None:
    """Upload momentum vector to the S3 Object Store.

    This function saves a momentum vector to a local file and then uploads it to the S3
    Object Store. It creates a local file path based on the checkpoint path and current
    round, dumps the momentum vector to this file in npz format, and then uploads the
    file to S3. The function measures and logs the time taken to dump the vector to
    disk. It can handle both first and second momentum vectors based on the
    is_second_momentum flag.

    Parameters
    ----------
    current_round : int
        The current training round number, used for organizing files in the S3 bucket.
    momentum_vector : NDArrays
        The momentum vector to upload.
    remote_up_down : RemoteUploaderDownloader
        The uploader/downloader object for interacting with the S3 Object Store.
    local_checkpoint_path : str
        The local path where checkpoint files are temporarily stored before upload.
    is_second_momentum : bool, optional
        Flag indicating whether this is the second momentum vector, by default False.
        If True, uses SECOND_MOMENTUM as filename; if False, uses FIRST_MOMENTUM.

    """
    log(DEBUG, "Dump momentum vector to disk")
    filename_no_ext = FIRST_MOMENTUM if not is_second_momentum else SECOND_MOMENTUM
    local_path = (
        Path(local_checkpoint_path)
        / "server"
        / f"{current_round}"
        / f"{filename_no_ext}.npz"
    )
    local_path.parent.mkdir(parents=True, exist_ok=True)
    dump_mom_vec_time = time.time()
    if not local_path.exists():
        dump_model_parameters_to_file(
            local_path,
            momentum_vector,
        )
    log(
        DEBUG,
        "Push momentum vector to S3 Object Store. Time to dump to disk: %s",
        time.time() - dump_mom_vec_time,
    )
    upload_file_to_s3(
        remote_up_down,
        f"{current_round}/{filename_no_ext}.npz",
        local_path,
    )


def upload_model_parameters(
    parameters: NDArrays,
    current_round: int,
    remote_up_down: RemoteUploaderDownloader,
    local_checkpoint_path: str,
) -> None:
    """Upload model parameters to the S3 Object Store.

    This function saves model parameters to a local file and then uploads them to the S3
    Object Store. It first creates a local file path based on the checkpoint path and
    current round, dumps the model parameters to this file in npz format, and then
    uploads the file to S3. The function measures and logs the time taken to dump the
    parameters to disk.

    Parameters
    ----------
    parameters : NDArrays
        The model parameters to upload.
    current_round : int
        The current training round number, used for organizing files in the S3 bucket.
    remote_up_down : RemoteUploaderDownloader
        The uploader/downloader object for interacting with the S3 Object Store.
    local_checkpoint_path : str
        The local path where checkpoint files are temporarily stored before upload.

    """
    dump_model_time = time.time()
    local_path = (
        Path(local_checkpoint_path)
        / "server"
        / f"{current_round}"
        / f"{MODEL_PARAMETERS}.npz"
    )
    local_path.parent.mkdir(parents=True, exist_ok=True)
    if not local_path.exists():
        dump_model_parameters_to_file(
            local_path,
            parameters,
        )
    log(
        DEBUG,
        "Push parameters to S3 Object Store. Time to dump to disk: %s",
        time.time() - dump_model_time,
    )
    upload_file_to_s3(
        remote_up_down,
        f"{current_round}/{MODEL_PARAMETERS}.npz",
        local_path,
    )


def upload_server_checkpoint(
    server_state: ServerState,
) -> None:
    """Upload the server checkpoint to the S3 Object Store.

    This function uploads the server checkpoint, including model parameters and
    potentially other state information, to the S3 Object Store. It uploads the server
    state (history, current round, time elapsed, client states), momentum vectors
    (if present), and model parameters. Each component is saved to disk first and then
    uploaded to the S3 Object Store under appropriate paths based on the current round.

    Parameters
    ----------
    server_state : ServerState
        The server state object containing all information to be checkpointed,
        including model parameters, momentum vectors, history, client states,
        and a remote_up_down object for S3 communication.

    Raises
    ------
        AssertionError
            If remote_up_down is None in the server state.

    """
    # Retrieve remote_up_down object
    remote_up_down = server_state.remote_up_down
    assert remote_up_down is not None, "remote_up_down is None"
    # Uploading the server state
    if server_state.history and server_state.client_states:
        upload_server_state(
            server_state=server_state,
            remote_up_down=remote_up_down,
            local_checkpoint_path=server_state.local_checkpoint_path,
        )
    # Dump and upload momentum vector if present
    if server_state.momentum_vector is not None:
        upload_momentum_vector(
            current_round=server_state.current_round,
            momentum_vector=server_state.momentum_vector,
            remote_up_down=remote_up_down,
            local_checkpoint_path=server_state.local_checkpoint_path,
        )
    # Dump and upload second momentum vector if present
    if server_state.second_momentum_vector is not None:
        upload_momentum_vector(
            current_round=server_state.current_round,
            momentum_vector=server_state.second_momentum_vector,
            remote_up_down=remote_up_down,
            is_second_momentum=True,
            local_checkpoint_path=server_state.local_checkpoint_path,
        )
    # Dump and upload model parameters
    upload_model_parameters(
        parameters=server_state.model_states,
        current_round=server_state.current_round,
        remote_up_down=remote_up_down,
        local_checkpoint_path=server_state.local_checkpoint_path,
    )


def get_model_parameters_checkpoint(
    cfg: BaseConfig,
    remote_up_down: RemoteUploaderDownloader,
    server_path: str,
    local_checkpoint_path: Path,
    timeout: float = 0.5,
) -> tuple[NDArrays, tuple[tuple[str, ModelStateNames], ...]]:
    """Retrieve model parameters from local storage or S3 Object Store.

    This function loads model parameters from either local storage or the S3 Object
    Store. It first attempts to find parameter files locally, and if not found,
    downloads them from S3. The function supports both .bin and .npz file formats for
    model parameters. It also ensures that the loaded parameters have the same structure
    as the initial parameters expected by the model.

    Parameters
    ----------
    cfg : BaseConfig
        Configuration parameters containing resume round information.
    remote_up_down : RemoteUploaderDownloader
        The uploader/downloader object for interacting with the S3 Object Store.
    server_path : str
        The S3 path where server checkpoint files are stored.
    local_checkpoint_path : Path
        The local path where checkpoint files are stored or downloaded to.
    timeout : float, optional
        Time to wait between attempts to validate remote paths, by default 0.5.

    Returns
    -------
    tuple[NDArrays, tuple[tuple[str, ModelStateNames], ...]]
        A tuple containing:
        - The model parameters loaded from the checkpoint
        - A tuple of (layer_name, layer_type) pairs defining the model state structure

    Raises
    ------
        AssertionError
            If the number of tensors in the loaded parameters doesn't match the
            initial parameters.

    Notes
    -----
    The function first checks for local parameter files (.bin or .npz) and if not found,
    attempts to download them from S3. It will poll for the availability of remote files
    with a specified timeout interval until they are found.

    """
    # Initialize the model parameters and layer names mapping
    initial_model_states, layer_names_and_types = get_initial_parameters(cfg)
    # Check for local paths first
    checkpoint_model_states: NDArrays | None = None
    if (Path(local_checkpoint_path) / f"{MODEL_PARAMETERS}.bin").exists():
        checkpoint_model_states = load_model_parameters_from_file(
            Path(local_checkpoint_path) / f"{MODEL_PARAMETERS}.bin",
        )
        log(
            DEBUG,
            "Loaded model parameters from local path: %s",
            Path(local_checkpoint_path) / f"{MODEL_PARAMETERS}.bin",
        )
    if (Path(local_checkpoint_path) / f"{MODEL_PARAMETERS}.npz").exists():
        checkpoint_model_states = load_model_parameters_from_file(
            Path(local_checkpoint_path) / f"{MODEL_PARAMETERS}.npz",
        )
        log(
            DEBUG,
            "Loaded model parameters from local path: %s",
            Path(local_checkpoint_path) / f"{MODEL_PARAMETERS}.npz",
        )
    if not checkpoint_model_states:
        file_found = False
        # Check whether the server model states exist
        remote_file_name_no_ext = (
            server_path + f"{cfg.repo.resume_round}/{MODEL_PARAMETERS}"
        )
        # Look for remote paths
        while not file_found:
            file_found = validate_given_remote_path(
                remote_file_name_no_ext + ".bin",
            ) or validate_given_remote_path(remote_file_name_no_ext + ".npz")
            time.sleep(timeout)
        # Set the server model states file names depending on the extension found
        remote_file_name = (
            f"{cfg.repo.resume_round}/{MODEL_PARAMETERS}.bin"
            if validate_given_remote_path(remote_file_name_no_ext + ".bin")
            else f"{cfg.repo.resume_round}/{MODEL_PARAMETERS}.npz"
        )
        local_file_name = (
            Path(local_checkpoint_path) / f"{MODEL_PARAMETERS}.bin"
            if validate_given_remote_path(remote_file_name_no_ext + ".bin")
            else Path(local_checkpoint_path) / f"{MODEL_PARAMETERS}.npz"
        )
        # Download the model states
        download_file_from_s3(remote_up_down, remote_file_name, local_file_name)
        checkpoint_model_states = load_model_parameters_from_file(local_file_name)
        log(DEBUG, "Loaded model parameters from remote path: %s", remote_file_name)
    assert len(initial_model_states) == len(
        checkpoint_model_states,
    ), "Randomly initialized parameters and checkpoints must have the same length."
    return checkpoint_model_states, layer_names_and_types


def get_server_state_checkpoint(
    cfg: BaseConfig,
    remote_up_down: RemoteUploaderDownloader,
    local_checkpoint_path: Path,
) -> tuple[
    WandbHistory,
    int,
    float,
    int,
    dict[str | int, ClientState],
]:
    """Retrieve server state from local storage or S3 Object Store.

    This function loads the server state checkpoint, which contains essential training
    information such as history metrics, current round, elapsed time, and client states.
    It first attempts to load the state from a local file and, if not found, downloads
    it from the S3 Object Store. The function also handles backward compatibility with
    older checkpoint formats by providing default values for missing attributes.

    Parameters
    ----------
    cfg : BaseConfig
        Configuration parameters containing resume round information and federated
        learning settings.
    remote_up_down : RemoteUploaderDownloader
        The uploader/downloader object for interacting with the S3 Object Store.
    local_checkpoint_path : Path
        The local path where checkpoint files are stored or downloaded to.

    Returns
    -------
    tuple[WandbHistory, int, float, int, dict[str | int, ClientState]]
        A tuple containing the following server state components:
        - WandbHistory: The history object containing metrics and other logging data
        - int: The starting training round number
        - float: The elapsed time offset to resume from
        - int: The cumulative server steps count
        - dict[str | int, ClientState]: A dictionary mapping client IDs to their
        respective states

    Raises
    ------
        AssertionError
            If the round number in the checkpoint doesn't match the expected resume
            round.

    Notes
    -----
    This function maintains partial compatibility across different ClientState
    implementations by dynamically determining which attributes to load based on the
    current ClientState class signature. When client state information is missing from
    the checkpoint, the function creates default client states with zero cumulative
    steps.

    """
    # Check for local paths first
    if (local_checkpoint_path / CURRENT_SERVER_STATE_FILENAME).exists():
        with (local_checkpoint_path / CURRENT_SERVER_STATE_FILENAME).open("rb") as f:
            server_state = custom_pickle_load(f)
        log(
            DEBUG,
            "Loaded server state from local path: %s",
            local_checkpoint_path / CURRENT_SERVER_STATE_FILENAME,
        )
    else:
        # Download the server state from S3 Object Store
        download_file_from_s3(
            remote_up_down,
            f"{cfg.repo.resume_round}/{CURRENT_SERVER_STATE_FILENAME}",
            str(local_checkpoint_path / CURRENT_SERVER_STATE_FILENAME),
        )
        with (local_checkpoint_path / CURRENT_SERVER_STATE_FILENAME).open("rb") as f:
            server_state = custom_pickle_load(f)
        log(
            DEBUG,
            "Loaded server state from remote path: %s",
            f"{cfg.repo.resume_round}/{CURRENT_SERVER_STATE_FILENAME}",
        )
    # Check the consistency of the server state
    check_server_state_lightweight_dict(
        server_state,
        raise_if_not_lightweight=True,
    )
    # Return the server state components
    return read_server_state_lightweight_dict(
        server_state,
        n_local_steps=cfg.fl.n_local_steps,
        n_total_clients=cfg.fl.n_total_clients,
    )


def get_server_momentum_checkpoint(
    cfg: BaseConfig,
    remote_up_down: RemoteUploaderDownloader,
    server_path: str,
    local_checkpoint_path: Path,
    momentum_type: str,
) -> NDArrays | None:
    """Retrieve momentum vector from local storage or S3 Object Store.

    This function attempts to load a momentum vector (either first or second) from
    available  sources. It first checks for a local server state file and, if not found,
    downloads it from  the S3 Object Store. It then tries to extract the momentum vector
    from this server state. If the vector is not present in the server state, the
    function looks for a dedicated momentum vector file locally or in S3, downloading it
    if necessary.

    Parameters
    ----------
    cfg : BaseConfig
        Configuration parameters containing S3 bucket information, run UUID, and
        resume round information.
    remote_up_down : RemoteUploaderDownloader
        The uploader/downloader object for interacting with the S3 Object Store.
    server_path : str
        The S3 path where server checkpoint files are stored.
    local_checkpoint_path : Path
        The local path where checkpoint files are stored or downloaded to.
    momentum_type : str
        The type of momentum vector to retrieve (FIRST_MOMENTUM or SECOND_MOMENTUM).

    Returns
    -------
    NDArrays | None
        The retrieved momentum vector as NDArrays, or None if no momentum vector was
        found.

    Notes
    -----
    The function prioritizes finding the momentum vector in the following order:
    1. In the server state file (if it contains a "momentum" field)
    2. In a local file dedicated to the specific momentum type
    3. In a remote S3 file dedicated to the specific momentum type

    """
    # Check for local paths first
    if (local_checkpoint_path / CURRENT_SERVER_STATE_FILENAME).exists():
        with (local_checkpoint_path / CURRENT_SERVER_STATE_FILENAME).open("rb") as f:
            server_state = custom_pickle_load(f)
    else:
        # Download the server state from S3 Object Store
        download_file_from_s3(
            remote_up_down,
            f"{cfg.repo.resume_round}/{CURRENT_SERVER_STATE_FILENAME}",
            str(local_checkpoint_path / CURRENT_SERVER_STATE_FILENAME),
        )
        log(DEBUG, "Read server state from disk")
        with (local_checkpoint_path / CURRENT_SERVER_STATE_FILENAME).open("rb") as f:
            server_state = custom_pickle_load(f)

    momentum_vector: NDArrays | None = None
    remote_file_name_momentum = (
        server_path + f"{cfg.repo.resume_round}/{momentum_type}.npz"
    )
    if "momentum" in server_state:
        log(DEBUG, "Get momentum vector from server state")
        momentum_vector = server_state["momentum"]
    if not momentum_vector:
        # Check for local paths first
        local_file_name = Path(local_checkpoint_path) / f"{momentum_type}.npz"
        if local_file_name.exists():
            momentum_vector = load_model_parameters_from_file(local_file_name)
        elif validate_given_remote_path(remote_file_name_momentum):
            # Set the file names depending on the extension found
            remote_file_name = f"{cfg.repo.resume_round}/{momentum_type}.npz"
            # Download the parameters
            download_file_from_s3(remote_up_down, remote_file_name, local_file_name)
            momentum_vector = load_model_parameters_from_file(local_file_name)
    return momentum_vector


def get_server_checkpoint(
    cfg: BaseConfig,
    remote_up_down: RemoteUploaderDownloader,
    repo_save_path: str,
    timeout: float = 0.5,
) -> ServerState:
    """Load server checkpoint from S3 or local storage.

    This function retrieves all necessary components of a server checkpoint, including
    model parameters, server state information, and momentum vectors. It first attempts
    to find checkpoint files locally, and if not found, downloads them from the S3
    Object Store. The components are then assembled into a ServerState object for
    resuming federated learning training from a previous checkpoint.

    Parameters
    ----------
    cfg : BaseConfig
        Configuration parameters containing S3 bucket information, run UUID, and
        resume round information.
    remote_up_down : RemoteUploaderDownloader
        The uploader/downloader object for interacting with the S3 Object Store.
    repo_save_path : str
        The local path where checkpoint files are stored or downloaded to.
    timeout : float, optional
        Time to wait between attempts to validate remote paths, by default 0.5.

    Returns
    -------
    ServerState
        A fully initialized ServerState object containing the model parameters,
        training history, momentum vectors, client states, and other necessary
        information to resume training.

    Notes
    -----
    The function loads multiple components:
    - Model parameters and layer names/types
    - Server state (history, start round, time offset, client state)
    - First momentum vector (if available)
    - Second momentum vector (if available)

    Each component is loaded by calling specialized helper functions that handle both
    local and remote checkpoint retrieval.

    """
    # Set the path to server checkpoints
    server_path = f"s3://{cfg.s3_comm_config.bucket_name}/{cfg.run_uuid}/server/"
    # Get the local checkpoint path
    local_checkpoint_path = (
        Path(repo_save_path) / "server" / f"{cfg.repo.resume_round}"
    )

    # Model states
    model_states, layer_names_and_types = get_model_parameters_checkpoint(
        cfg=cfg,
        remote_up_down=remote_up_down,
        server_path=server_path,
        local_checkpoint_path=local_checkpoint_path,
        timeout=timeout,
    )

    # Server state
    (
        history,
        start_round,
        time_offset,
        server_steps_cumulative,
        client_states,
    ) = get_server_state_checkpoint(
        cfg=cfg,
        remote_up_down=remote_up_down,
        local_checkpoint_path=local_checkpoint_path,
    )

    # First momentum vector
    momentum_vector = get_server_momentum_checkpoint(
        cfg=cfg,
        remote_up_down=remote_up_down,
        server_path=server_path,
        local_checkpoint_path=local_checkpoint_path,
        momentum_type=FIRST_MOMENTUM,
    )

    # Second momentum vector
    second_momentum_vector = get_server_momentum_checkpoint(
        cfg=cfg,
        remote_up_down=remote_up_down,
        server_path=server_path,
        local_checkpoint_path=local_checkpoint_path,
        momentum_type=SECOND_MOMENTUM,
    )

    log(INFO, "Server checkpoint loaded")
    return ServerState(
        model_states=model_states,
        strategy=dispatch_strategy(
            cfg,
        ),
        transmission_scheduler=dispatch_model_state_scheduler(cfg),
        aggregation_mask_scheduler=dispatch_model_state_scheduler(cfg),
        history=history,
        time_offset=time_offset,
        server_steps_cumulative=server_steps_cumulative,
        client_states=client_states,
        momentum_vector=momentum_vector,
        second_momentum_vector=second_momentum_vector,
        sampled_clients=[],
        previously_sampled_clients=[],
        rng=random.Random(cfg.seed),  # noqa: S311
        start_round=start_round,
        current_round=start_round,
        current_time_elapsed=0.0,
        layer_names_and_types=layer_names_and_types,
        remote_up_down=remote_up_down,
        local_checkpoint_path=repo_save_path,
        client_sampler=ClientSampler(
            number_of_clients_per_round=cfg.fl.n_clients_per_round,
            total_number_of_clients=cfg.fl.n_total_clients,
            dropout_ratio=cfg.fl.dropout_ratio,
            dropout_function_name=cfg.fl.dropout_function_name,
        ),
    )


def get_file_from_path(
    input_file_path: str,
    run_uuid: str,
    s3_comm_config: S3CommConfig,
    output_parent_path: str,
) -> Path:
    """Retrieve a file from a given path, which can be either a local path or an S3 URI.

    This function interprets the input file path to determine whether it's a local path
    or an S3 URI. If it is an S3 URI, the function downloads the file from the specified
    S3 bucket to a temporary directory. If it is a local path, the function verifies the
    existence of the file. The function returns the local path to the file.

    Parameters
    ----------
    input_file_path : str
        The path to the file to be retrieved. Can be a local file path or an S3 URI.
    run_uuid : str
        The unique identifier for the run, used for S3 operations.
    s3_comm_config : S3CommConfig
        The S3 communication configuration.
    output_parent_path : str
        The output parent path for storing the downloaded file.

    Returns
    -------
    Path
        The local file path to the retrieved file.

    Raises
    ------
    ValueError
        If the backend specified in the URI is unknown.

    """
    # Interpret the URI
    backend, bucket_name, remote_file_name = parse_uri(input_file_path)
    local_file_path: Path | None = None
    if backend == "s3":
        log(
            INFO,
            "Downloading model %s from S3 bucket %s",
            remote_file_name,
            bucket_name,
        )
        # Create RemoteUploaderDownloader object
        remote_up_down = create_remote_up_down(
            bucket_name=bucket_name,
            prefix="",
            run_uuid=run_uuid,
            num_attempts=5,
            client_config=OmegaConf.to_container(
                s3_comm_config.backend_kwargs.client_config,
            ),  # type: ignore[reportArgumentType, arg-type]
        )
        local_file_path = Path(output_parent_path) / (
            "checkpoint" + Path(remote_file_name).suffix
        )
        download_file_from_s3(remote_up_down, remote_file_name, local_file_path)
    elif not backend:
        log(
            INFO,
            "File path %s is local.",
            Path(input_file_path),
        )
        local_file_path = Path(input_file_path)
    else:
        msg = f"Unknown backend: {backend}"
        raise ValueError(msg)
    assert local_file_path is not None, "Local file path is None"
    assert local_file_path.exists(), f"Local file path {local_file_path} does not exist"
    return local_file_path


def load_pretrained_model_from_path(
    pretrained_model_path: str,
    run_uuid: str,
    s3_comm_config: S3CommConfig,
    trainer: Trainer,
) -> None:
    """Load a pretrained model from a specified path and set it to the trainer.

    This function supports loading models from both local file paths and S3 URIs.
    It downloads the model if the path points to an S3 bucket, and then sets the
    parameters in the provided trainer object.

    Parameters
    ----------
    pretrained_model_path : str
        Path to pre-trained model
    run_uuid : str
        The unique identifier for the run, used for S3 operations.
    s3_comm_config : S3CommConfig
        The S3 communication configuration.
    trainer : Trainer
        The trainer object to which the pretrained model parameters are set.

    """
    log(
        INFO,
        "Loading pretrained model from %s",
        pretrained_model_path,
    )
    # Create a temporary directory for storing the downloaded parameters
    with TemporaryDirectory() as tmp_dir:
        # Load the local file path
        local_file_path = get_file_from_path(
            output_parent_path=tmp_dir,
            input_file_path=pretrained_model_path,
            run_uuid=run_uuid,
            s3_comm_config=s3_comm_config,
        )
        initial_parameters = load_model_parameters_from_file(local_file_path)
        set_trainer_params_from_ndarrays(
            initial_parameters,
            trainer,
            excluded_layers=[],
        )


def get_num_batches_from_checkpoint_name(checkpoint_name: str) -> int:
    """Extract the number of batches from the checkpoint name.

    The checkpoint name is expected to be in the format:
    ep{n_epochs}-ba{n_batches}-rank{rank}.pt

    Parameters
    ----------
    checkpoint_name : str
        The name of the checkpoint file.

    Returns
    -------
        int: The number of batches extracted from the checkpoint name.

    Raises
    ------
        ValueError: If the checkpoint name does not match the expected format.

    """
    match = re.search(r"-ba(\d+)-", checkpoint_name)
    if match:
        return int(match.group(1))
    msg = f"Invalid checkpoint name format: {checkpoint_name}"
    raise ValueError(msg)


def obtain_sorted_runs(run_uuid_path: str, state_keys: tuple[str, ...]) -> list[int]:
    """Obtain the sorted runs from the server path.

    This function lists the objects in the specified run UUID path, extracts unique run
    numbers from the paths under `{run_uuid_path}/server/`, and filters them based on
    the provided state keys. It returns the sorted list of valid run numbers.

    Parameters
    ----------
    run_uuid_path : str
        The path to the run UUID root.
    state_keys : tuple[str, ...]
        The state keys to check in the paths. Keys are intended to be the prefixes of
        any file name. For example, if the keys are ("state", "model"), then the
        function will return any path that starts with "state" and "model".

    Returns
    -------
    list[int]
        The sorted list of valid run numbers.

    Example
    -------
    >>> run_uuid_path = "s3://my-bucket/my-folder"
    >>> state_keys = ("state", "model")
    >>> sorted_runs = obtain_sorted_runs(run_uuid_path, state_keys)
    >>> print(sorted_runs)
    [1, 2, 3]

    Notes
    -----
    This function uses the `list_objects` function to list objects in the given path.
    It extracts unique run numbers from the paths under `{run_uuid_path}/server/` and
    filters them based on the provided state keys.

    """
    _is_remote, remote_objects = list_objects(run_uuid_path)

    # Extract unique run numbers
    run_numbers = {
        int(reg.group(1))
        for path in remote_objects
        if (reg := re.search(r"server/(\d+)/.*$", path)) is not None
    }

    valid_runs = set()

    for run in run_numbers:
        # Filter paths for the current run
        run_paths = [path for path in remote_objects if f"server/{run}/" in path]

        # Check if all state_keys are present in the paths for this run
        if all(
            any(state_key in path for path in run_paths) for state_key in state_keys
        ):
            valid_runs.add(run)

    return sorted(valid_runs)


def delete_past_communication_states(
    bucket_name: str,
    prefix: str,
) -> None:
    """Delete previous communication state files from S3 storage.

    This function removes all .npz files from the communication stack folder in the
    specified S3 bucket and prefix path. It first lists all objects in the communication
    stack folder, filters for .npz files, and then deletes each matching file. This is
    typically used to clean up stale communication state files after they are no longer
    needed, preventing unnecessary storage usage and costs.

    Parameters
    ----------
    bucket_name : str
        The name of the S3 bucket containing the communication state files.
    prefix : str
        The prefix (folder path) within the bucket where communication state files
        are stored.

    Notes
    -----
    This function only removes files with the .npz extension, which are typically
    NumPy compressed array files used for storing model parameters or other numeric
    data during federated learning communications.

    """
    # Compose remote URI
    remote_uri = f"s3://{bucket_name}/{prefix}/comm_stack"
    # List all the object in the communication stack folder
    remote_objects = list_remote_objects(remote_uri)
    # Remove non-.npz paths from the list of files
    remote_objects_to_remove = [cro for cro in remote_objects if cro.endswith(".npz")]
    for object_to_remove in remote_objects_to_remove:
        # Parse the URI to extract the backend and bucket name
        backend, bucket_name, _prefix = parse_uri(remote_uri)
        delete_object(f"{backend}://{bucket_name}/{object_to_remove}")


def delete_clients_checkpoints(run_uuid_path: str, end_idx: int | None = -1) -> None:
    """Delete client checkpoints from a specified path. Can be either a local or remote.

    This function deletes the specified client checkpoints from the provided run UUID
    path. It lists all the objects in the path, extracts unique client IDs, and removes
    the corresponding checkpoints for each client based on the `end_idx` parameter. The
    function supports both local and remote S3 paths.

    Parameters
    ----------
    run_uuid_path : str
        The path to the run UUID, which can be either a local path or a remote S3 path.
    end_idx : int, optional
        The index up to which checkpoints should be deleted. Defaults to -1, which means
        all checkpoints except for the last.

    Example
    -------
    >>> delete_clients_checkpoints("s3://my-bucket/my-folder", end_idx=5)
    >>> delete_clients_checkpoints("/local/path/to/folder", end_idx=5)

    Notes
    -----
    This function uses the `list_objects` function to list objects in the given path.
    For remote S3 paths, it uses the `delete_object` function to delete objects from the
    S3 bucket. For local paths, it uses the `delete_object` function to delete local
    files.

    """
    # List all the remote objects in the run UUID path
    _is_remote, remote_objects = list_objects(run_uuid_path)
    # Extract unique client IDs from the remote objects
    unique_client_ids = {
        int(reg.group(1))
        for path in remote_objects
        if (reg := re.search(r"client_(\d+)/.*$", path)) is not None
    }
    for client_id in unique_client_ids:
        # List all the remote objects for the client
        is_remote, client_remote_objects = list_objects(
            f"{run_uuid_path}/client_{client_id}/",
        )
        # Remove symlinks from the list of files
        client_remote_objects = [
            cro for cro in client_remote_objects if not cro.endswith(".symlink")
        ]
        # Sort by number of batches the client trained on
        sorted_client_objects = sorted(
            client_remote_objects,
            key=get_num_batches_from_checkpoint_name,
        )
        # Delete only the last `end_idx` checkpoints
        objects_to_remove = sorted_client_objects[:end_idx]
        for object_to_remove in objects_to_remove:
            if is_remote:
                # Parse the URI to extract the backend and bucket name
                backend, bucket_name, _prefix = parse_uri(run_uuid_path)
                delete_object(f"{backend}://{bucket_name}/{object_to_remove}")
            else:
                delete_object(object_to_remove)


def delete_rounds(
    run_uuid_path: str,
    state_keys: tuple[str, ...],
    end_idx: int | None = -1,
) -> None:
    """List objects in a local/(remote s3) given path.

    This function attempts to list objects in the specified path. If the path is remote
    S3 path, it uses the `list_remote_objects` function to list the objects. If the path
    is a local path, it lists all files recursively within the directory.

    Parameters
    ----------
    run_uuid_path : str
        The path to list objects from. This can be either a local path or a remote S3
        path.
    state_keys : tuple[str, ...]
        The state keys to check in the paths. Keys are intended to be the prefixes of
        any file name. For example, if the keys are ("state", "model"), then the
        function will return any path that starts with "state" and "model".
    end_idx : int, optional
        The index up to which rounds should be deleted. Defaults to -1, which means all

    Example
    -------
    >>> is_remote, objects = list_objects("s3://my-bucket/my-folder")
    >>> print(is_remote)
    True
    >>> print(objects)
    ['s3://my-bucket/my-folder/file1.txt', 's3://my-bucket/my-folder/file2.txt']

    >>> is_remote, objects = list_objects("/local/path/to/folder")
    >>> print(is_remote)
    False
    >>> print(objects)
    ['/local/path/to/folder/file1.txt', '/local/path/to/folder/file2.txt']

    Notes
    -----
    This function uses the `list_remote_objects` function to list objects in a remote
    S3 path. For local paths, it uses the `Path.rglob` method to recursively list all
    files in the directory.

    """
    # List all the federated rounds in the run UUID path
    sorted_rounds = obtain_sorted_runs(run_uuid_path, state_keys)
    # Delete only the last `end_idx` rounds
    rounds_to_delete = sorted_rounds[:end_idx]
    for round_to_delete in rounds_to_delete:
        # List all the remote objects for the server at the round specified
        is_remote, objects_to_remove = list_objects(
            f"{run_uuid_path}/server/{round_to_delete}/",
        )
        # Remove the objects found
        for object_to_remove in objects_to_remove:
            if is_remote:
                # Parse the URI to extract the backend and bucket name
                backend, bucket_name, _prefix = parse_uri(run_uuid_path)
                delete_object(f"{backend}://{bucket_name}/{object_to_remove}")
            else:
                delete_object(object_to_remove)


def delete_remote_object(object_name: str) -> None:
    """Delete an object from an S3 bucket.

    This function deletes an object from an S3 bucket using the provided object name.
    It creates an S3 object store from the object name, parses the URI to extract the
    prefix, and then deletes the object from the object store.

    Parameters
    ----------
    object_name : str
        The name of the object to delete.

    Raises
    ------
        ValueError: If the object name is not a valid URI or if the object cannot be
            deleted.

    """
    # Create an object store from the object name
    object_store: S3ObjectStore | None = maybe_create_object_store_from_uri(object_name)  # type: ignore[assignment,reportAssignmentType]
    if object_store is None:
        msg = f"Invalid object name: {object_name}"
        raise ValueError(msg)
    # Parse the URI to extract the prefix to use as the key to delete the file
    _backend, _bucket_name, prefix = parse_uri(object_name)
    # Delete the object from the object store
    object_store.client.delete_object(
        Bucket=object_store.bucket,
        Key=object_store.get_key(prefix),
    )


def copy_old_checkpoints_to_new_run(
    remote_up_down: RemoteUploaderDownloader,
    bucket_uri: str,
    restore_run_data: tuple[str, str, int, int],
    n_total_clients: int | None,
    *,
    copy_client_checkpoints: bool = True,
) -> None:
    """Copy old checkpoints to the new run folder.

    Parameters
    ----------
    remote_up_down : RemoteUploaderDownloader
        The remote uploader/downloader object.
    bucket_uri : str
        The URI of the bucket.
    restore_run_data : tuple[str, str, int, int]
        A tuple containing the run UUID, restore run UUID, restore run round, and
        restore run step.
    n_total_clients : int | None
        The total number of clients expected in the new run.
    copy_client_checkpoints : bool, optional
        A flag indicating whether to copy the client checkpoints. Defaults to True.

    Raises
    ------
        NotImplementedError: If the backend is not an S3ObjectStore.
        ValueError: If the old run folder or the new run folder is not found.

    """
    run_uuid, restore_run_uuid, restore_run_round, restore_run_step = restore_run_data
    if not isinstance(remote_up_down.remote_backend, S3ObjectStore):
        msg = "Support for resuming from non-S3 backends is not yet implemented."
        raise NotImplementedError(
            msg,
        )

    new_run_folder = bucket_uri + f"/{run_uuid}"
    old_run_folder = bucket_uri + f"/{restore_run_uuid}"
    if (old_run_val := validate_given_remote_path(old_run_folder)) and (
        _new_run_val := validate_given_remote_path(new_run_folder)
    ):
        state_bin = (
            restore_run_uuid
            + f"/server/{restore_run_round}/{CURRENT_SERVER_STATE_FILENAME}"
        )

        momentum_vec = (
            old_run_folder + f"/server/{restore_run_round}/{FIRST_MOMENTUM}.npz"
        )

        parameters_no_ext = (
            old_run_folder + f"/server/{restore_run_round}/{MODEL_PARAMETERS}"
        )
        parameters = (
            parameters_no_ext.replace(bucket_uri + "/", "") + ".bin"
            if validate_given_remote_path(parameters_no_ext + ".bin")
            else (parameters_no_ext.replace(bucket_uri + "/", "") + ".npz")
        )

        # Extract the client and the batches
        # NOTE: (?:\d+) means a do-not-capture group
        # As such we allow any number of epochs without extracting
        # The number of epochs
        client_path_batches = sorted(
            [
                (
                    path,
                    int(reg.group(1)),
                    int(reg.group(2)),
                )
                for path in list_remote_objects(old_run_folder)
                if (reg := re.search(r"client_(\d+)/ep(?:\d+)-ba(\d+)", path))
                is not None
            ],
            key=operator.itemgetter(1, 2),
        )

        # For each client, choose the latest checkpoint
        # That is consistent with the step of the resume round
        # groupby acts like an sql groupby
        client_paths = [
            list(filter(lambda x: x[2] <= restore_run_step, group))[-1][0]
            for _, group in groupby(client_path_batches, key=operator.itemgetter(1))
        ]

        if (
            copy_client_checkpoints
            and n_total_clients is not None
            and (found_clients := len(client_paths)) != n_total_clients
        ):
            msg = (
                f"Found {found_clients} clients in the old run folder {old_run_folder},"
                f" but expected {n_total_clients}."
            )
            raise ValueError(
                msg,
            )

        paths_to_copy = [state_bin, parameters]

        if validate_given_remote_path(momentum_vec):
            paths_to_copy.append(momentum_vec.replace(bucket_uri + "/", ""))
        else:
            log(
                DEBUG,
                f"Could not find momentum vector to copy from {momentum_vec}",
            )

        if copy_client_checkpoints:
            paths_to_copy.extend(client_paths)

        for path in paths_to_copy:
            target_key = path.replace(restore_run_uuid, run_uuid)
            log(DEBUG, "Copying %s to %s", path, target_key)
            remote_up_down.remote_backend.client.copy(
                {"Bucket": remote_up_down.remote_backend.bucket, "Key": path},
                remote_up_down.remote_backend.bucket,
                target_key,
            )

    else:
        if not old_run_val:
            msg = (
                f"Could not find the old run folder {old_run_folder} to copy"
                " checkpoints."
            )
            raise ValueError(
                msg,
            )
        msg = f"Could not find the new run folder {new_run_folder} to copy checkpoints."
        raise ValueError(
            msg,
        )


def cleanup_checkpoints(
    server_state: ServerState,
    run_uuid: str,
    strategy_state_keys: tuple[str, ...],
    end_idx: int | None = -1,
) -> None:
    """Clean up checkpoints from S3 Object Store and local storage.

    This function handles cleanup of both client and server checkpoints. It removes old
    checkpoints from the S3 Object Store based on the provided end_idx parameter, which
    specifies which checkpoints to keep. It also cleans up the local checkpoint path to
    free up disk space. This function is useful for periodic cleanup during training to
    avoid excessive storage usage.

    Parameters
    ----------
    server_state : ServerState
        The server state object containing the local checkpoint path.
    run_uuid : str
        The unique identifier for the run whose checkpoints need to be cleaned up.
    strategy_state_keys : tuple[str, ...]
        The keys used to identify strategy state files in the checkpoints.
    end_idx : int | None, optional
        The index up to which checkpoints should be deleted. Defaults to -1, which means
        all checkpoints except for the latest one will be removed.

    Notes
    -----
    This function performs cleanup in both remote S3 storage and local storage. It uses
    the underlying functions `delete_clients_checkpoints` and `delete_rounds` to perform
    the actual deletion operations.

    """
    # Clean up checkpoints if asked to

    # Remove old clients checkpoints from the S3 Object Store
    delete_clients_checkpoints(
        run_uuid_path=f"s3://checkpoints/{run_uuid}",
        end_idx=end_idx,
    )
    # Remove old server checkpoints from the S3 Object Store
    delete_rounds(
        run_uuid_path=f"s3://checkpoints/{run_uuid}",
        state_keys=(CURRENT_SERVER_STATE_FILENAME, *strategy_state_keys),
        end_idx=end_idx,
    )
    # Remove old server checkpoints from the S3 Object Store
    delete_rounds(
        run_uuid_path=server_state.local_checkpoint_path,
        state_keys=(CURRENT_SERVER_STATE_FILENAME, *strategy_state_keys),
        end_idx=end_idx,
    )
