import logging

from federatedscope.core.configs.config import CN
from federatedscope.register import register_config
from federatedscope.core.monitors.metric_calculator import SUPPORT_METRICS

logger = logging.getLogger(__name__)


def extend_hpo_cfg(cfg):

    # ---------------------------------------------------------------------- #
    # hpo related options
    # ---------------------------------------------------------------------- #
    cfg.hpo = CN()
    cfg.hpo.trial_index = 0
    cfg.hpo.working_folder = 'hpo'
    cfg.hpo.ss = ''
    cfg.hpo.num_workers = 0
    cfg.hpo.init_cand_num = 16
    cfg.hpo.larger_better = False  # Non-configurable, determined by metric
    cfg.hpo.scheduler = 'rs'
    cfg.hpo.metric = 'client_summarized_weighted_avg.val_loss'

    # SHA
    cfg.hpo.sha = CN()
    cfg.hpo.sha.elim_rate = 3
    cfg.hpo.sha.budgets = []
    cfg.hpo.sha.iter = 0

    # PBT
    cfg.hpo.pbt = CN()
    cfg.hpo.pbt.max_stage = 5
    cfg.hpo.pbt.perf_threshold = 0.1

    # FedEx
    cfg.hpo.fedex = CN()
    cfg.hpo.fedex.use = False
    cfg.hpo.fedex.ss = ''
    cfg.hpo.fedex.flatten_ss = True
    # If <= .0, use 'auto'
    cfg.hpo.fedex.eta0 = -1.0
    cfg.hpo.fedex.sched = 'auto'
    # cutoff: entropy level below which to stop updating the config
    # probability and use MLE
    cfg.hpo.fedex.cutoff = .0
    # discount factor; 0.0 is most recent, 1.0 is mean
    cfg.hpo.fedex.gamma = .0
    cfg.hpo.fedex.diff = False
    cfg.hpo.fedex.psn = False
    cfg.hpo.fedex.pi_lr = 0.01

    # Table
    cfg.hpo.table = CN()
    cfg.hpo.table.eps = 0.1
    cfg.hpo.table.num = 27
    cfg.hpo.table.idx = 0

    # FTS
    cfg.hpo.fts = CN()
    cfg.hpo.fts.use = False
    cfg.hpo.fts.ss = ''
    cfg.hpo.fts.target_clients = []
    cfg.hpo.fts.diff = False
    cfg.hpo.fts.local_bo_max_iter = 50
    cfg.hpo.fts.local_bo_epochs = 50
    cfg.hpo.fts.fed_bo_max_iter = 50
    cfg.hpo.fts.ls = 1.0
    cfg.hpo.fts.var = 0.1
    cfg.hpo.fts.g_var = 1e-6
    cfg.hpo.fts.v_kernel = 1.0
    cfg.hpo.fts.obs_noise = 1e-6
    cfg.hpo.fts.M = 100
    cfg.hpo.fts.M_target = 200
    cfg.hpo.fts.gp_opt_schedule = 1
    cfg.hpo.fts.allow_load_existing_info = True

    # pfedhpo
    cfg.hpo.pfedhpo = CN()
    cfg.hpo.pfedhpo.use = False
    cfg.hpo.pfedhpo.discrete = False
    cfg.hpo.pfedhpo.train_fl = False
    cfg.hpo.pfedhpo.train_anchor = False
    cfg.hpo.pfedhpo.ss = ''
    cfg.hpo.pfedhpo.target_fl_total_round = 1000

    # --------------- register corresponding check function ----------
    cfg.register_cfg_check_fun(assert_hpo_cfg)


def assert_hpo_cfg(cfg):
    for key, value in SUPPORT_METRICS.items():
        is_larger_the_better = value[1]
        if key in cfg.hpo.metric and is_larger_the_better != \
                cfg.hpo.larger_better:
            logger.warning(f'`cfg.hpo.larger_better` is overwritten by '
                           f'{is_larger_the_better} for the metric `'
                           f'{cfg.hpo.metric}` is  {is_larger_the_better} '
                           f'for larger the better.')
            cfg.hpo.larger_better = is_larger_the_better
            break

    assert cfg.hpo.num_workers >= 0, "#worker should be non-negative but " \
                                     "given {}".format(cfg.hpo.num_workers)
    assert not (cfg.hpo.fedex.use and cfg.federate.use_ss
                ), "Cannot use secret sharing and FedEx at the same time"
    assert cfg.train.optimizer.type == 'SGD' or not cfg.hpo.fedex.use, \
        "SGD is required if FedEx is considered"
    assert cfg.hpo.fedex.sched in [
        'adaptive', 'aggressive', 'auto', 'constant', 'scale'
    ], "schedule of FedEx must be choice from {}".format(
        ['adaptive', 'aggressive', 'auto', 'constant', 'scale'])
    assert not cfg.hpo.fedex.gamma < .0 and cfg.hpo.fedex.gamma <= 1.0, \
        "{} must be in [0, 1]".format(cfg.hpo.fedex.gamma)
    assert cfg.hpo.fedex.use == cfg.federate.use_diff, "Once FedEx is " \
                                                       "adopted, " \
                                                       "federate.use_diff " \
                                                       "must be True."

    assert cfg.hpo.fts.use == cfg.federate.use_diff, \
        "Once FTS is adopted, federate.use_diff must be True."


register_config("hpo", extend_hpo_cfg)
