import logging

from federatedscope.core.configs.config import CN
from federatedscope.register import register_config


def extend_asyn_cfg(cfg):
    # ---------------------------------------------------------------------- #
    # Asynchronous related options
    # ---------------------------------------------------------------------- #
    cfg.asyn = CN()

    cfg.asyn.use = False
    cfg.asyn.time_budget = 0
    cfg.asyn.min_received_num = 2
    cfg.asyn.min_received_rate = -1.0
    cfg.asyn.staleness_toleration = 0
    cfg.asyn.staleness_discount_factor = 1.0
    cfg.asyn.aggregator = 'goal_achieved'  # ['goal_achieved', 'time_up']
    # 'goal_achieved': perform aggregation when the defined number of feedback
    # has been received; 'time_up': perform aggregation when the allocated
    # time budget has been run out
    cfg.asyn.broadcast_manner = 'after_aggregating'  # ['after_aggregating',
    # 'after_receiving'] 'after_aggregating': broadcast the up-to-date global
    # model after performing federated aggregation;
    # 'after_receiving': broadcast the up-to-date global model after receiving
    # the model update from clients
    cfg.asyn.overselection = False

    # --------------- register corresponding check function ----------
    cfg.register_cfg_check_fun(assert_asyn_cfg)


def assert_asyn_cfg(cfg):
    if not cfg.asyn.use:
        return True
    # to ensure a valid time budget
    assert isinstance(cfg.asyn.time_budget, int) or isinstance(
        cfg.asyn.time_budget, float
    ), "The time budget (seconds) must be an int or a float value, " \
       "but {} is got".format(
        type(cfg.asyn.time_budget))

    # min received num pre-process
    min_received_num_valid = (0 < cfg.asyn.min_received_num <=
                              cfg.federate.sample_client_num)
    min_received_rate_valid = (0 < cfg.asyn.min_received_rate <= 1)
    # (a) sampling case
    if min_received_rate_valid:
        # (a.1) use min_received_rate
        old_min_received_num = cfg.asyn.min_received_num
        cfg.asyn.min_received_num = max(
            1,
            int(cfg.asyn.min_received_rate * cfg.federate.sample_client_num))
        if min_received_num_valid:
            logging.warning(
                f"Users specify both valid min_received_rate as"
                f" {cfg.asyn.min_received_rate} "
                f"and min_received_num as {old_min_received_num}.\n"
                f"\t\tWe will use the min_received_rate value to calculate "
                f"the actual number of participated clients as"
                f" {cfg.asyn.min_received_num}.")
    # (a.2) use min_received_num, commented since the below two lines do not
    # change anything elif min_received_rate:
    #     cfg.asyn.min_received_num = cfg.asyn.min_received_num
    if not (min_received_num_valid or min_received_rate_valid):
        # (b) non-sampling case, use all clients
        cfg.asyn.min_received_num = cfg.federate.sample_client_num

    # to ensure a valid staleness toleation
    assert cfg.asyn.staleness_toleration >= 0 and isinstance(
        cfg.asyn.staleness_toleration, int
    ), f"Please provide a valid staleness toleration value, " \
       f"expect an integer value that is larger or equal to 0, " \
       f"but got {cfg.asyn.staleness_toleration}."

    assert cfg.asyn.aggregator in ["goal_achieved", "time_up"], \
        f"Please specify the cfg.asyn.aggregator as string 'goal_achieved' " \
        f"or 'time_up'. But got {cfg.asyn.aggregator}."
    assert cfg.asyn.broadcast_manner in ["after_aggregating",
                                         "after_receiving"], \
        f"Please specify the cfg.asyn.broadcast_manner as the string " \
        f"'after_aggregating' or 'after_receiving'. " \
        f"But got {cfg.asyn.broadcast_manner}."


register_config("asyn", extend_asyn_cfg)
