# Lint as: python3
# Copyright 2018 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""Utils for learning rate schedule."""

import abc
from absl import logging

import tensorflow as tf

PI = 3.14159265359


class StepWithWarmupLearningRateSchedule(
    tf.keras.optimizers.schedules.LearningRateSchedule):
  """Configurable learning rate schedule."""

  def __init__(self,
               steps_per_epoch,
               base_lr,
               use_warmup,
               warmup_epochs,
               decay_rate,
               decay_epochs):
    super().__init__()
    self.steps_per_epoch = steps_per_epoch
    self.base_lr = base_lr
    self.use_warmup = use_warmup
    self.warmup_epochs = warmup_epochs
    self.decay_rate = decay_rate
    self.decay_epochs = decay_epochs
    if isinstance(self.decay_epochs, (list, tuple)):
      lr_values = [self.base_lr * (self.decay_rate ** k)
                   for k in range(len(self.decay_epochs) + 1)]
      self._lr_schedule_no_warmup = (
          tf.keras.optimizers.schedules.PiecewiseConstantDecay(
              self.decay_epochs, lr_values))
    else:
      self._lr_schedule_no_warmup = (
          tf.keras.optimizers.schedules.ExponentialDecay(
              self.base_lr, self.decay_epochs, self.decay_rate, staircase=True))

  def get_config(self):
    return {
        'steps_per_epoch': self.steps_per_epoch,
        'base_lr': self.base_lr,
        'use_warmup': self.use_warmup,
        'warmup_epochs': self.warmup_epochs,
        'decay_rate': self.decay_rate,
        'decay_epochs': self.decay_epochs,
    }

  def __call__(self, step):
    lr_epoch = tf.cast(step, tf.float32) / self.steps_per_epoch
    if self.use_warmup:
      return tf.cond(lr_epoch < self.warmup_epochs,
                     lambda: lr_epoch / self.warmup_epochs * self.base_lr,
                     lambda: self._lr_schedule_no_warmup(lr_epoch))
    else:
      return self._lr_schedule_no_warmup(lr_epoch)


class BaseWarmupLearningRateSchedule(
    tf.keras.optimizers.schedules.LearningRateSchedule):
  """Base class for learnign rate schedule with warmup."""

  def __init__(self,
               steps_per_epoch,
               base_lr,
               use_warmup,
               warmup_epochs):
    super().__init__()
    self.steps_per_epoch = steps_per_epoch
    self.base_lr = base_lr
    self.use_warmup = use_warmup
    self.warmup_epochs = warmup_epochs

  @abc.abstractmethod
  def _lr_schedule_no_warmup(self, lr_epoch):
    pass

  def __call__(self, step):
    lr_epoch = tf.cast(step, tf.float32) / self.steps_per_epoch
    if self.use_warmup:
      return tf.cond(
          lr_epoch < self.warmup_epochs,
          lambda: lr_epoch / self.warmup_epochs * self.base_lr,
          lambda: self._lr_schedule_no_warmup(lr_epoch - self.warmup_epochs))
    else:
      return self._lr_schedule_no_warmup(lr_epoch)


class ConstantLearningRateSchedule(BaseWarmupLearningRateSchedule):
  """Constant learning rate."""

  def __init__(self,
               steps_per_epoch,
               base_lr,
               use_warmup,
               warmup_epochs):
    super().__init__(
        steps_per_epoch=steps_per_epoch,
        base_lr=base_lr,
        use_warmup=use_warmup,
        warmup_epochs=warmup_epochs)

  def get_config(self):
    return {
        'steps_per_epoch': self.steps_per_epoch,
        'base_lr': self.base_lr,
        'use_warmup': self.use_warmup,
        'warmup_epochs': self.warmup_epochs,
    }

  def _lr_schedule_no_warmup(self, lr_epoch):
    return self.base_lr


def make_learning_rate_schedule(batch_size,
                                steps_per_epoch,
                                num_training_epochs,
                                learning_rate_hparams):
  """Creates learning rate schedule from hyperparameters.

  Args:
    batch_size: batch size, needed for learning rate scaling
    steps_per_epoch: number of steps per epoch
    num_training_epochs: number of training epochs
    learning_rate_hparams: dictionary with learning rate hyperparameters,
      typically obtained from global hyperparameters as hparams.learning_rate

  Returns:
    Class with learning rate schedule.
  """
  del num_training_epochs
  if learning_rate_hparams.distributed_mode:
    base_lr = learning_rate_hparams.base_lr * (batch_size / 256)
  else:
    base_lr = learning_rate_hparams.base_lr
  logging.info(f'Scaled learning rate {base_lr}')
  if learning_rate_hparams.schedule_type == 'step':
    return StepWithWarmupLearningRateSchedule(
        steps_per_epoch=steps_per_epoch,
        base_lr=base_lr,
        use_warmup=learning_rate_hparams.use_warmup,
        warmup_epochs=learning_rate_hparams.warmup_epochs,
        decay_rate=learning_rate_hparams.decay_rate,
        decay_epochs=learning_rate_hparams.decay_epochs)
  elif learning_rate_hparams.schedule_type == 'constant':
    return ConstantLearningRateSchedule(
        steps_per_epoch=steps_per_epoch,
        base_lr=base_lr,
        use_warmup=learning_rate_hparams.use_warmup,
        warmup_epochs=learning_rate_hparams.warmup_epochs)
  else:
    raise ValueError('Invalid learning rate schedule type: ' +
                     str(learning_rate_hparams.schedule_type))
