import torch.optim as optim
from warmup_scheduler import GradualWarmupScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR



def get_optimizer(model, args, **kwargs):
    if args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=args.eps, weight_decay=args.weight_decay, betas=(args.beta1, args.beta2))
    elif args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.lr)
    elif args.optim == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=args.lr, eps=args.eps, weight_decay=args.weight_decay, betas=(args.beta1, args.beta2))
    
    if args.scheduler == 'GradualWarmupScheduler_CosineAnnealingLR':
        '''
            optim_T_max: 使用CosineAnnealingLR时的周期长度，即从当前学习率下降到最小学习率所需的epoch，默认4000
            optim_eta_min: 使用CosineAnnealingLR下降到的最小学习率，默认2e-5
            optim_multiplier: 使用GradualWarmupScheduler时的最大学习率与初始学习率的比值，默认10，即先上升10倍，之后用CosineAnnealingLR下降
            optim_total_epoch: 使用GradualWarmupScheduler时的预热的周期数，默认400
        '''
        scheduler_cosine = CosineAnnealingLR(optimizer, T_max=int(args.optim_T_max*args.num_batches), eta_min=float(args.optim_eta_min))
        
        # 设置 GradualWarmupScheduler
        # multiplier 是最大学习率与初始学习率的比值，total_epoch 是预热的周期数，after_scheduler 是预热后的再使用的学习率调整策略
        scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=float(args.optim_multiplier), total_epoch=int(args.optim_total_epoch*args.num_batches), after_scheduler=scheduler_cosine)
        
        scheduler = scheduler_warmup
    elif args.scheduler == 'StepLR':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step*args.num_batches, gamma=args.lr_decay_rate)
    else:
        scheduler = None

    return optimizer, scheduler