"""Implementation of the Flower's ServerApp for orchestrating federate learning."""

import contextlib
import os
import time
import timeit
import warnings
from functools import partial
from logging import DEBUG, INFO
from queue import Queue
from typing import TYPE_CHECKING, cast

import flwr as fl
from flwr.common import (
    Context,
    MessageType,
)
from flwr.common.logger import log, update_console_handler
from flwr.server import Driver
from omegaconf import OmegaConf

import ray
import wandb
from repo.clients.configs import (
    get_repo_evaluate_config_fn,
    get_repo_fit_config_fn,
)
from repo.masks_utils import generate_mask
from repo.server.broadcast_utils import broadcast_parameters_to_nodes
from repo.server.evaluate_utils import evaluate_round
from repo.server.gather_utils import gather_round
from repo.server.init_utils import (
    initialize_server_state,
)
from repo.server.load_balancer import ServerLoadBalancer
from repo.server.s3_utils import (
    cleanup_checkpoints,
    upload_server_checkpoint,
)
from repo.server.utils import (
    message_collaborative,
    wait_for_nodes_to_connect,
)
from repo.utils import (
    custom_ray_garbage_collector,
    wandb_init,
)

if TYPE_CHECKING:
    from repo.conf.base_schema import BaseConfig

# Fix the logger
update_console_handler(level=DEBUG, colored=False, timestamps=True)
# Filter user warning from configuration of MPT
warnings.filterwarnings(
    action="ignore",
    category=UserWarning,
    message=("If not using a Prefix Language Model*"),
    append=True,
)
# TODO(<Anonymous>): These don't work -- not sure why
# Filter deprecation warning from pkg_resources
warnings.filterwarnings(
    action="ignore",
    category=DeprecationWarning,
    message=("Deprecated call to *"),
    append=True,
)
warnings.filterwarnings(
    action="ignore",
    category=DeprecationWarning,
    message=("pkg_resources is deprecated*"),
    append=True,
)

# Run via `flower-server-app server:app`
app = fl.server.ServerApp()


