import torch
from .base_scheduler import BaseScheduler


class ReduceOnPlateau(BaseScheduler):
    def __init__(self, optimizer, total_steps):
        super(ReduceOnPlateau, self).__init__()

        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=300)
        self.scheduler.step(100) 

    def step(self, loss):
        self.scheduler.step(loss)
