from .engine import LowBitEngine
from .conf import config

def initialize(
    args,
    model,
    optimizer,
    name_groups=None,
):
    _configure_config(args)

    engine = LowBitEngine(
        args,
        model,
        optimizer,
        name_groups,
    )
    return (
        engine,
        engine.optimizer
    )


def _configure_config(args):
    config.enable_quantize_p = args.pb < 32
    config.enable_quantize_grad = args.gb <= 8
    config.enable_quantize_mm = args.mb <= 8
    config.enable_quantize_sqmm = args.smb <= 8

    config.compression_bits_p = args.pb
    config.compression_bits_grad = min(args.gb, 8)
    config.compression_bits_mm = min(args.mb, 8)
    config.compression_bits_sqmm = min(args.smb, 8)

    config.stochastic = args.sq > 0 if getattr(args, 'sq', None) is not None else True
    config.group_size = args.gp_sz if getattr(args, 'gp_sz', None) is not None else 2048

    config.numerical_mode_p = args.numerical_mode_p if getattr(args, 'numerical_mode_p', None) is not None else config.default_numerical_mode
    config.numerical_mode_grad = args.numerical_mode_grad if getattr(args, 'numerical_mode_grad', None) is not None else config.default_numerical_mode
    config.numerical_mode_mm = args.numerical_mode_mm if getattr(args, 'numerical_mode_mm', None) is not None else config.default_numerical_mode
    config.numerical_mode_sqmm = args.numerical_mode_sqmm if getattr(args, 'numerical_mode_sqmm', None) is not None else config.default_numerical_mode
    
    # quantization technique
    config.steps_requantization_p = args.num_steps_requantization_p if getattr(args, 'num_steps_requantization_p', None) else 1
    config.steps_requantization_mm = args.num_steps_requantization_mm if getattr(args, 'num_steps_requantization_mm', None) else 1
    config.unbiased_sqmm_flag = args.unbiased_sqmm_flag > 0 if getattr(args, 'unbiased_sqmm_flag', None) is not None else True
    config.save_grad_outlier = False
    config.params_initial_dilation = args.params_initial_dilation if getattr(args, 'params_initial_dilation', None) is not None else 1 
    config.params_freeze_dilation = args.params_freeze_dilation if getattr(args, 'params_freeze_dilation', None) is not None else 1 
    ## note that only specific for imagenet 90 epochs training
    config.epochs_fix_scale = args.epochs_fix_scale if getattr(args, 'epochs_fix_scale', None) is not None else 80 

    # debug
    config.debug_memory_model = args.debug_memory_model if getattr(args, 'debug_memory_model', None) is not None else False
    config.debug_quantization_difference = True
    config.debug_proba_hist = True
    if config.debug_proba_hist:
        config.debug_GOR_freq = 200
    
    # print config
    print(f'Quantization config:')
    for key in config.__dict__:
        print(f'\t{key}: {getattr(config, key)}')
    print(f'Config configuration ended')