# TODO(<Anonymous>): Breakup the main function into smaller functions
@app.main()
def main(  # noqa: PLR0915
    driver: Driver,
    context: Context,  # noqa: ARG001
) -> None:
    """Implement the main function for the Flower ServerApp.

    Parameters
    ----------
    driver : fl.server.Driver
        The driver object for the server.
    context : fl.common.Context
        The context object for the server.

    Raises
    ------
    ValueError
        If the environmental variable `repo_SAVE_PATH` is not set.

    """
    start_up_time = timeit.default_timer()
    # Get the environmental variable for the dump folder
    repo_save_path = os.environ.get("repo_SAVE_PATH", "")
    # Raise an error if the environmental variable is not set
    if not repo_save_path:
        msg = "repo_SAVE_PATH is not set."
        raise ValueError(msg)
    # Load the configuration from the config file
    cfg = cast("BaseConfig", OmegaConf.load(repo_save_path + "/config.yaml"))
    log(INFO, "Initializing repo Server")
    custom_ray_context: (
        contextlib._GeneratorContextManager[None] | contextlib.nullcontext[None]
    ) = contextlib.nullcontext()
    server_state = initialize_server_state(cfg, repo_save_path)
    server_state.local_checkpoint_path = repo_save_path
    # Creating Queue for custom Ray garbage collector
    if cfg.repo.comm_stack.ray:
        server_state.ray_garbage_queue = Queue()
        # Initialize Ray
        ray.init("auto")
        custom_ray_context = custom_ray_garbage_collector(
            garbage_queue=server_state.ray_garbage_queue,
            list_of_threads=[],
            process_name="server",
        )
    # Resolve WandB config
    wandb_config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
    with (
        wandb_init(  # type: ignore[union-attr,misc]
            cfg.use_wandb,
            **cfg.wandb.setup,  # type: ignore[reportCallIssue]
            settings=wandb.Settings(start_method="thread"),  # type: ignore[arg-type]
            config=wandb_config,  # type: ignore[arg-type]
        ),
        custom_ray_context,
    ):
        # Wait for the minimum number of nodes to connect
        wait_for_nodes_to_connect(driver, cfg.repo.n_nodes)

        log(
            INFO,
            "Start-up time for the server is %s",
            timeit.default_timer() - start_up_time,
        )
        # Run federated learning for number of rounds
        log(INFO, "FL starting from round %s", server_state.start_round)
        start_time = timeit.default_timer()

        # Create the LoadBalancer object
        server_load_balancer = ServerLoadBalancer(
            driver=driver,
            total_number_of_clients=cfg.fl.n_total_clients,
            rng=server_state.rng,
            is_production=cfg.repo.is_production,
        )
        # Federated learning loop
        for current_federated_round in range(
            server_state.start_round + 1,
            cfg.fl.n_rounds + 1,
        ):
            start_round_time = time.time_ns()
            server_state.current_round = current_federated_round

            log(DEBUG, f"Commencing server round {server_state.current_round}")

            # Sampling clients for the current round
            server_state.previously_sampled_clients = server_state.sampled_clients
            server_state.sampled_clients = server_state.client_sampler.sample_clients(
                rng=server_state.rng,
            )

            # Broadcast the model parameters to all NodeManagers
            broadcast_time = time.time_ns()
            # Check for changes in connected NodeManagers
            server_load_balancer.check_nodes_health()
            broadcast_parameters_to_nodes(
                server_load_balancer=server_load_balancer,
                server_state=server_state,
                comm_stack=cfg.repo.comm_stack,
            )
            server_state.history.add_metrics_centralized(
                server_round=server_state.current_round,
                metrics={
                    "server/broadcast_time": (time.time_ns() - broadcast_time) * 1e-9,
                },
            )

            # Evaluate the model
            if (
                cfg.fl.eval_period is not None
                and server_state.current_round % cfg.fl.eval_period == 0
            ):
                # Launch the evaluate process
                evaluate_round(
                    server_load_balancer=server_load_balancer,
                    evaluate_config_fn=get_repo_evaluate_config_fn(cfg),
                    cfg=cfg,
                    server_state=server_state,
                )

            # Fit the model
            fit_and_gather_round_time = time.time_ns()
            fit_config_fn = get_repo_fit_config_fn(
                cfg,
                aggregation_mask=partial(
                    generate_mask,
                    server_state.layer_names_and_types,
                    server_state.aggregation_mask_scheduler,
                ),
            )
            gather_round(
                server_state=server_state,
                server_load_balancer=server_load_balancer,
                cfg=cfg,
                fit_round_results=message_collaborative(
                    server_load_balancer=server_load_balancer,
                    message_constants=(MessageType.TRAIN, "fitins"),
                    server_state=server_state,
                    gen_ins_function=fit_config_fn,
                ),
            )
            server_state.history.add_metrics_centralized(
                server_round=server_state.current_round,
                metrics={
                    "server/fit_and_gather_round_time": (
                        (time.time_ns() - fit_and_gather_round_time) * 1e-9
                    ),
                },
            )
            log(DEBUG, f"Fit and gather round {server_state.current_round} completed")

            # Save the checkpoint to S3 Object Store
            server_state.current_time_elapsed = (
                timeit.default_timer() - start_time + server_state.time_offset
            )
            if cfg.repo.checkpoint or cfg.repo.comm_stack.s3:
                if server_state.remote_up_down is None:
                    msg = "Cannot checkpoint without a RemoteUploaderDownloader object"
                    raise ValueError(msg)
                upload_server_checkpoint(
                    server_state=server_state,
                )

            # Log the time taken for the round
            server_state.history.add_metrics_centralized(
                server_round=server_state.current_round,
                metrics={
                    "server/round_time": (time.time_ns() - start_round_time) * 1e-9,
                },
            )
            # Clean up checkpoints if asked to
            if cfg.repo.cleanup_checkpoints_per_round:
                cleanup_checkpoints(
                    run_uuid=cfg.run_uuid,
                    strategy_state_keys=server_state.strategy.state_keys,
                    end_idx=-1,
                    server_state=server_state,
                )

        # Bookkeeping
        log(
            INFO,
            "FL finished in %s",
            timeit.default_timer() - start_time + server_state.time_offset,
        )

        log(
            DEBUG,
            "app_fit: losses_distributed %s",
            str(server_state.history.losses_distributed),
        )
        log(
            DEBUG,
            "app_fit: metrics_distributed_fit %s",
            str(server_state.history.metrics_distributed_fit),
        )
        log(
            DEBUG,
            "app_fit: metrics_distributed %s",
            str(server_state.history.metrics_distributed),
        )
        log(
            DEBUG,
            "app_fit: losses_centralized %s",
            str(server_state.history.losses_centralized),
        )
        log(
            DEBUG,
            "app_fit: metrics_centralized %s",
            str(server_state.history.metrics_centralized),
        )

        # Clean up checkpoints if asked to
        if cfg.repo.cleanup_checkpoints:
            cleanup_checkpoints(
                run_uuid=cfg.run_uuid,
                strategy_state_keys=server_state.strategy.state_keys,
                end_idx=None,
                server_state=server_state,
            )
