

"""A trainable ADAM optimizer that learns its internal variables."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf

from optimizer import trainable_optimizer as opt
from optimizer import utils


class TrainableAdam(opt.TrainableOptimizer):
  """Adam optimizer with learnable scalar parameters.

  See Kingma et. al., 2014 for algorithm (http://arxiv.org/abs/1412.6980).
  """

  def __init__(self,
               learning_rate=1e-3,
               beta1=0.9,
               beta2=0.999,
               epsilon=1e-8,
               **kwargs):
    """Initializes the TrainableAdam optimizer with the given initial values.

    Args:
      learning_rate: The learning rate (default: 1e-3).
      beta1: The exponential decay rate for the 1st moment estimates.
      beta2: The exponential decay rate for the 2nd moment estimates.
      epsilon: A small constant for numerical stability.
      **kwargs: Any additional keyword arguments for TrainableOptimizer.

    Raises:
      ValueError: if the learning rate or epsilon is not positive
      ValueError: if beta1 or beta2 is not in (0, 1).
    """
    if learning_rate <= 0:
      raise ValueError("Learning rate must be positive.")
    if epsilon <= 0:
      raise ValueError("Epsilon must be positive.")
    if not 0 < beta1 < 1 or not 0 < beta2 < 1:
      raise ValueError("Beta values must be between 0 and 1, exclusive.")

    self._reuse_vars = False

    with tf.variable_scope(opt.OPTIMIZER_SCOPE):
      def inv_sigmoid(x):
        return np.log(x / (1.0 - x))

      self.log_learning_rate = tf.get_variable(
          "log_learning_rate",
          shape=[],
          initializer=tf.constant_initializer(np.log(learning_rate)))
      self.beta1_logit = tf.get_variable(
          "beta1_logit",
          shape=[],
          initializer=tf.constant_initializer(inv_sigmoid(beta1)))
      self.beta2_logit = tf.get_variable(
          "beta2_logit",
          shape=[],
          initializer=tf.constant_initializer(inv_sigmoid(beta2)))
      self.log_epsilon = tf.get_variable(
          "log_epsilon",
          shape=[],
          initializer=tf.constant_initializer(np.log(epsilon)))


    state_keys = ["m", "v", "t"]
    super(TrainableAdam, self).__init__("Adam", state_keys, **kwargs)

  def _initialize_state(self, var):
    """Returns a dictionary mapping names of state variables to their values."""
    vectorized_shape = var.get_shape().num_elements(), 1

    return {key: tf.zeros(vectorized_shape) for key in self.state_keys}

  def _compute_update(self, param, grad, state):
    """Calculates the new internal state and parameters.

    If the gradient is sparse, updates the appropriate slices in the internal
    state and stacks the update tensor.

    Args:
      param: A tensor of parameters.
      grad: A tensor of gradients with the same shape as param.
      state: A dictionary containing any state for the optimizer.

    Returns:
      updated_param: The updated parameters.
      updated_state: The updated state variables in a dictionary.
    """

    with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope:

      if self._reuse_vars:
        scope.reuse_variables()
      else:
        self._reuse_vars = True

      (grad_values, first_moment, second_moment, timestep, grad_indices
      ) = self._extract_gradients_and_internal_state(
          grad, state, tf.shape(param))

      beta1 = tf.nn.sigmoid(self.beta1_logit)
      beta2 = tf.nn.sigmoid(self.beta2_logit)
      epsilon = tf.exp(self.log_epsilon) + 1e-10
      learning_rate = tf.exp(self.log_learning_rate)

      old_grad_shape = tf.shape(grad_values)
      grad_values = tf.reshape(grad_values, [-1, 1])

      new_timestep = timestep + 1
      new_first_moment = self._update_adam_estimate(
          first_moment, grad_values, beta1)
      new_second_moment = self._debias_adam_estimate(
          second_moment, tf.square(grad_values), beta2)

      debiased_first_moment = self._debias_adam_estimate(
          new_first_moment, beta1, new_timestep)
      debiased_second_moment = self._debias_adam_estimate(
          new_second_moment, beta2, new_timestep)

      # Propagating through the square root of 0 is very bad for stability.
      update = (learning_rate * debiased_first_moment /
                (tf.sqrt(debiased_second_moment + 1e-10) + epsilon))

      update = tf.reshape(update, old_grad_shape)

      if grad_indices is not None:
        param_shape = tf.shape(param)
        update = utils.stack_tensor(
            update, grad_indices, param, param_shape[:1])
        new_first_moment = utils.update_slices(
            new_first_moment, grad_indices, state["m"], param_shape)
        new_second_moment = utils.update_slices(
            new_second_moment, grad_indices, state["v"], param_shape)
        new_timestep = utils.update_slices(
            new_timestep, grad_indices, state["t"], param_shape)

      new_param = param - update

      # collect the update and new state
      new_state = {
          "m": new_first_moment,
          "v": new_second_moment,
          "t": new_timestep
      }

    return new_param, new_state

  def _update_adam_estimate(self, estimate, value, beta):
    """Returns a beta-weighted average of estimate and value."""
    return (beta * estimate) + ((1 - beta) * value)

  def _debias_adam_estimate(self, estimate, beta, t_step):
    """Returns a debiased estimate based on beta and the timestep."""
    return estimate / (1 - tf.pow(beta, t_step))

  def _extract_gradients_and_internal_state(self, grad, state, param_shape):
    """Extracts the gradients and relevant internal state.

    If the gradient is sparse, extracts the appropriate slices from the state.

    Args:
      grad: The current gradient.
      state: The current state.
      param_shape: The shape of the parameter (used if gradient is sparse).

    Returns:
      grad_values: The gradient value tensor.
      first_moment: The first moment tensor (internal state).
      second_moment: The second moment tensor (internal state).
      timestep: The current timestep (internal state).
      grad_indices: The indices for the gradient tensor, if sparse.
          None otherwise.
    """
    grad_values = grad
    grad_indices = None
    first_moment = state["m"]
    second_moment = state["v"]
    timestep = state["t"]

    if isinstance(grad, tf.IndexedSlices):
      grad_indices, grad_values = utils.accumulate_sparse_gradients(grad)
      first_moment = utils.slice_tensor(
          first_moment, grad_indices, param_shape)
      second_moment = utils.slice_tensor(
          second_moment, grad_indices, param_shape)
      timestep = utils.slice_tensor(timestep, grad_indices, param_shape)

    return grad_values, first_moment, second_moment, timestep, grad_indices

