import torch.optim as optim
from warmup_scheduler import GradualWarmupScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR



def get_optimizer(model, args, **kwargs):
    if args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.lr)
    elif args.optim == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    
    if args.scheduler == 'GradualWarmupScheduler_CosineAnnealingLR':

        if kwargs and 'optim_T_max' in kwargs and 'optim_eta_min' in kwargs:
            scheduler_cosine = CosineAnnealingLR(optimizer, T_max=int(kwargs['optim_T_max']), 
                                                 eta_min=float(kwargs['optim_eta_min']))
        else:
            scheduler_cosine = CosineAnnealingLR(optimizer, T_max=4000, eta_min=2e-5)  

        if kwargs and 'optim_multiplier' in kwargs and 'optim_total_epoch' in kwargs:
            scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=float(kwargs['optim_multiplier']), 
                            total_epoch=int(kwargs['optim_total_epoch']), after_scheduler=scheduler_cosine)
        else:
            scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=400, 
                                                      after_scheduler=scheduler_cosine)
        scheduler = scheduler_warmup
    else:
        scheduler = None

    return optimizer, scheduler