from torch.optim.lr_scheduler import LRScheduler


class RobbinsMonroScheduler(LRScheduler):
    def __init__(self, optimizer, alpha=0.05, last_epoch=-1):
        self.alpha = alpha  # controls the rate of decrease in learning rate
        self.base_lrs = [group["lr"] for group in optimizer.param_groups]
        super(RobbinsMonroScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        # Ensure to apply Robbins-Monro update only after the first call to `step()`
        if self.last_epoch == 0:
            return [base_lr for base_lr in self.base_lrs]
        return [
            base_lr / (1 + self.alpha * self.last_epoch) for base_lr in self.base_lrs
        ]
