"""ABEL Scheduler Class."""
from flax.training import lr_schedule
import jax


class ABELScheduler():
  """Implementation of ABEL scheduler.

  Attributes:
    num_epochs: Number of epochs the model will be trained for.
    learning_rate: Initial learning rate.
    steps_per_epoch: Number of steps per epoch.
    decay_factor: Decay factor to the learning rate.
    train_fn: Function that inputs a learning and returns a train step function.
      This is needed to update the optimizer when changing the learning rate.
    meas_freq: Number of epochs that we are using for averaging the weight norm.
    warmup: Number of warmup epochs (if any).

  Returns:
    A function that takes as input the current step and returns the learning
      rate to use.
  """

  def __init__(self,
               num_epochs: int,
               learning_rate: float,
               steps_per_epoch: int,
               decay_factor: float,
               train_fn: callable,
               meas_freq: int = 1,
               warmup: int = 0):

    self.num_epochs = num_epochs
    self.learning_rate = learning_rate
    self.steps_per_epoch = steps_per_epoch
    self.decay_factor = decay_factor
    self.meas_freq = meas_freq
    self.train_fn = train_fn
    self.warmup = warmup

    self.learning_rate_fn = self.get_learning_rate_fn(self.learning_rate)
    self.weight_list = []
    self.reached_minima = False
    self.epoch = 0

  def get_learning_rate_fn(self, lr):
    """Outputs a simple decay learning rate function from base learning rate."""
    lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
        lr, self.steps_per_epoch // jax.host_count(),
        [[int(self.num_epochs * 0.85), self.decay_factor]])

    if self.warmup:
      warmup_fn = lambda step: jax.numpy.minimum(
          1., step / self.steps_per_epoch / self.warmup)
    else:
      warmup_fn = lambda step: 1

    return lambda step: lr_fn(step) * warmup_fn(step)

  def update(self, step_fn, weight_norm):
    """Optimizer update rule for ABEL Scheduler."""
    if self.epoch == 0:
      self.avg_weight_norm=0
    self.epoch += 1
    self.avg_weight_norm += weight_norm
    if self.epoch % self.meas_freq != 0: # We only use meas_freq=/=1 for PyramidNet because training is long and noise might drive "reached_minimum=True".
      return step_fn
    
    self.weight_list.append(self.avg_weight_norm/self.meas_freq)
    self.avg_weight_norm=0

    if len(self.weight_list) < 3:
      return step_fn

    if (self.weight_list[-1] - self.weight_list[-2]) * (
        self.weight_list[-2] - self.weight_list[-3]) < 0:
      if self.reached_minima:
        self.reached_minima = False
        self.learning_rate *= self.decay_factor
        step_fn = self.update_train_step(self.learning_rate)
      else:
        self.reached_minima = True

    return step_fn

  def update_train_step(self, learning_rate):
    learning_rate_fn = self.get_learning_rate_fn(learning_rate)
    return self.train_fn(learning_rate_fn=learning_rate_fn)