"""Utility functions for initialization task on main server loop in flwr next."""

import copy
import operator
import os
import random
import re
from logging import DEBUG, INFO

import numpy as np
from composer.loggers import RemoteUploaderDownloader
from composer.utils.file_helpers import list_remote_objects
from flwr.common import (
    NDArrays,
    log,
)
from omegaconf import OmegaConf

from repo.clients.configs import CentralizedConfig
from repo.clients.llm_client_functions import (
    get_trainer_object,
)
from repo.clients.utils import (
    get_initial_parameters,
    get_optimizer_state_from_trainer,
)
from repo.conf.base_schema import BaseConfig, StrategyName
from repo.constants import CURRENT_SERVER_STATE_FILENAME
from repo.file_utils import create_remote_up_down
from repo.server.client_sampler import ClientSampler
from repo.server.param_scheduler_dispatcher import dispatch_model_state_scheduler
from repo.server.s3_utils import (
    copy_old_checkpoints_to_new_run,
    get_server_checkpoint,
    interpret_resume_round,
    upload_server_checkpoint,
)
from repo.server.utils import (
    ServerState,
)
from repo.strategy.dispatcher import dispatch_strategy
from repo.strategy.utils import initialize_strategy
from repo.utils import (
    ClientState,
    ModelStateNames,
    get_list_of_parameters_names,
    get_parameters_from_state,
)
from repo.wandb_history import WandbHistory


def initialize_server_state(
    cfg: BaseConfig,
    repo_save_path: str,
) -> ServerState:
    """Initialize the server state for federated learning.

    This function prepares the server state by either resuming from a past checkpoint
    or creating a fresh state. It handles the creation of the RemoteUploaderDownloader
    for checkpointing operations, manages state restoration logic, and initializes the
    federated learning strategy.

    Args:
    ----
    cfg : BaseConfig
        The configuration object containing settings for federated learning and system
        behavior, including checkpoint and S3 communication parameters.
    repo_save_path : str
        The local path the server uses to store and load checkpoints.

    Returns:
    -------
    ServerState
        The server state object, either resumed from a checkpoint or newly initialized,
        with the configured strategy.

    Raises:
    ------
    ValueError
        If checkpointing is needed but not enabled, or if a RemoteUploaderDownloader
        is required but not available.

    """
    # Create RemoteUploaderDownloader
    remote_up_down: RemoteUploaderDownloader | None = None
    if cfg.repo.checkpoint or cfg.repo.comm_stack.s3:
        remote_up_down = create_remote_up_down(
            bucket_name=cfg.s3_comm_config.bucket_name,
            prefix=f"{cfg.run_uuid}/server",
            run_uuid=cfg.run_uuid,
            num_attempts=cfg.s3_comm_config.num_attempts,
            client_config=OmegaConf.to_container(
                cfg.s3_comm_config.backend_kwargs.client_config,
            ),  # type: ignore[reportArgumentType, arg-type]
        )
    # Resume experiment from a previously saved checkpoint
    if cfg.repo.resume_round is not None:
        if cfg.repo.checkpoint is None:  # type: ignore[reportUnnecessaryComparison]
            msg = "Cannot resume if `cfg.repo.checkpoint` is None"
            raise ValueError(msg)
        if remote_up_down is None:
            msg = "Cannot resume without a RemoteUploaderDownloader object"
            raise ValueError(msg)
        server_state = resume_from_round(
            cfg=cfg,
            remote_up_down=remote_up_down,
            repo_save_path=repo_save_path,
        )
    else:
        server_state = initialize_federated_learning(
            cfg=cfg,
            remote_up_down=remote_up_down,
            repo_save_path=repo_save_path,
        )
        log(
            DEBUG,
            "Sampled %s Client IDs: %s",
            len(server_state.sampled_clients),
            server_state.sampled_clients,
        )
    # Loop over the PRNG to get to the correct round
    for _ in range(server_state.start_round):
        server_state.previously_sampled_clients = server_state.sampled_clients
        server_state.sampled_clients = server_state.client_sampler.sample_clients(
            rng=server_state.rng,
        )
    if cfg.repo.restore_cent_run_uuid is not None:
        if remote_up_down is None:
            msg = "Cannot restore without a RemoteUploaderDownloader object"
            raise ValueError(msg)
        log(
            INFO,
            "Restoring from a centralized checkpoint with run UUID %s",
            cfg.repo.restore_cent_run_uuid,
        )
        server_state.model_states, server_state.layer_names_and_types = (
            get_centralized_run_parameters(copy.deepcopy(cfg))
        )

    # TODO(Anonymous): Make this useful
    # # Exclude from server's model states and layer names and types those types which
    # # are not going to be scheduled for synchronization
    # server_state.model_states, server_state.layer_names_and_types = (
    #     reconcile_model_state_with_scheduler(
    #         server_state.model_states, server_state.layer_names_and_types, cfg))

    # Initialize the strategy
    initialize_strategy(server_state)

    # Save the checkpoint to S3 Object Store (w/ model parameters)
    server_state.current_round = server_state.start_round
    server_state.current_time_elapsed = server_state.time_offset
    if cfg.repo.checkpoint or cfg.repo.comm_stack.s3:
        assert server_state.remote_up_down is not None, (
            "Cannot checkpoint without a RemoteUploaderDownloader object"
        )
        upload_server_checkpoint(
            server_state=server_state,
        )

    return server_state


