"""Utility functions for evaluate tasks on main server loop in flwr next."""

import time
from collections.abc import Callable, Generator
from logging import DEBUG, ERROR
from typing import cast

from flwr.common import (
    Code,
    EvaluateRes,
    Message,
    MessageType,
    Scalar,
    Status,
    log,
)
from flwr.common.recordset_compat import (
    recordset_to_evaluateres,
)
from flwr.common.typing import ConfigsRecordValues

from repo.conf.base_schema import BaseConfig
from repo.server.load_balancer import ServerLoadBalancer
from repo.server.utils import (
    ServerState,
    TooManyFailuresError,
    message_collaborative,
)


def handle_evaluate_replies(
    cfg: BaseConfig,
    replies: Generator[Message, None, None],
    server_state: ServerState,
) -> (
    tuple[
        float | None,
        dict[str, Scalar],
        tuple[list[tuple[None, EvaluateRes | None]], list[EvaluateRes | None]],
    ]
    | None
):
    """Process evaluation replies from clients and aggregate results.

    This function processes evaluation responses from clients, separating successful
    and failed responses. It aggregates the evaluation results using the server's
    strategy and logs information about successes and failures.

    Parameters
    ----------
    cfg : BaseConfig
        The configuration parameters for the federated learning process.
    replies : Generator[Message, None, None]
        Generator yielding messages containing evaluation results from clients.
    server_state : ServerState
        The current state of the server containing strategy and round information.

    Returns
    -------
    tuple[
        float | None,
        dict[str, Scalar],
        tuple[list[tuple[None, EvaluateRes | None]], list[EvaluateRes | None]]
    ] | None
        A tuple containing:
        - Aggregated loss (if any, otherwise None)
        - Aggregated metrics dictionary
        - A tuple of (completed results, failures)
          where completed results is a list of (None, EvaluateRes) pairs
          and failures is a list of failed EvaluateRes or None values.
        Returns None if processing fails.

    """
    all_eval_res = (
        recordset_to_evaluateres(msg.content) if msg.has_content() else None
        for msg in replies
    )

    results_and_failures = (
        (
            (eval_res.status.code == Code.OK, eval_res)
            if eval_res is not None
            else (False, eval_res)
        )
        for eval_res in all_eval_res
    )

    # Using a generator limits us in failure/metrics accumulation
    # The output params are not used in the aggregation
    # They are merely populated by the processing of the generator
    failures: list[EvaluateRes | None] = []

    metrics_accumulator: list[tuple[dict[str, Scalar], Status, int]] = []

    handle_success_and_failure = get_handle_success_and_failure_evaluate(
        metrics_accumulator,
        failures,
        accept_failures_cnt=cfg.fl.accept_failures_cnt,
    )
    results_and_failures = (
        handle_success_and_failure(result)  # type: ignore[arg-type]
        for result in results_and_failures
    )

    results = (result for success, result in results_and_failures if success)
    completed_results = [(None, evaluate_res) for evaluate_res in results]

    aggregated_result: tuple[
        float | None,
        dict[str, Scalar],
    ] = server_state.strategy.aggregate_evaluate(
        server_state.current_round,
        completed_results,  # type: ignore[reportArgumentType,arg-type]
        failures,  # type: ignore[reportArgumentType,arg-type]
    )

    if len(failures) > 0:
        log(
            ERROR,
            "evaluate_round %s: there are %s failures: %s",
            server_state.current_round,
            len(failures),
            failures,
        )
    log(
        DEBUG,
        "evaluate_round %s received %s results and %s failures",
        server_state.current_round,
        len(completed_results),
        len(failures),
    )

    loss_aggregated, metrics_aggregated = aggregated_result
    return loss_aggregated, metrics_aggregated, (completed_results, failures)


