from spaghettini import quick_register

from torch.optim.lr_scheduler import _LRScheduler


@quick_register
class WeightFreezeSchedule(_LRScheduler):
    def __init__(self, optimizer, freeze_at=1e9, thaw_at=1e10, last_epoch=-1):
        self.freeze_at = freeze_at
        self.thaw_at = thaw_at
        self.first = True
        self.init_lrs = None
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        # Save the initial learning rate to use later.
        if self.first:
            self.first = False
            self.init_lrs = [group['lr'] for group in self.optimizer.param_groups]

        # Decide whether or not to set the lr to 0.
        if self.freeze_at <= self.last_epoch < self.thaw_at:
            return [0.0 for _ in self.optimizer.param_groups]
        else:
            return self.init_lrs