def get_centralized_run_parameters(  # noqa: PLR0914
    dummy_config: BaseConfig,
) -> tuple[NDArrays, tuple[tuple[str, ModelStateNames], ...]]:
    """Retrieve model parameters from a previous centralized training run.

    This function finds a checkpoint from a centralized training run stored in S3 that
    matches the desired number of training batches. It loads that checkpoint into a
    temporary trainer, extracts both the model parameters and optimizer states, and
    returns them in the format needed for federated learning initialization.

    Parameters
    ----------
    dummy_config : BaseConfig
        The configuration object containing information about the centralized run to
        restore, including the run UUID and desired number of batches.

    Returns
    -------
    tuple[NDArrays, tuple[tuple[str, ModelStateNames], ...]]
        A tuple containing:
        - model_states: The model parameters and optimizer states as NDArrays
        - layer_names_and_types: A tuple mapping parameter names to their state types

    Raises
    ------
    ValueError
        If no checkpoint with the desired number of batches can be found in S3.

    """
    dummy_config = copy.deepcopy(dummy_config)
    desired_steps = dummy_config.repo.restore_cent_run_batches
    folder = f"s3://checkpoints/{dummy_config.repo.restore_cent_run_uuid}"
    remote_objects = list_remote_objects(folder)
    log(DEBUG, f"Restoring from centralized run, found {remote_objects}")
    sorted_pairs = sorted(
        [
            (
                int(reg.group(1)),  # epoch number
                int(reg.group(2)),  # number of batches
            )
            for path in remote_objects
            if (reg := re.search(r"/ep(\d+)-ba(\d+)", path)) is not None
        ],
        key=operator.itemgetter(1),
    )
    path_to_check = next(
        (
            (epoch, batches)
            for epoch, batches in sorted_pairs
            if batches == desired_steps
        ),
        None,
    )
    if path_to_check is None:
        msg = f"Could not find a checkpoint with {desired_steps} batches"
        raise ValueError(msg)
    epoch, batches = path_to_check

    dummy_config_llm = dummy_config.llm_config
    dummy_config_llm.load_path = folder + f"/ep{epoch}-ba{batches}-" + "rank{rank}.pt"
    dummy_config_llm.load_ignore_keys = [
        "*scheduler*",
        "*optim*",
        "*dataset_state*",
    ]
    os.environ["APPOINTED_CUDA_DEVICE"] = str(None)
    dummy_config_llm.save_folder = None
    dummy_config_llm.device_train_microbatch_size = 1
    # Creating ClientConfig object
    client_config = CentralizedConfig(
        allow_unigram_metrics_failures=dummy_config_llm.fl.allow_unigram_metrics_failures,
        resize_vocab=dummy_config_llm.fl.resize_vocab,
        split_eval=dummy_config_llm.centralized.split_eval,
        set_trainer_params_filter_keys=dummy_config_llm.fl.set_trainer_params_filter_keys,
        set_trainer_key_to_filter=dummy_config_llm.fl.set_trainer_key_to_filter,
        use_unigram_metrics=dummy_config_llm.fl.use_unigram_metrics,
        s3_comm_config=dummy_config_llm.s3_comm_config,
        cid=None,
    )
    # Get the temporary trainer object
    trainer, *_ = get_trainer_object(
        dummy_config_llm,
        client_config=client_config,
        no_data_loading=True,
    )
    # Obtain the list of trainable parameters names
    param_names = get_list_of_parameters_names(
        model=trainer.state.model,
        only_requires_grad=True,
        # NOTE: We need to sort the name to ensure that the order is consistent with
        # the parameters obtained earlier
        sort_list=True,
    )
    # Filter our names in case there's any "model." suffix leftover
    param_names = [param_name.replace("model.", "", 1) for param_name in param_names]
    # Construct the layer names and types data structure
    layer_names_and_types: list[tuple[str, ModelStateNames]] = []
    for i in [
        ModelStateNames.PARAMETERS,
        ModelStateNames.EXP_AVG,
        ModelStateNames.EXP_AVG_SQ,
    ]:
        layer_names_and_types.extend([(param_name, i) for param_name in param_names])
    # Get lists of layer names and types to retrieve model parameters and, if required,
    # optimizer states
    layer_names = [name for name, _ in layer_names_and_types]
    layer_types = [t for _, t in layer_names_and_types]
    model_states = get_parameters_from_state(
        {},
        trainer,
        parameter_names=(
            name
            for name, layer_type in zip(layer_names, layer_types, strict=True)
            if layer_type == ModelStateNames.PARAMETERS
        ),
    )
    opt_states = get_optimizer_state_from_trainer(
        trainer=trainer,
        layer_names=layer_names,
        layer_types=layer_types,
    )
    for acc in opt_states.values():
        model_states.extend(acc)

    return model_states, tuple(layer_names_and_types)


