import torch
import torch.optim.lr_scheduler as lr_scheduler
import math
import horovod.torch as hvd

def get_Hparam(args, model):
    if args.CLoptimizer == "SGD":
        Hparam = [{"params": model.parameters(), "lr": args.lr, "momentum":0.9, "weight_decay":1e-4, "nesterov":True}]
    elif args.CLoptimizer == "AdamW":
        Hparam = [{"params": model.parameters(), "lr": args.lr, "weight_decay":1e-4}]
    return Hparam

def Getoptim(CLoptimizer,Hparam):
    if CLoptimizer=="SGD":
        return torch.optim.SGD(Hparam)
    elif CLoptimizer=="AdamW":
        return torch.optim.AdamW(Hparam)

def define_scheduler(opt, optimizer):
    if opt.lr_policy == 'linear':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'exp':
        scheduler = lr_scheduler.ExponentialLR(optimizer, 0.1, last_epoch=-1)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif opt.lr_policy == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler


def adjust_learning_rate(optimizer, max_step, warmup_steps, train_steps, args):
    """Decay the learning rate based on schedule"""
    if warmup_steps and train_steps < warmup_steps:
        warmup_percent_done = train_steps / warmup_steps
        warmup_percent_done = max(warmup_percent_done, 0.001)
        warmup_learning_rate = args.lr * warmup_percent_done  #gradual warmup_lr
        cur_lr = warmup_learning_rate
    else:
        
        cur_lr = args.lr * 0.5 * (1. + math.cos(math.pi * (train_steps-warmup_steps) / (max_step-warmup_steps)))

    for param_group in optimizer.param_groups:
        param_group['lr'] = cur_lr
    

    
