import torch.optim as optim

from utils.sam import SAM


def get_optim(
        model,
        opt: str,
        lr: float = 1e-4,
        weight_decay: float = 5e-4,
        momentum: float = 0.9,
        **opt_kwargs):

    opt = opt.lower()
    optimizer = None

    if opt in ("sgd", "nestrov"):
        optimizer = optim.SGD(model.parameters(), lr=lr,
                              momentum=momentum, nesterov=True,
                              weight_decay=weight_decay, **opt_kwargs)
    elif opt == 'momentum':
        optimizer = optim.SGD(model.parameters(), lr=lr,
                              momentum=momentum, nesterov=False,
                              weight_decay=weight_decay, **opt_kwargs)
    elif opt == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=lr,
                               weight_decay=weight_decay, **opt_kwargs)

    elif opt == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr=lr,
                                weight_decay=weight_decay, **opt_kwargs)

    elif opt == 'sam':
        optimizer = SAM(model.parameters(), optim.SGD, lr=lr,
                        weight_decay=weight_decay, momentum=momentum,
                        nesterov=False, **opt_kwargs)

    elif opt == 'sam_adam':
        optimizer = SAM(model.parameters(), optim.Adam, lr=lr,
                        weight_decay=weight_decay,  **opt_kwargs)

    if optimizer:
        print('optimizer type: {}'.format(opt))
    else:
        raise NotImplementedError('check your optim : {}'.format(opt))

    return optimizer

