from torch.optim import SGD, Adam


def sgd(params, lr, momentum, weight_decay, *args, **kwargs):
    return SGD(params, lr=lr, momentum=momentum, weight_decay=weight_decay)


def adam(params, lr, momentum, weight_decay, *args, **kwargs):
    return Adam(params, lr=lr, weight_decay=weight_decay)


optimizer_factories = {
    'sgd': sgd,
    'adam': adam,
}


def get_available_optimizers():
    return optimizer_factories.keys()


def get_optimizer(name, *args, **kwargs):
    return optimizer_factories[name](*args, **kwargs)