import math
import numpy as np


def warmup_lr(args, epoch, dl, it, optimizer):
    it_g = 1 + it + epoch * len(dl)/args.bsz  # global training iteration
    if it_g <= args.warmup_iters:
        lr = args.lr * it_g / float(args.warmup_iters) # jem
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


def adjust_lr(args, epoch, optimizer, log=None):
    if epoch in args.lr_decay_epochs:
        lr = 0
        for param_group in optimizer.param_groups:
            lr = param_group['lr'] * args.lr_decay_rate
            param_group['lr'] = lr
        # log.update(1, lr=lr)
