import tensorflow as tf
from tensorflow.keras.optimizers import Adam


class Optimizer:

  def __init__(self,
               model,
               learning_rate: float,
               strategy: tf.distribute.Strategy,
               mixed_precision: bool,
               name: str = 'optimizer'):
    self.name = name
    self.model = model
    self.mixed_precision = mixed_precision

    with strategy.scope():
      self.optimizer = Adam(learning_rate, beta_1=0.5, beta_2=0.9)

    if self.mixed_precision:
      self.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
          self.optimizer)

  def get_weights(self):
    return self.optimizer.get_weights()

  def set_weights(self, weights):
    self.optimizer.apply_gradients(
        zip([tf.zeros_like(v) for v in self.model.trainable_variables],
            self.model.trainable_variables))
    self.optimizer.set_weights(weights)

  def get_scaled_loss(self, loss):
    return self.optimizer.get_scaled_loss(loss)

  def get_unscaled_gradients(self, scaled_gradients):
    return self.optimizer.get_unscaled_gradients(scaled_gradients)

  def minimize(self, loss, tape):
    gradients = tape.gradient(loss, self.model.trainable_variables)
    if self.mixed_precision:
      gradients = self.get_unscaled_gradients(gradients)
    self.optimizer.apply_gradients(
        zip(gradients, self.model.trainable_variables))
