
import torch
import numpy as np


def choose_optimizer(optimizer_name: str, *params, **kwargs):
    if optimizer_name == 'AdamW':
        return AdamW(*params, **kwargs)
    else:
        raise NotImplementedError

def AdamW(model_param, lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False):
    optimizer = torch.optim.AdamW(
                model_param,
                lr=lr,
                betas=betas,
                eps=eps,
                weight_decay=weight_decay,
                amsgrad=amsgrad
    )
    return optimizer
