"""Parameter scheduler dispatcher.

These dispatchers are meant to extend the standard `get_fit_config`
and `get_eval_config` that the Flower Framework's Server uses to
compile the task instructions for the client. As such, the functions
dispatched should retain the same typing of the `get_fit_config` or
`get_eval_config`.
"""

from collections.abc import Callable
from typing import cast

from omegaconf import DictConfig

from repo.conf.base_schema import BaseConfig, ParamSchedulerName
from repo.utils import ModelStateNames


def get_freq_param_scheduler(
    kwargs: DictConfig | dict[str, int],
) -> Callable[[str | int, int], list[ModelStateNames]]:
    """Get a frequency parameter scheduler.

    Parameters
    ----------
    kwargs : DictConfig | dict[str, int]
        The keyword arguments, should be
            a dictionary of parameter names and frequencies.

    Returns
    -------
    Callable[[str | int, int], list[ModelStateNames]]
        The frequency parameter scheduler.

    """

    def freq_scheduler(cid: str | int, server_round: int) -> list[ModelStateNames]:  # noqa: ARG001
        return [
            ModelStateNames[cast("str", name).upper()]
            for name, freq in kwargs.items()
            if server_round % freq == 0
        ]

    return freq_scheduler


def dispatch_model_state_scheduler(
    cfg: BaseConfig,
) -> Callable[[str | int, int], list[ModelStateNames]]:
    """Dispatch the model state scheduler.

    Parameters
    ----------
    cfg : BaseConfig
        The configuration.

    Returns
    -------
    Callable[[str | int, int], list[ModelStateNames]]
        The model state scheduler.

    Raises
    ------
    ValueError
        If the scheduler contains unrecognized parameter values.

    """
    match cfg.fl.parameter_scheduler_name.lower():
        case ParamSchedulerName.ALL:
            return lambda _cid, _server_round: [ModelStateNames.ALL]
        case ParamSchedulerName.FREQ:
            return get_freq_param_scheduler(
                cfg.fl.parameter_scheduler_kwargs,
            )
        case _:
            msg = "Invalid parameter scheduler name."
            raise ValueError(msg)
