import torch
from torch.optim import SGD, AdamW

def build_optimizer(cfg, params_groups):
    if cfg.SOLVER.OPTIMIZER_NAME == 'AdamW':
        return AdamW(
            params_groups,
            betas=(cfg.SOLVER.BETA1, cfg.SOLVER.BETA2)
        )
    elif cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
        return SGD(
            params_groups,
            lr=cfg.SOLVER.BASE_LR,
            weight_decay=cfg.SOLVER.WEIGHT_DECAY
        )

def get_mean_lr(optimizer):
    return torch.mean(torch.Tensor([param_group['lr'] for param_group in optimizer.param_groups])).item()