import logging

from federatedscope.core.configs.config import CN
from federatedscope.register import register_config

logger = logging.getLogger(__name__)


def extend_compression_cfg(cfg):
    # ---------------------------------------------------------------------- #
    # Compression (for communication efficiency) related options
    # ---------------------------------------------------------------------- #
    cfg.quantization = CN()

    # Params
    cfg.quantization.method = 'none'  # ['none', 'uniform']
    cfg.quantization.nbits = 8  # [8,16]

    # --------------- register corresponding check function ----------
    cfg.register_cfg_check_fun(assert_compression_cfg)


def assert_compression_cfg(cfg):

    if cfg.quantization.method.lower() not in ['none', 'uniform']:
        logger.warning(
            f'Quantization method is expected to be one of ["none",'
            f'"uniform"], but got "{cfg.quantization.method}". So we '
            f'change it to "none"')

    if cfg.quantization.method.lower(
    ) != 'none' and cfg.quantization.nbits not in [8, 16]:
        raise ValueError(f'The value of cfg.quantization.nbits is invalid, '
                         f'which is expected to be one on [8, 16] but got '
                         f'{cfg.quantization.nbits}.')


register_config("compression", extend_compression_cfg)