def initialize_federated_learning(
    cfg: BaseConfig,
    remote_up_down: RemoteUploaderDownloader | None,
    repo_save_path: str,
) -> ServerState:
    """Initialize the state for a federated learning process.

    This function sets up the initial state for a new federated learning process:
    bookkeeping variables, client states, global model parameters, and optionally
    saving the initial checkpoint to an S3 Object Store if configured. It prepares the
    server for the federated learning process by setting the starting round, time
    offset, cumulative server steps, and initializing the momentum vector for
    optimization algorithms.

    Args:
    ----
    cfg : BaseConfig
        The configuration object containing settings for federated learning and system
        behavior.
    remote_up_down : RemoteUploaderDownloader | None
        An optional uploader/downloader object for interacting with remote storage,
        required if checkpointing or S3 communication is enabled.
    repo_save_path : str
        The local path the server uses to store and load checkpoints.

    Returns:
    -------
    ServerState
        The initial server state object.

    """
    # Instantiate a PRNG
    rng = random.Random(cfg.seed)  # noqa: S311
    # Initialize the bookkeeping variables
    start_round: int = 0
    time_offset: float = 0.0
    server_steps_cumulative: int = 0
    history = WandbHistory(use_wandb=cfg.use_wandb)
    # Initialize client_state_dict
    client_states: dict[str | int, ClientState] = {
        cid: ClientState(0) for cid in range(cfg.fl.n_total_clients)
    }
    # Initialize the model parameters and layer names mapping
    initial_model_states, layer_names_and_types = get_initial_parameters(cfg)

    momentum_vector: NDArrays = []
    second_momentum_vector: NDArrays = []
    # NOTE: We should unify the state with the strategy
    # Only focusing on saving space right now
    match cfg.fl.strategy_name.lower():
        case StrategyName.FEDMOM | StrategyName.NESTOROV:
            momentum_vector = [np.zeros_like(x) for x in initial_model_states]
        case StrategyName.FEDADAM | StrategyName.FEDYOGI:
            momentum_vector = [np.zeros_like(x) for x in initial_model_states]
            second_momentum_vector = [np.zeros_like(x) for x in initial_model_states]

    return ServerState(
        model_states=initial_model_states,
        layer_names_and_types=layer_names_and_types,
        transmission_scheduler=dispatch_model_state_scheduler(cfg),
        strategy=dispatch_strategy(
            cfg,
        ),
        aggregation_mask_scheduler=dispatch_model_state_scheduler(cfg),
        history=history,
        time_offset=time_offset,
        server_steps_cumulative=server_steps_cumulative,
        client_states=client_states,
        momentum_vector=momentum_vector,
        second_momentum_vector=second_momentum_vector,
        sampled_clients=[],
        previously_sampled_clients=[],
        rng=rng,
        start_round=start_round,
        current_round=start_round,
        current_time_elapsed=time_offset,
        remote_up_down=remote_up_down,
        local_checkpoint_path=repo_save_path,
        client_sampler=ClientSampler(
            total_number_of_clients=cfg.fl.n_total_clients,
            number_of_clients_per_round=cfg.fl.n_clients_per_round,
            dropout_ratio=cfg.fl.dropout_ratio,
            dropout_function_name=cfg.fl.dropout_function_name,
        ),
    )


