import inspect

from utils.optim.adam_cpr import AdamCPR, group_parameters_for_cpr_optimizer


def apply_CPR(model, optimizer_cls, kappa_init_param, kappa_init_method='warm_start', reg_function='l2',
              kappa_adapt=False, kappa_update=1.0,
              normalization_regularization=False, bias_regularization=False, embedding_regularization=False,
              **optimizer_args):
    optimizer_args['weight_decay'] = 0

    param_groups = group_parameters_for_cpr_optimizer(model=model, bias_weight_decay=bias_regularization,
                                                      normalization_weight_decay=normalization_regularization)

    optimizer_keys = inspect.getfullargspec(optimizer_cls).args
    for k, v in optimizer_args.items():
        if k not in optimizer_keys:
            raise UserWarning(f"apply_CPR: Unknown optimizer argument {k}")

    optimizer = AdamCPR(param_groups, kappa_init_param=kappa_init_param, kappa_init_method=kappa_init_method,
                        reg_function=reg_function, kappa_update=kappa_update,
                        lr=optimizer_args['lr'], betas=optimizer_args['betas'])

    return optimizer
