"""Manual learning rate schedulers.
"""
import torch


class PresetLRScheduler(object):
  """Using a manually designed learning rate schedule rules.
  """
  def __init__(self, decay_schedule):
    # decay_schedule is a dictionary
    # which is for specifying iteration -> lr
    self.decay_schedule = decay_schedule
    print('=> Using a preset learning rate schedule:')
    print(decay_schedule)
    self.for_once = True

  def __call__(self, optimizer, iteration):
    for param_group in optimizer.param_groups:
      lr = self.decay_schedule.get(iteration, param_group['lr'])
      param_group['lr'] = torch.cuda.FloatTensor([lr])

  @staticmethod
  def get_lr(optimizer):
    for param_group in optimizer.param_groups:
      lr = param_group['lr']
      return lr