def get_handle_success_and_failure_evaluate(
    metrics_accumulator: list[tuple[dict[str, Scalar], Status, int]],
    evaluate_failures: list[EvaluateRes | None],
    accept_failures_cnt: int | None,
) -> Callable[
    [tuple[bool, EvaluateRes] | tuple[bool, None]],
    tuple[bool, EvaluateRes] | tuple[bool, None],
]:
    """Closure to generate a function which handles client success and failure.

    The function distinguishes between intentional and unintentional failures.
    It enforces constraints on the number of unintentional failures.
    It stores results and failures in the respective lists.

    Parameters
    ----------
    metrics_accumulator : List[Tuple[ClientProxy, Dict[str, Scalar], Status, int]]
        The list where the metrics are accumulated.
    evaluate_failures : List[Union[EvaluateRes, BaseException]]
        The list where the failures are accumulated.
    accept_failures_cnt : int | None
        The maximum number of unintentional failures to accept.

    Returns
    -------
    handle_success_and_failure : Callable[
        [
            Tuple[bool, EvaluateRes]
            | Tuple[bool, EvaluateRes | BaseException]
        ],
        Tuple[bool, EvaluateRes]
        | Tuple[bool, EvaluateRes | BaseException],
    ]
        The function which handles client success and failure while saving the outputs.

    """

    def handle_success_and_failure_evaluate(
        result: tuple[bool, EvaluateRes] | tuple[bool, None],
    ) -> tuple[bool, EvaluateRes] | tuple[bool, None]:
        cnt_failures = 0

        match result:
            case (True, res):
                evaluate_res = cast("EvaluateRes", res)
                metrics_accumulator.append(
                    (
                        evaluate_res.metrics,
                        evaluate_res.status,
                        evaluate_res.num_examples,
                    ),
                )
                return (True, evaluate_res)
            case (False, res):
                cnt_failures += 1
                if (
                    accept_failures_cnt is not None
                    and cnt_failures > accept_failures_cnt
                ):
                    msg = f"""Unintentional failures passed
                        the maximum: {accept_failures_cnt}"""
                    raise TooManyFailuresError(
                        msg,
                    )
                evaluate_failures.append(res)  # type: ignore[arg-type]
                return (False, res) if res is not None else (False, None)  # type: ignore[return-value]
        return result

    return handle_success_and_failure_evaluate


def evaluate_round(
    server_load_balancer: ServerLoadBalancer,
    evaluate_config_fn: Callable[[int, int | str], dict[str, ConfigsRecordValues]],
    server_state: ServerState,
    cfg: BaseConfig,
) -> None:
    """Evaluate the current global model across selected clients.

    This function coordinates the evaluation process by sending evaluation instructions
    to available clients through the server load balancer, processing their responses,
    and recording evaluation metrics in the server state history.

    Parameters
    ----------
    server_load_balancer : ServerLoadBalancer
        The load balancer responsible for distributing tasks to clients.
    evaluate_config_fn : Callable[[int, int | str], dict[str, ConfigsRecordValues]]
        Function that generates configuration for evaluation tasks based on
        client ID and server round.
    server_state : ServerState
        The current state of the server containing model parameters and metrics history.
    cfg : BaseConfig
        The configuration parameters for the federated learning process.

    """
    # Evaluate model on a sample of available clients
    evaluate_round_time = time.time_ns()

    eval_replies = (
        # Send a single client-message to every NodeManager at a time
        message_collaborative(
            server_load_balancer=server_load_balancer,
            message_constants=(MessageType.EVALUATE, "evaluateins"),
            server_state=server_state,
            gen_ins_function=evaluate_config_fn,
        )
    )
    res_fed = handle_evaluate_replies(
        cfg=cfg,
        replies=eval_replies,
        server_state=server_state,
    )
    if res_fed is not None:
        loss_fed, evaluate_metrics_fed, _ = res_fed
        if loss_fed is not None:
            server_state.history.add_loss_distributed(
                server_round=server_state.current_round,
                loss=loss_fed,
            )
            server_state.history.add_metrics_distributed(
                server_round=server_state.current_round,
                metrics=evaluate_metrics_fed,
            )
    server_state.history.add_metrics_centralized(
        server_round=server_state.current_round,
        metrics={
            "server/evaluate_round_time": (time.time_ns() - evaluate_round_time) * 1e-9,
        },
    )
