import logging

log = logging.getLogger()


def get_parameter_groups(model, stage_cfg, print_log=False):
    """
    Assign different weight decays and learning rates to different parameters.
    Returns a parameter group which can be passed to the optimizer.
    """
    weight_decay = stage_cfg.weight_decay
    embed_weight_decay = stage_cfg.embed_weight_decay
    backbone_lr_ratio = stage_cfg.backbone_lr_ratio
    base_lr = stage_cfg.learning_rate

    backbone_params = []
    embed_params = []
    other_params = []

    embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
    embedding_names = [e + '.weight' for e in embedding_names]

    # inspired by detectron2
    memo = set()
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        # Avoid duplicating parameters
        if param in memo:
            continue
        memo.add(param)

        if name.startswith('module'):
            name = name[7:]

        inserted = False
        if name.startswith('pixel_encoder.'):
            backbone_params.append(param)
            inserted = True
            if print_log:
                log.info(f'{name} counted as a backbone parameter.')
        else:
            for e in embedding_names:
                if name.endswith(e):
                    embed_params.append(param)
                    inserted = True
                    if print_log:
                        log.info(f'{name} counted as an embedding parameter.')
                    break

        if not inserted:
            other_params.append(param)

    parameter_groups = [
        {
            'params': backbone_params,
            'lr': base_lr * backbone_lr_ratio,
            'weight_decay': weight_decay
        },
        {
            'params': embed_params,
            'lr': base_lr,
            'weight_decay': embed_weight_decay
        },
        {
            'params': other_params,
            'lr': base_lr,
            'weight_decay': weight_decay
        },
    ]

    return parameter_groups