from torch.optim.lr_scheduler import _LRScheduler


def get_linear_lr_update(max_epochs, start_decay_at):

    def update_lr(epoch):

        # number of epochs to decay
        n_decay = max_epochs - start_decay_at

        # keep learning rate constant in the beginning
        if epoch < start_decay_at:
            return 1.0

        # decay learning rate
        else:
            return float(n_decay - (epoch - start_decay_at)) / n_decay

    return update_lr


class ConstantLR(_LRScheduler):
    """
    Constant learning rate scheduler
    """

    def __init__(self, optimizer, last_epoch=-1):
        super(ConstantLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr for base_lr in self.base_lrs]
