import torch
from torch.optim.lr_scheduler import MultiStepLR


class MultiStepLRWarmup(MultiStepLR):
    def __init__(
        self,
        optimizer,
        milestones,
        warmup_iter=-1,
        warmup_init_lr=0,
        gamma=0.1,
        last_epoch=-1,
        verbose=False,
    ):
        self.warmup_iter = warmup_iter
        self.warmup_init_lr = warmup_init_lr
        super(MultiStepLRWarmup, self).__init__(
            optimizer, milestones, gamma, last_epoch, verbose
        )

    def get_lr(self):
        if self.last_epoch < self.warmup_iter:
            return [
                self.warmup_init_lr
                + (v - self.warmup_init_lr) / self.warmup_iter * self.last_epoch
                for v in self.base_lrs
            ]
        else:
            return super(MultiStepLRWarmup, self).get_lr()


def multi_step_lr(optimizer, milestones, gamma, warmup_iter=-1, warmup_init_lr=0):
    if isinstance(milestones, str):
        milestones = list(map(int, milestones.split("+")))
    lr_scheduler = MultiStepLRWarmup(
        optimizer, milestones, warmup_iter, warmup_init_lr, gamma
    )
    return lr_scheduler
