"""Utility functions related to file objects."""

import pickle  # noqa: S403
from pathlib import Path
from typing import Any

import numpy as np
from composer.loggers import RemoteUploaderDownloader
from flwr.common import (
    NDArrays,
    parameters_to_ndarrays,
)

from repo.constants import CURRENT_CODEBASE_NAME, OLD_CODEBASE_NAME


def download_file_from_s3(
    remote_up_down: RemoteUploaderDownloader,
    remote_file_name: str,
    local_file_name: Path | str,
) -> None:
    """Download a file from S3."""
    remote_up_down._check_workers()  # noqa: SLF001
    remote_up_down.download_file(
        remote_file_name=remote_file_name,
        destination=str(local_file_name),
        overwrite=True,
    )


def upload_file_to_s3(
    remote_up_down: RemoteUploaderDownloader,
    remote_file_name: str,
    local_file_name: Path,
) -> None:
    """Download a file from S3."""
    remote_up_down._check_workers()  # noqa: SLF001
    remote_up_down.upload_file(
        state=None,
        remote_file_name=remote_file_name,
        file_path=local_file_name,
        overwrite=True,
    )


def load_model_parameters_from_file(file_path: Path) -> NDArrays:
    """Load model parameters from a file.

    Parameters
    ----------
    file_path : Path
        The file path.

    Returns
    -------
    NDArrays
        The model parameters.

    Raises
    ------
    ValueError
        If the file format is not supported.

    """
    if file_path.suffix in {".npz", ".npzc"}:
        with file_path.open("rb") as file:
            data = np.load(file)
            return [data[key] for key in data.files]
    elif file_path.suffix == ".bin":
        with file_path.open("rb") as file:
            return parameters_to_ndarrays(pickle.load(file))  # noqa: S301
    else:
        msg = f"Unsupported file format: {file_path.suffix}"
        raise ValueError(msg)


def dump_model_parameters_to_file(file_path: Path, model_parameters: NDArrays) -> None:
    """Load model parameters from a file.

    Parameters
    ----------
    file_path : Path
        The file path.
    model_parameters : NDArrays
        The model parameters.

    Raises
    ------
    ValueError
        If the file format is not supported.

    """
    # NOTE: Very slow for big models b/c compression. Good benchmark available here: https://stackoverflow.com/questions/30329726/fastest-save-and-load-options-for-a-numpy-array
    if file_path.suffix == ".npzc":
        with file_path.open("wb") as file:
            np.savez_compressed(file, *model_parameters)
    elif file_path.suffix == ".bin":
        with file_path.open("wb") as file:
            pickle.dump(model_parameters, file)
    elif file_path.suffix == ".npz":
        with file_path.open("wb") as file:
            np.savez(file, *model_parameters)
    else:
        msg = f"Unsupported file format: {file_path.suffix}"
        raise ValueError(msg)


class RenameUnpickler(pickle.Unpickler):  # noqa: S301
    """Custom unpickler to rename modules when loading pickled objects."""

    def find_class(self, module: str, name: str) -> Any:  # noqa: ANN401
        """Find and return the class specified by the module and name.

        Parameters
        ----------
        module : str
            The module name.
        name : str
            The class name.

        Returns
        -------
        Any
            The class object.

        """
        renamed_module = module
        if OLD_CODEBASE_NAME in module:
            renamed_module = module.replace(OLD_CODEBASE_NAME, CURRENT_CODEBASE_NAME)

        return super().find_class(renamed_module, name)


def custom_pickle_load(
    file_obj: Any,  # noqa: ANN401
) -> Any:  # noqa: ANN401
    """Read a pickled object from the file object.

    Parameters
    ----------
    file_obj : Any
        The file object to read from.

    Returns
    -------
    Any
        The unpickled object.

    """
    return RenameUnpickler(file_obj).load()


def create_remote_up_down(  # noqa: PLR0913
    bucket_name: str,
    prefix: str,
    run_uuid: str | None,
    num_attempts: int,
    client_config: dict[str, Any],
    *,
    num_concurrent_uploads: int = 1,
    upload_staging_folder: str | None = None,  # Don't touch, it's /tmp by default
    use_procs: bool = True,
) -> RemoteUploaderDownloader:
    """Create the remote uploader/downloader.

    Parameters
    ----------
    bucket_name : str
        The name of the bucket.
    prefix : str
        The prefix of the bucket.
    run_uuid : str | None
        The UUID of the run.
    num_attempts : int
        The number of attempts.
    client_config : dict[str, Any]
        The configuration of the client.
    num_concurrent_uploads : int, optional
        The number of concurrent uploads, by default 1.
    upload_staging_folder : str | None, optional
        The upload staging folder, dont't touch, by default None.
    use_procs : bool, optional
        Whether to use processes, by default True. Don't touch.

    Returns
    -------
    RemoteUploaderDownloader
        The remote uploader/downloader.

    """
    bucket_uri = f"s3://{bucket_name}"
    remote_up_down = RemoteUploaderDownloader(
        bucket_uri=bucket_uri,
        backend_kwargs={
            "bucket": bucket_name,
            "prefix": prefix,  # Don't touch
            "region_name": None,  # Not necessary
            "endpoint_url": None,  # Will be read from env var
            "aws_access_key_id": None,  # Will be read from config file
            "aws_secret_access_key": None,  # Will be read from config file
            "aws_session_token": None,  # Will be automatically generated
            "client_config": client_config,  # And using defaults
            "transfer_config": None,  # Using defaults
        },
        file_path_format_string="{remote_file_name}",  # Don't touch
        num_concurrent_uploads=num_concurrent_uploads,
        upload_staging_folder=upload_staging_folder,  # Don't touch, default: /tmp
        use_procs=use_procs,  # Don't touch
        num_attempts=num_attempts,
    )
    remote_up_down.init(run_name=run_uuid)  # Don't touch
    return remote_up_down