def resume_from_round(
    cfg: BaseConfig,
    remote_up_down: RemoteUploaderDownloader,
    repo_save_path: str,
) -> ServerState:
    """Resume from a previous round.

    Parameters
    ----------
    cfg : BaseConfig
        The configuration object.
    remote_up_down : RemoteUploaderDownloader
        The object to upload/download files from/to the S3 Object Store.
    repo_save_path : str
        The local path the server uses to store and load checkpoints.

    Returns
    -------
    ServerState
        The server state object.

    """
    strategy = dispatch_strategy(
        cfg,
    )
    # Obtain server and strategy state keys
    state_keys = (CURRENT_SERVER_STATE_FILENAME, *strategy.state_keys)
    # Interpret the resume round
    cfg.repo.resume_round = interpret_resume_round(
        resume_round=cfg.repo.resume_round,
        run_uuid_path=(f"s3://{cfg.s3_comm_config.bucket_name}/{cfg.run_uuid}/"),
        # NOTE: This condition will ensure that if no checkpoints are found but the
        # `cgf.repo.resume_round` equals -1, no error is raised and federated learning
        # will start from scratch. However, if the `cfg.repo.resume_round` is lower
        # than -1, it will raise an error.
        raise_error=cfg.repo.resume_round != -1,
        state_keys=state_keys,
    )
    assert cfg.repo.resume_round is not None, (
        "Cannot resume run if `cfg.repo.resume_round` is None"
    )
    # NOTE: If the `interpret_resume_round` function returns
    # `cfg.repo.resume_round == -1`, it means that it couldn't find the last
    # checkpoint. In fact, if it finds the last checkpoint for round R>0, it would have
    # replaced `cfg.repo.resume_round` with R, thus `cfg.repo.resume_round != -1`.
    if cfg.repo.resume_round == -1:
        log(INFO, "There is no checkpoint to resume. Starting from scratch.")
        return initialize_federated_learning(
            cfg=cfg,
            remote_up_down=remote_up_down,
            repo_save_path=repo_save_path,
        )
    log(DEBUG, "Resume round %s", cfg.repo.resume_round)
    # Import another experiment checkpoints for restoration
    if cfg.repo.restore_run_uuid is not None:
        log(
            DEBUG,
            "Copy past checkpoints from run with UUID %s",
            cfg.repo.restore_run_uuid,
        )
        copy_old_checkpoints_to_new_run(
            remote_up_down=remote_up_down,
            bucket_uri=f"s3://{cfg.s3_comm_config.bucket_name}",
            restore_run_data=(
                cfg.run_uuid,
                cfg.repo.restore_run_uuid,
                cfg.repo.resume_round,
                cfg.repo.resume_round * cfg.fl.n_local_steps,
            ),
            n_total_clients=cfg.fl.n_total_clients,
            copy_client_checkpoints=cfg.repo.copy_client_checkpoints,
        )
    log(INFO, "A checkpoint was found. Resuming from it.")
    return get_server_checkpoint(
        cfg=cfg,
        remote_up_down=remote_up_down,
        repo_save_path=repo_save_path,
    )
