

import torch.optim as optim
from sgd_lr_decay import SGD_LRDecay
from sls import Sls
from sgd_lr_band import SGD_LRBand

def load_optim(params, optim_method, step_mode, epoch_mode, eta0, alpha, ratio, milestones, T_max, n_batches_per_epoch, nesterov, momentum, weight_decay):
    """
    Args:
        params: iterable of parameters to optimize or dicts defining
            parameter groups.
        optim_method: which optimizer to use.
        eta0: starting step size.
        alpha: decaying factor for various methods.
        milestones: used for SGD stage decay denoting when to decrease the
            step size, unit in iteration.
        T_max: total number of steps.
        n_batches_per_epoch: number of batches in one train epoch.
        nesterov: whether to use nesterov momentum (True) or not (False).
        momentum: momentum factor used in variants of SGD.
        weight_decay: weight decay factor.

    Outputs:
        an optimizer
    """
    if optim_method == 'Adam':
        optimizer = optim.Adam(params=params, lr=eta0,
                               weight_decay=weight_decay)
    elif optim_method.startswith('SGD') and optim_method.endswith('Decay'):
        if optim_method == 'SGD_Const_Decay':
            scheme = 'const'
        elif optim_method == 'SGD_Exp_Decay':
            scheme = 'exp'
        elif optim_method == 'SGD_1t_Decay':
            scheme = '1t'
        elif optim_method == 'SGD_1sqrt_Decay':
            scheme = '1sqrt'
        elif optim_method == 'SGD_Step_Decay':
            scheme = 'step-decay'
        elif optim_method == 'SGD_Cosine_Decay':
            scheme = 'cosine'
        optimizer = SGD_LRDecay(params=params, scheme=scheme, epoch_mode = epoch_mode, eta0=eta0, alpha=alpha, milestones=milestones, T_max=T_max,n_batches_per_epoch=n_batches_per_epoch, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
    elif optim_method.startswith('SGD') and optim_method.endswith('Band'):
        if optim_method == 'SGD_1t_Band':
            scheme = '1t_band'
        elif optim_method == 'SGD_1sqrt_Band':
            scheme = '1sqrt_band'
        elif optim_method == 'SGD_Step_Band':
            scheme = 'step_band'
        elif optim_method == 'SGD_Exp_Band':
            scheme = 'exp_band'
        optimizer = SGD_LRBand(params=params, scheme=scheme, step_mode=step_mode, epoch_mode = epoch_mode, eta0=eta0, alpha=alpha, ratio=ratio, milestones=milestones, T_max=T_max, n_batches_per_epoch=n_batches_per_epoch, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)

    else:
        raise ValueError("Invalid optimizer: {}".format(optim_method))

    return optimizer
