import torch
from .base_scheduler import BaseScheduler
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.optim.sgd import SGD
from warmup_scheduler import GradualWarmupScheduler


class RampUp(BaseScheduler):
    def __init__(self, optimizer, total_steps, warmup):
        '''
        https://github.com/ildoonet/pytorch-gradual-warmup-lr
        '''
        super(RampUp, self).__init__()
        warmup = int(warmup * total_steps)
        step_size = (total_steps - warmup) / 100
        scheduler_steplr = StepLR(optimizer, step_size=step_size, gamma=0.99)
        scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup,
                                                  after_scheduler=scheduler_steplr)

        # this zero gradient update is needed to avoid a warning message, issue #8.
        optimizer.zero_grad()
        optimizer.step()
        self.scheduler = scheduler_warmup

    def step(self, loss):
        self.scheduler.step()
