

"""A base class definition for trainable optimizers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import itertools
import pdb
import tensorflow as tf

from tensorflow.python.framework import tensor_shape

OPTIMIZER_SCOPE = "LOL"
_LOCAL_VARIABLE_PREFIX = "local_state_"
_LOCAL_STATE_VARIABLE_COLLECTION = "local_state_collection"
EPSILON = 1e-6

tf.app.flags.DEFINE_float("alpha", 5e-4,
                          """The scale for optimizer regularization.""")
tf.app.flags.DEFINE_float("beta", 1e-4,
                          """The scale for optimizer regularization.""")
tf.app.flags.DEFINE_boolean("reg_optimizer", False,
                            """Whether to regularize the optimizer.""")
tf.app.flags.DEFINE_boolean("reg_optimizee", False,
                            """Whether to regularize the optimizee.""")
FLAGS = tf.flags.FLAGS

# MODIFY
def list_mult(l):
    m = 1
    for li in l:
        m *= li
    return m

class TrainableOptimizer(tf.train.Optimizer):
  """Base class for trainable optimizers.

  A trainable optimizer is an optimizer that has parameters that can themselves
  be learned (meta-optimized).

  Subclasses must implement:
      _compute_update(self, param, grad, state)
  """

  def __init__(self, name, state_keys, use_attention=False,
               use_log_objective=False, obj_train_max_multiplier=-1,
               use_second_derivatives=True, use_numerator_epsilon=False,
               **kwargs):
    """Initializes the optimizer with the given name and settings.

    Args:
      name: The name string for this optimizer.
      state_keys: The names of any required state variables (list)
      use_attention: Whether this optimizer uses attention (Default: True)
      use_log_objective: Whether this optimizer uses the logarithm of the
          objective when computing the loss (Default: False)
      obj_train_max_multiplier: The maximum multiplier for the increase in the
          objective before meta-training is stopped. If <= 0, meta-training is
          not stopped early. (Default: -1)
      use_second_derivatives: Whether this optimizer uses second derivatives in
          meta-training. This should be set to False if some second derivatives
          in the meta-training problem set are not defined in Tensorflow.
          (Default: True)
      use_numerator_epsilon: Whether to use epsilon in the numerator when
          scaling the problem objective during meta-training. (Default: False)
      **kwargs: Any additional keyword arguments.
    """
    self.use_second_derivatives = use_second_derivatives
    self.state_keys = sorted(state_keys)
    self.use_attention = use_attention
    self.use_log_objective = use_log_objective
    self.obj_train_max_multiplier = obj_train_max_multiplier
    self.use_numerator_epsilon = use_numerator_epsilon

    use_locking = False
    super(TrainableOptimizer, self).__init__(use_locking, name)

  def _create_slots(self, var_list):
    """Creates all slots needed by the variables.

    Args:
      var_list: A list of `Variable` objects.
    """
    for var in var_list:
      init_states = self._initialize_state(var)
      for slot_name in sorted(init_states):
        slot_var_name = "{}_{}".format(self.get_name(), slot_name)
        value = init_states[slot_name]
        self._get_or_make_slot(var, value, slot_name, slot_var_name)

  def _initialize_state(self, var):
    """Initializes any state required for this variable.

    Args:
      var: a tensor containing parameters to be optimized

    Returns:
      state: a dictionary mapping state keys to initial state values (tensors)
    """
    return {}

  def _initialize_global_state(self):
    """Initializes any global state values."""
    return []

  def _apply_common(self, grad, var):
    """Applies the optimizer updates to the variables.

    Note: this should only get called via _apply_dense or _apply_sparse when
    using the optimizer via optimizer.minimize or optimizer.apply_gradients.
    During meta-training, the optimizer.train function should be used to
    construct an optimization path that is differentiable.

    Args:
      grad: A tensor representing the gradient.
      var: A tf.Variable with the same shape as grad.

    Returns:
      update_op: A tensorflow op that assigns new values to the variable, and
          also defines dependencies that update the state variables for the
          optimizer.
    """
    state = {key: self.get_slot(var, key) for key in self.get_slot_names()}
    new_var, new_state = self._compute_update(var, grad, state)
    state_assign_ops = [tf.assign(state_var, new_state[key])
                        for key, state_var in state.items()]
    with tf.control_dependencies(state_assign_ops):
      update_op = var.assign(new_var)

    return update_op

  def _apply_dense(self, grad, var):
    """Adds ops to apply dense gradients to 'var'."""
    return self._apply_common(grad, var)

  def _apply_sparse(self, grad, var):
    """Adds ops to apply sparse gradients to 'var'."""
    return self._apply_common(grad, var)

  def _compute_update(self, param, grad, state):
    """Computes the update step for optimization.

    Args:
      param: A tensor of parameters to optimize.
      grad: The gradient tensor of the objective with respect to the parameters.
          (It has the same shape as param.)
      state: A dictionary containing any extra state required by the optimizer.

    Returns:
      updated_params: The updated parameters.
      updated_state: The dictionary of updated state variable(s).
    """
    raise NotImplementedError

  def _compute_updates(self, params, grads, states, global_state):
    """Maps the compute update functions for each parameter.

    This function can be overriden by a subclass if the subclass wants to
    combine information across the different parameters in the list.

    Args:
      params: A list of parameter tensors.
      grads: A list of gradients corresponding to each parameter.
      states: A list of state variables corresponding to each parameter.
      global_state: A list of global state variables for the problem.

    Returns:
      new_params: The updated parameters.
      new_states: The updated states.
      new_global_state: The updated global state.
      attention_params: A list of attention parameters. This is the same as
          new_params if the optimizer does not use attention.
    """
    # Zip up the arguments to _compute_update.
    args = zip(params, grads, states)

    # Call compute_update on each set of parameter/gradient/state args.
    new_params, new_states = zip(*list(
        itertools.starmap(self._compute_update, args)))

    # Global state is unused in the basic case, just pass it through.
    return list(new_params), list(new_states), global_state, list(new_params)

  def train(self, problem):
    """Creates graph operations to train the optimizer.

    Args:
      problem: A problem_generator.Problem instance to train on.
      dataset: A datasets.Dataset tuple to use when training.

    Returns:
      meta_objective: A tensorflow operation for computing the meta-objective
      obj_weights: A tensor placeholder for feeding in the objective weights
      obj_values: The subproblem objective values during optimization
      batches: The batch indexes tensor for overriding with feed_dict
      first_unroll: A placeholder signifying if this is a first unroll
        (this will propagate the gradients slightly differently).
      reset_state: A placeholder signifying that the rnn state should be reset.
      output_state: The final state of the optimizer
      init_loop_vars_to_override: Local variables that can be assigned to
        propagate the optimizer and problem state for unrolling
      final_loop_vals: Final values of the loop variables that can be
        assigned to init_loop_vars_to_override.
    """
    # pdb.set_trace()
    # Placeholder for the objective weights
    obj_weights = tf.placeholder(tf.float32)
    num_iter = tf.shape(obj_weights)[0]

    # Unpack the dataset and generate the minibatches for training
    # data, labels = dataset
    # Convert the ndarrays to tensors so we can pass them back in via feed_dict
    mini_data = tf.placeholder(tf.float32)
    mini_labels = tf.placeholder(tf.int32)
    batches = tf.placeholder(tf.int32)
    first_unroll = tf.placeholder_with_default(False, [])
    reset_state = tf.placeholder_with_default(False, [])
    jacob_switch = tf.placeholder(tf.bool)
    # MODIFY
    mode_mt = tf.placeholder_with_default(False, shape=())
    init_tensors = [tf.placeholder(shape=shape, dtype=tf.float32)
                    for shape in problem.param_shapes]
    # (num_params, None(unroll_len), param_shape)
    mt_labels = [tf.placeholder(shape=(None,) + shape, dtype=tf.float32) for shape in problem.param_shapes]

    # MODIFY
    training_output = collections.namedtuple("TrainingOutput",
                                             ["metaobj",
                                              "metaobjmt",
                                              "obj_weights",
                                              "problem_objectives",
                                              "initial_obj",
                                              "batches",
                                              "first_unroll",
                                              "reset_state",
                                              "output_state",
                                              "init_loop_vars",
                                              "output_loop_vars",
                                              "jacob_switch",
                                              "mini_data",
                                              "mini_labels",
                                              "mt_labels",
                                              "init_tensors",
                                              "mode_mt",
                                              ])

    def loop_body(itr, obj_accum, obj_accum_mt, params, attend_params, flattened_states,
                  global_state, all_obj, unused_init_obj, mini_data,
                  mini_labels, batches, regular, mt_labels, mode_mt):
      """Body of the meta-training while loop for optimizing a sub-problem.

      Args:
        itr: The current meta-training iteration.
        obj_accum: The accumulated objective over all training steps so far.
        params: The parameters of the sub-problem.
        attend_params: The parameters of the sub-problems at the attended
            location.
        flattened_states: The states of the trainable optimizer, sorted and
            flattened into a list (since a while loop can't handle nested lists
            or dictionaries).
        global_state: The global state of the optimizer.
        all_obj: The list of all objective values in the training process.
        unused_init_obj: The initial objective (unused here, but needed in the
            variable list because it's used in a stopping condition in the
            loop_cond.)
        data: The data for this problem.
        labels: The labels corresponding to the data.
        batches: The batch indexes needed for shuffled minibatch creation.

      Returns:
        itr: The updated meta-training iteration.
        obj_accum: The updated accumulated objective.
        params: The new parameters of the sub-problem.
        attend_params: The new parameters of the sub-problems at the attended
            location.
        flattened_states: The new states of the trainable optimizer.
        global_state: The updated global state.
        all_obj: The updates list of all objective values.
        unused_init_obj: The initial objective.
        data: The data for this problem.
        labels: The labels corresponding to the data.
        batches: The batch indexes needed for shuffled minibatch creation.
      """
      
      batch_indices = tf.gather(batches, itr)
      batch_data = tf.gather(mini_data, batch_indices)
      batch_labels = tf.gather(mini_labels, batch_indices)
      # MODIFY
      update_steps_mt = [tf.gather(mt_label, itr) for mt_label in mt_labels]  # (num_params, param_shape)

      # Compute the objective over the entire dataset (full batch).
      obj = problem.objective(params, mini_data, mini_labels)

      # Compute the gradients on just the current batch
      if self.use_attention:
        current_obj = problem.objective(attend_params, batch_data, batch_labels)
        grads, reg = problem.gradients(current_obj, attend_params)
      else:
        current_obj = problem.objective(params, batch_data, batch_labels)
        grads, reg = problem.gradients(current_obj, params)
      
      alpha = FLAGS.alpha
      beta = FLAGS.beta
      
      # regularize optimziee
      if FLAGS.reg_optimizee:
        current_obj = current_obj + beta * reg
        # TODO regularize optimizee
        if self.use_attention:
            grads = tf.gradients(current_obj, attend_params)
        else:
            grads = tf.gradients(current_obj, params)
            
      if not self.use_second_derivatives:
        new_grads = []
        for grad in grads:
          if isinstance(grad, tf.IndexedSlices):
            new_grads.append(
                tf.IndexedSlices(tf.stop_gradient(grad.values), grad.indices))
          else:
            new_grads.append(tf.stop_gradient(grad))
        grads = new_grads
      
      # jacob regularizer
      
      # pdb.set_trace()
      # print(alpha)
      # print(beta)
      # print(FLAGS.reg_optimizer)
      # print(FLAGS.reg_optimizee)
      
      # regularize optimizer
      if FLAGS.reg_optimizer:
        # jacob_losses = tf.reduce_mean(tf.stack(jacob_losses))
        reg = tf.cond(jacob_switch, lambda: reg, lambda: regular)
        regular += alpha * reg
      
      # grad regularizer
      
      # alpha = 2e-7
      # jacob_losses = 0
      # for grad in grads:
      #   jacob_losses += tf.reduce_sum(grad) / tf.cast(tf.size(grad), tf.float32)
      # obj_accum = tf.add(obj_accum, alpha * jacob_losses)
      # store the objective value for the entire problem at each iteration
      
      all_obj = tf.concat([all_obj, tf.reshape(obj, (1,))], 0)
      # pdb.set_trace()
      # print(obj.shape)
      # accumulate the weighted objective for the entire dataset
      acc = tf.gather(obj_weights, itr) * obj

      obj_accum = tf.add(obj_accum, acc)
      # Set the shape to keep the shape invariant for obj_accum. Without this,
      # the graph builder thinks the tensor shape is unknown on the 2nd iter.
      obj_accum.set_shape([])

      # convert flattened_states to dictionaries
      dict_states = [dict(zip(self.state_keys, flat_state))
                     for flat_state in flattened_states]

      # compute the new parameters and states
      # MODIFY
      args = (params, grads, dict_states, global_state, mode_mt, update_steps_mt)
      updates = self._compute_updates(*args)
      new_params, new_states, new_global_state, new_attend_params, update_steps = updates
      # MSE loss of mt
      mse_loss = sum([tf.reduce_sum(0.5 * (d1 - d2) * (d1 - d2))
                      for d1, d2 in zip(update_steps, update_steps_mt)])
      mse_loss = mse_loss / problem.num_params
      mse_loss = tf.gather(obj_weights, itr) * mse_loss
      obj_accum_mt = tf.add(obj_accum_mt, mse_loss)
      obj_accum_mt.set_shape([])

      # print(new_global_state)
      # flatten the states
      new_flattened_states = [flatten_and_sort(item) for item in new_states]
      # new_flattened_states = map(flatten_and_sort, new_states)
      # print(new_flattened_states)
      # print([itr + 1, obj_accum, new_params, new_attend_params,
      #         new_flattened_states, new_global_state, all_obj, unused_init_obj,
      #         data, labels, batches])
      return [itr + 1, obj_accum, obj_accum_mt, new_params, new_attend_params,
              new_flattened_states, new_global_state, all_obj, unused_init_obj,
              mini_data, mini_labels, batches, regular, mt_labels, mode_mt]

    def loop_cond(itr, obj_accum, unused_obj_accum_mt, unused_params, unused_attend_params,
                  unused_flattened_states, unused_global_state, all_obj,
                  init_obj, *args):
      """Termination conditions of the sub-problem optimization loop."""
      # pdb.set_trace()
      del args  # unused

      cond1 = tf.less(itr, num_iter)  # We've run < num_iter times
      cond2 = tf.is_finite(obj_accum)  # The objective is still finite
      # print(cond1)
      # print(cond2)
      if self.obj_train_max_multiplier > 0:
        current_obj = tf.gather(all_obj, itr)
        # Account for negative init_obj too
        max_diff = (self.obj_train_max_multiplier - 1) * tf.abs(init_obj)
        max_obj = init_obj + max_diff
        # The objective is a reasonable multiplier of the original objective
        cond3 = tf.less(current_obj, max_obj)

        return tf.logical_and(tf.logical_and(cond1, cond2), cond3,
                              name="training_loop_cond")
      else:
        return tf.logical_and(cond1, cond2, name="training_loop_cond")
    # MODIFY
    init = self._initialize_training_loop_parameters(
        problem, mini_data, mini_labels, batches, first_unroll, reset_state, init_tensors)
    loop_vars, invariants, initial_obj, init_loop_vars_to_override = init
    loop_vars.extend([mt_labels, mode_mt])
    invariants.extend([[mt_label.get_shape() for mt_label in mt_labels], mode_mt.get_shape()])
    # print(loop_vars)
    # pdb.set_trace()
    # print(len(loop_vars))

    loop_output = tf.while_loop(loop_cond, loop_body, loop_vars,
                                swap_memory=False, shape_invariants=invariants)
    # MODIFY
    meta_obj, meta_obj_mt, problem_objectives, jacob = \
        loop_output[1], loop_output[2], loop_output[7], loop_output[-3]

    # The meta objective is normalized by the initial objective at the start of
    # the series of partial unrolls.
    scaled_meta_objective = self.scale_objective(
        meta_obj, problem_objectives, initial_obj)
    scaled_meta_objective += jacob
    # MODIFY
    final_loop_vals = (
            [initial_obj] + loop_output[3] + loop_output[4] + loop_output[6])
    final_loop_vals.extend(itertools.chain(*loop_output[5]))

    # MODIFY
    return training_output(scaled_meta_objective,
                           meta_obj_mt,
                           obj_weights,
                           problem_objectives,
                           initial_obj,
                           batches,
                           first_unroll,
                           reset_state,
                           loop_output[5],
                           init_loop_vars_to_override,
                           final_loop_vals,
                           jacob_switch,
                           mini_data,
                           mini_labels,
                           mt_labels,
                           init_tensors,
                           mode_mt
                           )

  def _initialize_training_loop_parameters(
      self, problem, data, labels, batches, first_unroll, reset_state, init_tensors):
    """Initializes the vars and params needed for the training process.

    Args:
      problem: The problem being optimized.
      data: The data for the problem.
      labels: The corresponding labels for the data.
      batches: The indexes needed to create shuffled batches of the data.
      first_unroll: Whether this is the first unroll in a partial unrolling.
      reset_state: Whether RNN state variables should be reset.

    Returns:
      loop_vars: The while loop variables for training.
      invariants: The corresponding variable shapes (required by while loop).
      initial_obj: The initial objective (used later for scaling).
      init_loop_vars_to_override: The loop vars that can be overridden when
          performing training via partial unrolls.
    """
    # Extract these separately so we don't have to make inter-variable
    # dependencies.
    # MODIFY
    # initial_tensors = problem.init_tensors()
    initial_tensors = init_tensors

    return_initial_tensor_values = first_unroll
    initial_params_vars, initial_params = local_state_variables(
        initial_tensors, return_initial_tensor_values)
    initial_attend_params_vars, initial_attend_params = local_state_variables(
        initial_tensors, return_initial_tensor_values)
    # Recalculate the initial objective for the list on each partial unroll with
    # the new initial_params. initial_obj holds the value from the very first
    # unroll.
    initial_obj_init = problem.objective(initial_params, data, labels)
    return_initial_obj_init = first_unroll
    [initial_obj_var], [initial_obj] = local_state_variables(
        [initial_obj_init], return_initial_obj_init)

    # Initialize the loop variables.
    initial_itr = tf.constant(0, dtype=tf.int32)
    initial_meta_obj = tf.constant(0, dtype=tf.float32)
    # MODIFY
    initial_meta_obj_mt = tf.constant(0, dtype=tf.float32)

    # N.B. the use of initial_obj_init here rather than initial_obj
    initial_problem_objectives = tf.reshape(initial_obj_init, (1,))

    # Initialize the extra state.
    initial_state_vars = []
    initial_state = []
    state_shapes = []
    return_initial_state_values = reset_state
    for param in initial_tensors:
      param_state_vars, param_state = local_state_variables(
          flatten_and_sort(self._initialize_state(param)),
          return_initial_state_values)

      initial_state_vars.append(param_state_vars)
      initial_state.append(param_state)
      state_shapes.append([f.get_shape() for f in param_state])

    # Initialize any global (problem-level) state.
    initial_global_state_vars, initial_global_state = local_state_variables(
        self._initialize_global_state(), return_initial_state_values)
    
    init_regular = tf.constant(0.)

    global_shapes = []
    for item in initial_global_state:
      global_shapes.append(item.get_shape())

    # build the list of loop variables:
    # MODIFY
    loop_vars = [
        initial_itr,
        initial_meta_obj,
        initial_meta_obj_mt,
        initial_params,         # Local variables.
        initial_attend_params,  # Local variables.
        initial_state,          # Local variables.
        initial_global_state,   # Local variables.
        initial_problem_objectives,
        initial_obj,            # Local variable.
        data,
        labels,
        batches,
        init_regular
    ]

    invariants = [
        initial_itr.get_shape(),
        initial_meta_obj.get_shape(),
        initial_meta_obj_mt.get_shape(),
        [t.get_shape() for t in initial_params],
        [t.get_shape() for t in initial_attend_params],
        state_shapes,
        global_shapes,
        tensor_shape.TensorShape([None]),   # The problem objectives list grows
        initial_obj.get_shape(),
        tensor_shape.unknown_shape(),  # Placeholder shapes are unknown
        tensor_shape.unknown_shape(),
        tensor_shape.unknown_shape(),
        init_regular.get_shape()
    ]

    # Initialize local variables that we will override with final tensors at the
    # next iter.
    init_loop_vars_to_override = (
        [initial_obj_var] + initial_params_vars + initial_attend_params_vars +
        initial_global_state_vars)
    init_loop_vars_to_override.extend(itertools.chain(*initial_state_vars))

    return loop_vars, invariants, initial_obj, init_loop_vars_to_override

  def scale_objective(self, total_obj, all_objs, initial_obj,
                      obj_scale_eps=1e-6):
    """Normalizes the objective based on the initial objective value.

    Args:
      total_obj: The total accumulated objective over the training run.
      all_objs: A list of all the individual objectives over the training run.
      initial_obj: The initial objective value.
      obj_scale_eps: The epsilon value to use in computations for stability.

    Returns:
      The scaled objective as a single value.
    """
    if self.use_log_objective:
      if self.use_numerator_epsilon:
        scaled_problem_obj = ((all_objs + obj_scale_eps) /
                              (initial_obj + obj_scale_eps))
        log_scaled_problem_obj = tf.log(scaled_problem_obj)
      else:
        scaled_problem_obj = all_objs / (initial_obj + obj_scale_eps)
        log_scaled_problem_obj = tf.log(scaled_problem_obj + obj_scale_eps)
      return tf.reduce_mean(log_scaled_problem_obj)
    else:
      return total_obj / (initial_obj + obj_scale_eps)


def local_state_variables(init_values, return_init_values):
  """Create local variables initialized from init_values.

  This will create local variables from a list of init_values. Each variable
  will be named based on the value's shape and dtype.

  As a convenience, a boolean tensor allows you to return value from
  the created local variable or from the original init value.

  Args:
    init_values: iterable of tensors
    return_init_values: boolean tensor

  Returns:
    local_vars: list of the created local variables.
    vals: if return_init_values is true, then this returns the values of
      init_values. Otherwise it returns the values of the local_vars.
  """
  if not init_values:
    return [], []

  # This generates a harmless warning when saving the metagraph.
  variable_use_count = tf.get_collection_ref(_LOCAL_STATE_VARIABLE_COLLECTION)
  if not variable_use_count:
    variable_use_count.append(collections.defaultdict(int))
  variable_use_count = variable_use_count[0]

  local_vars = []
  with tf.variable_scope(OPTIMIZER_SCOPE):
    # We can't use the init_value as an initializer as init_value may
    # itself depend on some problem variables. This would produce
    # inter-variable initialization order dependence which TensorFlow
    # sucks at making easy.
    for init_value in init_values:
      name = create_local_state_variable_name(init_value)
      unique_name = name + "_" + str(variable_use_count[name])
      variable_use_count[name] += 1
      # The overarching idea here is to be able to reuse variables between
      # different sessions on the same TensorFlow master without errors. By
      # uniquifying based on the type and name we mirror the checks made inside
      # TensorFlow, while still allowing some memory reuse. Ultimately this is a
      # hack due to the broken Session.reset().
      local_vars.append(
          tf.get_local_variable(
              unique_name,
              initializer=tf.zeros(
                  init_value.get_shape(), dtype=init_value.dtype)))

  # It makes things a lot simpler if we use the init_value the first
  # iteration, instead of the variable itself. It allows us to propagate
  # gradients through it as well as simplifying initialization. The variable
  # ends up assigned to after the first iteration.
  vals = tf.cond(return_init_values, lambda: init_values, lambda: local_vars)
  if len(init_values) == 1:
    # tf.cond extracts elements from singleton lists.
    vals = [vals]
  return local_vars, vals


def create_local_state_variable_name(tensor):
  """Create a name of the variable based on its type and shape."""
  if not tensor.get_shape().is_fully_defined():
    raise ValueError("Need a fully specified shape to create a local variable.")

  return (_LOCAL_VARIABLE_PREFIX + "_".join(
      map(str, tensor.get_shape().as_list())) + "_" + tensor.dtype.name)


def is_local_state_variable(op):
  """Returns if this op is a local state variable created for training."""
  return op.node_def.op in ["Variable", "VariableV2"] and op.name.startswith(
      OPTIMIZER_SCOPE + "/" + _LOCAL_VARIABLE_PREFIX)


def flatten_and_sort(dictionary):
  """Flattens a dictionary into a list of values sorted by the keys."""
  return [dictionary[k] for k in sorted(dictionary.keys())]
