import math
import torch

def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1, momentum=0.95, beta1=0.9, beta2=0.95):
    if optimizer_name == "adamw":
        return torch.optim.AdamW(
            model.parameters(), lr=lr, weight_decay=wd, betas=(beta1, beta2)
        )
    elif optimizer_name == "sgdm":
        return torch.optim.SGD(
            model.parameters(), lr=lr, weight_decay=wd, momentum=momentum
        )
    else:
        assert 0, "optimizer not supported"
