from torch.optim.lr_scheduler import LambdaLR, _LRScheduler


class WarmupMultistepLR(_LRScheduler):
    def __init__(
        self,
        optimizer,
        milestones,
        warmup_period=5,
        gamma=0.1,
        last_epoch=-1,
        verbose=True,
    ):
        self.milestones = milestones
        self.warmup_period = warmup_period
        self.gamma = gamma
        super(WarmupMultistepLR, self).__init__(optimizer, last_epoch, verbose)

        # Initialize epoch and base learning rates
        if last_epoch == -1 and self.warmup_period > 0:
            for group in optimizer.param_groups:
                group["lr"] /= self.warmup_period
                group.setdefault("initial_lr", group["lr"])

        self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups]
        self.last_epoch = last_epoch

    def lr_schedule_lambda(self, epoch, base_lr):
        if epoch < self.warmup_period:
            lr = base_lr / self.warmup_period * (epoch + 1)
            return lr
        else:
            lr = base_lr
            for i in range(len(self.milestones)):
                if epoch >= self.milestones[i]:
                    lr *= self.gamma
            return lr

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, "
                "please use `get_last_lr()`."
            )

        return [
            self.lr_schedule_lambda(self.last_epoch, base_lr)
            for base_lr in self.base_lrs
        ]
