# -*- coding: utf-8 -*-


def adjust_learning_rate(args, optimizer, init_lr=0.1):
    """Sets the learning rate to the initial LR decayed by # of accessed sample
        We should decay the learning rate based on the number of samples that
        we have accessed.
    """
    # functions.
    def define_lr_decay_by_epoch(args, epoch_index):
        """ decay based on the number of accessed samples per device. """
        for ind, change_epoch in enumerate(args.lr_change_epochs):
            if epoch_index <= change_epoch:
                return args.learning_rate * (0.1 ** ind)
        return args.learning_rate * (0.1 ** 3)

    def define_lr_decay_by_index_poly(args, pow=2):
        """ decay the learning rate polynomially. """
        return args.learning_rate * (
            1 - args.local_index / args.num_batches_total_train) ** 2

    def define_lr_decay_by_auto_detect(args):
        """ decay the learning rate if there is no improvement over epochs. """
        best_epoch = args.best_epoch
        num_best_epoch = len(best_epoch)
        if num_best_epoch < 2:
            return args.lr

        # get best epoch gaps.
        best_epoch_gap = [
            ind for ind in range(1, num_best_epoch)
            if best_epoch[ind] - best_epoch[ind - 1] > args.lr_decay_auto]

        return args.learning_rate * (0.1 ** len(best_epoch_gap))

    def lr_for_mixed_precision(args, epoch_index):
        intervals=args.mixed_precision.split(',')
        for i in range(int(len(intervals)/2)):
            if int(intervals[int(2*i)])<=epoch_index<int(intervals[int(2*i)+1]):
                return 0.01
        return define_lr_decay_by_epoch(args, epoch_index)
        '''
        if epoch_index < 82:
            return 0.1
        elif epoch_index < 122:
            return 0.01
        else:
            return 0.001
        '''
    # adjust learning rate.
    '''if args.mixed_precision is not None:
        num_accessed_samples = args.local_index * args.batch_size
        epoch_index = num_accessed_samples // args.num_train_samples_per_device
        lr = lr_for_mixed_precision(args, epoch_index)
    '''
    if args.lr_decay_epochs is not None:
        num_accessed_samples = args.local_index * args.batch_size
        epoch_index = num_accessed_samples // args.num_train_samples_per_device
        lr = define_lr_decay_by_epoch(args, epoch_index)
    elif args.lr_decay_auto is not None:
        lr = define_lr_decay_by_auto_detect(args)
    else:
        lr = define_lr_decay_by_index_poly(args)

    # lr warmup at the first few epochs.
    if args.lr_warmup and args.local_index < args.num_warmup_samples:
        lr = (lr - init_lr) / args.num_warmup_samples * args.local_index + init_lr

    # assign learning rate.
    if args.old_learning_rate != lr:
        args.old_learning_rate = lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


def adjust_learning_rate_by_lars(args, global_lr, para):
    """Adjust the learning rate via Layer-Wise Adaptive Rate Scaling (LARS)
    """
    lr = global_lr

    if args.lr_lars:
        local_lr = args.lr_lars_eta * para.data.norm() / para.grad.data.norm()
        if args.lr_lars_mode == 'clip':
            lr = min(local_lr, lr)
        elif args.lr_lars_mode == 'scale':
            lr = local_lr * lr
        else:
            raise ValueError('Invalid LARS mode: %s' % args.lr_lars_factor)
    return lr
