
"""Helper utilities for training and testing optimizers."""
import math
from collections import defaultdict
import random
import sys
import time
import os
import psutil
import numpy as np
from six.moves import xrange
import tensorflow as tf
import pdb
from optimizer import trainable_optimizer
from optimizer import utils
from problems import datasets
from problems import problem_generator
from mnist import load_mnist
from cifar10 import load_cifar10
from problems.datasets import batch_indices, lstm_sin, lstm_UCI_HAR_Dataset, lstm_WISDM_HAR_Dataset, lstm_SML_Dataset, MLP_GAS_Dataset

tf.app.flags.DEFINE_integer("ps_tasks", 0,
                            """Number of tasks in the ps job.
                            If 0 no ps job is used.""")
tf.app.flags.DEFINE_integer("num_iterations", 100,
                            """Number of optimizee iterations""")
tf.app.flags.DEFINE_float("noise_scale", 1e-2,
                          """Scale of the noise added on the data.""")
tf.app.flags.DEFINE_float("nan_l2_reg", 1e-2,
                          """Strength of l2-reg when NaNs are encountered.""")
tf.app.flags.DEFINE_float("l2_reg", 0.,
                          """Lambda value for parameter regularization.""")
# Default is 0.9
tf.app.flags.DEFINE_float("rms_decay", 0.9,
                          """Decay value for the RMSProp metaoptimizer.""")
# Default is 1e-10
tf.app.flags.DEFINE_float("rms_epsilon", 1e-20,
                          """Epsilon value for the RMSProp metaoptimizer.""")
tf.app.flags.DEFINE_boolean("set_profiling", False,
                            """Enable memory usage and computation time """
                            """tracing for tensorflow nodes (available in """
                            """TensorBoard).""")
tf.app.flags.DEFINE_boolean("reset_rnn_params", True,
                            """Reset the parameters of the optimizer
                               from one meta-iteration to the next.""")

FLAGS = tf.app.flags.FLAGS
OPTIMIZER_SCOPE = "LOL"
OPT_SUM_COLLECTION = "LOL_summaries"


def sigmoid_weights(n, slope=0.1, offset=5):
  """Generates a sigmoid, scaled to sum to 1.

  This function is used to generate weights that serve to mask out
  the early objective values of an optimization problem such that
  initial variation in the objective is phased out (hence the sigmoid
  starts at zero and ramps up to the maximum value, and the total
  weight is normalized to sum to one)

  Args:
    n: the number of samples
    slope: slope of the sigmoid (Default: 0.1)
    offset: threshold of the sigmoid (Default: 5)

  Returns:
    No
  """
  x = np.arange(n)
  y = 1. / (1. + np.exp(-slope * (x-offset)))
  y_normalized = y / np.sum(y)
  return y_normalized


def sample_numiter(scale, min_steps=50):
  """Samples a number of iterations from an exponential distribution.

  Args:
    scale: parameter for the exponential distribution
    min_steps: minimum number of steps to run (additive)

  Returns:
    num_steps: An integer equal to a rounded sample from the exponential
               distribution + the value of min_steps.
  """
  sample = np.round(np.random.exponential(scale=scale)) + min_steps
  sample = np.clip(sample, 0, 3*min_steps)
  return int(sample)




def gen_idx(dataset,dataset_type,num_classes = 10):
  idx = None
  if dataset_type == "adapt":
    for i in range(num_classes):
      if i%2 != 0:
        if idx is not None:
          idx += dataset.labels == i
        else:
          assert i == 1
          idx = dataset.labels == i
  elif dataset_type == "pretrain":
    for i in range(num_classes):
      if i%2 == 0:
        if idx is not None:
          idx += dataset.labels == i
        else:
          assert i == 0
          idx = dataset.labels == i
  else:
    print("No such dataset type: ", dataset_type)
    exit()
  return idx

def resort_idx(dataset,dataset_type,num_classes=10):
  idx_list = []
  if dataset_type == "adapt":
    for i in range(num_classes):
      if i%2 != 0:
        idx_list.append(dataset.labels == i)
  elif dataset_type == "pretrain":
    for i in range(num_classes):
      if i%2 == 0:
        idx_list.append(dataset.labels == i)
  else:
    print("No such dataset type: ", dataset_type)
    exit()
  for i, idx in enumerate(idx_list):
    dataset.labels[idx] = i

class dummy_dataset:
  def __init__(self, labels, images):
    self.labels = labels
    self.images = images
    self.num_examples = int(self.labels.size)

def convert_dataset(dataset, dataset_type, num_classes = 10):
  assert dataset_type in ["pretrain", "adapt"], ("No supporting dataset: "+dataset_type)
  idx = gen_idx(dataset, dataset_type)
  _dataset = dummy_dataset(dataset.labels[idx], dataset.images[idx])
  resort_idx(_dataset, dataset_type)
  return _dataset

def test_optimizer(optimizer,
                   problem,
                   num_epochs,
                   dataset=datasets.EMPTY_DATASET,
                   batch_size=None,
                   seed=None,
                   graph=None,
                   logdir=None,
                   record_every=None,
                   pretrained_model_path=None,
                   include_is_training=False,
                   lstm_optimizee=False,
                   global_epoch=None,
                   restore_model_name=None):
  """Tests an optimization algorithm on a given problem.

  Args:
    optimizer: Either a tf.train.Optimizer instance, or an Optimizer instance
               inheriting from trainable_optimizer.py
    problem: A Problem instance that defines an optimization problem to solve
    num_iter: The number of iterations of the optimizer to run
    dataset: The dataset to train the problem against
    batch_size: The number of samples per batch. If None (default), the
      batch size is set to the full batch (dataset.size)
    seed: A random seed used for drawing the initial parameters, or a list of
      numpy arrays used to explicitly initialize the parameters.
    graph: The tensorflow graph to execute (if None, uses the default graph)
    logdir: A directory containing model checkpoints. If given, then the
            parameters of the optimizer are loaded from the latest checkpoint
            in this folder.
    record_every: if an integer, stores the parameters, objective, and gradient
                  every recored_every iterations. If None, nothing is stored

  Returns:
    objective_values: A list of the objective values during optimization
    parameters: The parameters obtained after training
    records: A dictionary containing lists of the parameters and gradients
             during optimization saved every record_every iterations (empty if
             record_every is set to None)
  """

  train, test = None, None
  # train_imgs, train_labels = train
  # test_imgs, test_labels = test
  # pdb.set_trace()
  if dataset is None:
    dataset = datasets.EMPTY_DATASET
    batch_size = dataset.size

  elif dataset == 'mnist':

    train, test = load_mnist()
    train_imgs, train_labels = train
    # print(train_labels.shape)
    train_imgs = train_imgs.reshape(-1, 28, 28, 1)
    test_imgs, test_labels = test
    test_imgs = test_imgs.reshape(-1, 28, 28, 1)
    

  elif dataset == 'cifar10':
    train, test = load_cifar10('cifar10')
    train_imgs, train_labels = train
    train_imgs = train_imgs.reshape(-1, 32, 32, 3)
    train_labels = train_labels.reshape(-1,)
    # print(train_labels.shape)
    test_imgs, test_labels = test
    test_imgs = test_imgs.reshape(-1, 32, 32, 3)
    test_labels = test_labels.reshape(-1, )
    # class 1,3,5,7,9

  print("Train imgs shape: ", train_imgs.shape)
  print("Test imgs shape: ", test_imgs.shape)
  # default batch size is the entire dataset
  batch_size = dataset.size if batch_size is None else batch_size

  if "lstm" in dataset:
    graph = tf.get_default_graph() if graph is None else graph
    with graph.as_default():

      # define the parameters of the optimization problem
      if isinstance(seed, (list, tuple)):
        # seed is a list of arrays
        params = problem_generator.init_fixed_variables(seed)
      else:
        # seed is an int or None
        print('init params')
        params = problem.init_variables(seed, pretrained_model_path=pretrained_model_path)

      # data_placeholder = tf.placeholder(tf.float32)
      # labels_placeholder = tf.placeholder(tf.int32)
      # is_training = tf.placeholder(tf.bool)

      # get the problem objective and gradient(s)
      # if include_is_training:
      #   obj = problem.objective(params, data_placeholder, labels_placeholder, is_training)
      # else:
      predictions, obj = problem.objective(params, data_placeholder, labels_placeholder)
      # print(predictions)
      # test_op = problem.accuracy(params, data_placeholder, labels_placeholder)

      gradients, grad_flag_list = problem.gradients(obj, params)

      # give the varibles to optimizer
      try:
        optimizer.set_grad_flag_list(grad_flag_list)
      except:
        print("Using Adam optimizer")

      vars_to_preinitialize = params

    avg_iter_time = []

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(graph=graph, config=config) as sess:
      # initialize the parameter scope variables; necessary for apply_gradients
      sess.run(tf.variables_initializer(vars_to_preinitialize))
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess=sess, coord=coord)

      # create the train operation and training variables
      try:
        train_op, real_params = optimizer.apply_gradients(zip(gradients, params))
        predictions, obj = problem.objective(real_params, data_placeholder, labels_placeholder)
        # , is_training)
      except TypeError:
        # If all goes well, this exception should only be thrown when we are using
        # a non-hrnn optimizer.
        train_op = optimizer.apply_gradients(zip(gradients, params))

      test_op = problem.accuracy(predictions, labels_placeholder)

      vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope=OPTIMIZER_SCOPE)
      vars_to_initialize = list(
          set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) -
          set(vars_to_restore) - set(vars_to_preinitialize))
      # load or initialize optimizer variables
      if logdir is not None:
        restorer = tf.train.Saver(var_list=vars_to_restore)
        ckpt = tf.train.latest_checkpoint(logdir)
        restorer.restore(sess, ckpt)
      else:
        sess.run(tf.variables_initializer(vars_to_restore))
      # initialize all the other variables
      sess.run(tf.variables_initializer(vars_to_initialize))

      problem.init_fn(sess)
      # pdb.set_trace()

      # generate the minibatch indices
      test_num_iter = math.floor(len(test_imgs) / batch_size)
      test_batch_inds = batch_indices(test_imgs, test_labels, test_num_iter, batch_size)

      # run the train operation for n iterations and save the objectives
      records = defaultdict(list)
      # objective_values = []
      # acc_records = []
      training_loss_records = []
      testing_loss_records = []
      testing_acc_records = []

      testing_loss = []
      testing_acc = []
      for itr, batch in enumerate(test_batch_inds):
        test_feed = {data_placeholder: test_imgs[batch],
                    labels_placeholder: test_labels[batch]}

        # run the optimization train operation
        testing_loss.append(sess.run([obj], feed_dict=test_feed)[0])
        testing_acc.append(sess.run([test_op], feed_dict=test_feed)[0])

      epoch_acc = sum(testing_acc) / len(testing_acc)

      testing_epoch_loss = sum(testing_loss) / len(testing_loss)
      print('Before training-loss:', testing_epoch_loss/batch_size)
      print('Before training-acc:', epoch_acc)
      # print('Before training - last testing loss :', testing_loss[-1]/batch_size)

      train_num_iter = math.floor(len(train_imgs) / batch_size)
      for ei in range(num_epochs):
        train_batch_inds = batch_indices(train_imgs, train_labels, train_num_iter, batch_size)
        training_loss = []
        for itr, batch in enumerate(train_batch_inds):
          # data to feed in
          feed = {data_placeholder: train_imgs[batch],
                  labels_placeholder: train_labels[batch]}
          # feed the current epoch number to global epoch placeholder
          if global_epoch is not None:
            feed[global_epoch]=ei
            
          full_feed = {data_placeholder: train_imgs,
                      labels_placeholder: train_labels}

          # record stuff
          if record_every is not None and (itr % record_every) == 0:
            def grad_value(g):
              if isinstance(g, tf.IndexedSlices):
                return g.values
              else:
                return g

            records_fetch = {}
            for p in params:
              for key in optimizer.get_slot_names():
                v = optimizer.get_slot(p, key)
                records_fetch[p.name + "_" + key] = v
            gav_fetch = [(grad_value(g), v) for g, v in zip(gradients, params)]

            _, gav_eval, records_eval = sess.run(
                (obj, gav_fetch, records_fetch), feed_dict=feed)
            full_obj_eval = sess.run([obj], feed_dict=full_feed)

            records["objective"].append(full_obj_eval)
            records["grad_norm"].append([np.linalg.norm(g.ravel())
                                        for g, _ in gav_eval])
            records["param_norm"].append([np.linalg.norm(v.ravel())
                                          for _, v in gav_eval])
            records["grad"].append([g for g, _ in gav_eval])
            records["param"].append([v for _, v in gav_eval])
            records["iter"].append(itr)

            for k, v in records_eval.iteritems():
              records[k].append(v)

          # run the optimization train operation
          train_optimizee_start = time.time()
          fetch = sess.run([train_op, obj], feed_dict=feed)[1]
          train_optimizee_end = time.time()
          train_optimizee_duration = train_optimizee_end-train_optimizee_start
          training_loss.append(fetch)
          avg_iter_time.append(train_optimizee_duration)
        last_itr_loss = training_loss[-1] / batch_size
        print('---'*10)
        print('last training loss :', training_loss[-1]/batch_size)
        training_epoch_loss = sum(training_loss) / len(training_loss)
        training_loss_records.extend(training_loss)

        testing_loss = []
        testing_acc = []

        for itr, batch in enumerate(test_batch_inds):

          test_feed = {data_placeholder: test_imgs[batch],
                      labels_placeholder: test_labels[batch]}

          # if itr == 0:
          #   print(sess.run([predictions], feed_dict=test_feed))

          # run the optimization train operation
          testing_loss.append(sess.run([obj], feed_dict=test_feed)[0])
          testing_acc.append(sess.run([test_op], feed_dict=test_feed)[0])

        epoch_acc = sum(testing_acc) / len(testing_acc)

        testing_epoch_loss = sum(testing_loss) / len(testing_loss)

        print('Test{}-loss:'.format(ei), testing_epoch_loss/batch_size)
        print('Test{}-acc:'.format(ei), epoch_acc)
        # print('last testing loss :', testing_loss[-1]/batch_size)

        # save best ckpt
        if len(testing_acc_records) > 0:
          if epoch_acc > max(testing_acc_records):
            best_parameters = [sess.run(p) for p in params]
            best_acc = epoch_acc
        else:
          best_parameters = [sess.run(p) for p in params]
          best_acc = epoch_acc

        testing_loss_records.extend(testing_loss)
        testing_acc_records.extend(testing_acc)
      # print('acc-records ;', acc_records)
      # final parameters
      parameters = [sess.run(p) for p in params]
      coord.request_stop()
      coord.join(threads)

    return training_loss_records, testing_loss_records, testing_acc_records, last_itr_loss, best_parameters, best_acc, np.mean(np.array(avg_iter_time))

  else:
    graph = tf.get_default_graph() if graph is None else graph
    with graph.as_default():

      # define the parameters of the optimization problem
      if isinstance(seed, (list, tuple)):
        # seed is a list of arrays
        params = problem_generator.init_fixed_variables(seed)
      else:
        # seed is an int or None
        '''
        *****************************************
        '''
        params = problem.init_variables(seed)
        '''
        *****************************************
        '''
        #params = problem.init_variables(seed, pretrained_model_path=pretrained_model_path)

      data_placeholder = tf.placeholder(tf.float32)
      labels_placeholder = tf.placeholder(tf.int32)
      is_training = tf.placeholder(tf.bool)

      # get the problem objective and gradient(s)
      if include_is_training:
        obj = problem.objective(params, data_placeholder, labels_placeholder, is_training)
      else:
        obj = problem.objective(params, data_placeholder, labels_placeholder)
      test_op = problem.accuracy(params, data_placeholder, labels_placeholder)

      gradients, grad_flag_list = problem.gradients(obj, params)

      # give the varibles to optimizer
      try:
        optimizer.set_grad_flag_list(grad_flag_list)
      except:
        print("Using Adam optimizer")

      vars_to_preinitialize = params

    avg_iter_time = []

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(graph=graph, config=config) as sess:
      # initialize the parameter scope variables; necessary for apply_gradients
      sess.run(tf.variables_initializer(vars_to_preinitialize))
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess=sess, coord=coord)

      # create the train operation and training variables
      try:
        train_op, real_params = optimizer.apply_gradients(zip(gradients, params))
        obj = problem.objective(real_params, data_placeholder, labels_placeholder, is_training)
      except TypeError:
        # If all goes well, this exception should only be thrown when we are using
        # a non-hrnn optimizer.
        train_op = optimizer.apply_gradients(zip(gradients, params))

      vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope=OPTIMIZER_SCOPE)
      vars_to_initialize = list(
          set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) -
          set(vars_to_restore) - set(vars_to_preinitialize))
      # load or initialize optimizer variables
      if logdir is not None:
        restorer = tf.train.Saver(var_list=vars_to_restore)
        ckpt = tf.train.latest_checkpoint(logdir)
        '''
        ****************************************************
        '''
        restore_model_name = FLAGS.restore_model_name
        print(restore_model_name)
        
        temp = logdir[0:len(logdir)-28] + 'HierarchicalRNN_GRUCell_20_2/'
        ckpt = os.path.join(temp, restore_model_name)
        '''
        '''
        restorer.restore(sess, ckpt)
      else:
        sess.run(tf.variables_initializer(vars_to_restore))
      # initialize all the other variables
      sess.run(tf.variables_initializer(vars_to_initialize))

      problem.init_fn(sess)
      # pdb.set_trace()

      # generate the minibatch indices
      test_num_iter = math.floor(len(test_imgs) / batch_size)
      test_batch_inds = batch_indices(test_imgs, test_labels, test_num_iter, batch_size)

      # run the train operation for n iterations and save the objectives
      records = defaultdict(list)
      objective_values = []
      '''
      ***************************
      '''
      iteration_count = 0
      training_acc = []
      test_acc = []
      test_loss = []
      '''
      '''
      acc_records = []

      acc = []
      for itr, batch in enumerate(test_batch_inds):

        # data to feed in
        if include_is_training:
          test_feed = {data_placeholder: test_imgs[batch],
                      labels_placeholder: test_labels[batch],
                      is_training: False}
        else:
          test_feed = {data_placeholder: test_imgs[batch],
                      labels_placeholder: test_labels[batch]}

        # run the optimization train operation
        acc.append(sess.run([test_op], feed_dict=test_feed)[0])

      epoch_acc = sum(acc) / len(acc)
      print('Before training-acc:', epoch_acc)

      for ei in range(num_epochs):
        train_num_iter = math.floor(len(train_imgs) / batch_size)

        train_batch_inds = batch_indices(train_imgs, train_labels, train_num_iter, batch_size)
        for itr, batch in enumerate(train_batch_inds):
          # data to feed in
          if include_is_training:
            feed = {data_placeholder: train_imgs[batch],
                    labels_placeholder: train_labels[batch],
                    is_training: True}
            full_feed = {data_placeholder: train_imgs,
                        labels_placeholder: train_labels,
                        is_training: True}
          else:
            feed = {data_placeholder: train_imgs[batch],
                    labels_placeholder: train_labels[batch]}
            full_feed = {data_placeholder: train_imgs,
                        labels_placeholder: train_labels}
          # feed the current epoch number to global epoch placeholder
          if global_epoch is not None:
            feed[global_epoch]=ei
            
          # record stuff
          if record_every is not None and (itr % record_every) == 0:
            def grad_value(g):
              if isinstance(g, tf.IndexedSlices):
                return g.values
              else:
                return g

            records_fetch = {}
            for p in params:
              for key in optimizer.get_slot_names():
                v = optimizer.get_slot(p, key)
                records_fetch[p.name + "_" + key] = v
            gav_fetch = [(grad_value(g), v) for g, v in zip(gradients, params)]

            _, gav_eval, records_eval = sess.run(
                (obj, gav_fetch, records_fetch), feed_dict=feed)
            full_obj_eval = sess.run([obj], feed_dict=full_feed)

            records["objective"].append(full_obj_eval)
            records["grad_norm"].append([np.linalg.norm(g.ravel())
                                        for g, _ in gav_eval])
            records["param_norm"].append([np.linalg.norm(v.ravel())
                                          for _, v in gav_eval])
            records["grad"].append([g for g, _ in gav_eval])
            records["param"].append([v for _, v in gav_eval])
            records["iter"].append(itr)

            for k, v in records_eval.iteritems():
              records[k].append(v)

          # run the optimization train operation
          train_optimizee_start = time.time()
          # print(sess.run([train_op, obj], feed_dict=feed))
          # exit()
          objective_values.append(sess.run([train_op, obj], feed_dict=feed)[1])
          train_optimizee_end = time.time()
          train_optimizee_duration = train_optimizee_end-train_optimizee_start
          avg_iter_time.append(train_optimizee_duration)
          '''
          *******************
          '''
          training_acc.append(sess.run([test_op], feed_dict=feed)[0])
          '''
          *******************
          '''


          '''
          *************
          '''
          iteration_count = iteration_count + 1
          if(iteration_count >= 10000):
            break
          '''
          '''
        epoch_loss = sum(objective_values)/len(objective_values)
        epoch_acc_train = sum(training_acc) / len(training_acc)
        print('epoch{}-train-loss:'.format(ei), epoch_loss)
        print('epoch{}-train-acc:'.format(ei), epoch_acc_train)
        acc = []
        test_obj = []
        for itr, batch in enumerate(test_batch_inds):

          # data to feed in
          if include_is_training:
            test_feed = {data_placeholder: test_imgs[batch],
                        labels_placeholder: test_labels[batch],
                        is_training: False}
          else:
            test_feed = {data_placeholder: test_imgs[batch],
                        labels_placeholder: test_labels[batch]}

          # run the optimization train operation
          acc.append(sess.run([test_op], feed_dict=test_feed)[0])
          '''
          **************
          '''
          test_obj.append(sess.run([obj], feed_dict=test_feed)[0])
          '''
          **************************
          '''
        epoch_test_obj = sum(test_obj)/len(test_obj)
        epoch_acc = sum(acc) / len(acc)
        '''
        *************
        '''
        test_acc.append(epoch_acc)
        test_loss.append(epoch_test_obj)
        '''
        '''
        print('epoch{}-test-loss:'.format(ei), epoch_test_obj)
        print('epoch{}-test-acc:'.format(ei), epoch_acc)
        acc_records.extend(acc)
      # print('acc-records ;', acc_records)
      # final parameters
        '''
        *************
        '''
        if(iteration_count >= 10000):
          break
        '''
        '''
      '''
      *********************************
      '''
      final_train_acc = []
      final_train_loss = []
      train_num_iter = math.floor(len(train_imgs) / batch_size)
      train_batch_inds = batch_indices(train_imgs, train_labels, train_num_iter, batch_size)
      for itr, batch in enumerate(train_batch_inds):
        # data to feed in
        if include_is_training:
          feed = {data_placeholder: train_imgs[batch],
                    labels_placeholder: train_labels[batch],
                    is_training: True}
          full_feed = {data_placeholder: train_imgs,
                        labels_placeholder: train_labels,
                        is_training: True}
        else:
          feed = {data_placeholder: train_imgs[batch],
                    labels_placeholder: train_labels[batch]}
          full_feed = {data_placeholder: train_imgs,
                        labels_placeholder: train_labels}

        final_train_acc.append(sess.run([test_op], feed_dict=feed)[0])
        final_train_loss.append(sess.run([obj], feed_dict=feed)[0])

      final_train_acc = sum(final_train_acc) / len(final_train_acc)
      final_train_loss = sum(final_train_loss) / len(final_train_loss)
      '''
      *********************************
      '''
      parameters = [sess.run(p) for p in params]
      coord.request_stop()
      coord.join(threads)
      print('final test acc = ',epoch_acc)
      print('final test loss = ',epoch_test_obj)
      print('final train loss = ',final_train_loss)
      print('final train acc = ',final_train_acc)
      print('final epoch loss train = ', epoch_loss)
      print('final epoch acc train = ', epoch_acc_train)
      print('final iteration count = ', iteration_count)
    return objective_values, acc_records, parameters, np.mean(np.array(avg_iter_time)), test_acc, test_loss, training_acc, epoch_acc, epoch_test_obj, final_train_loss, final_train_acc


def run_wall_clock_test(optimizer,
                        problem,
                        num_steps,
                        dataset=datasets.EMPTY_DATASET,
                        seed=None,
                        logdir=None,
                        batch_size=None,
                        pretrained_model_path=None):
  """Runs optimization with the given parameters and return average iter time.

  Args:
    optimizer: The tf.train.Optimizer instance
    problem: The problem to optimize (a problem_generator.Problem)
    num_steps: The number of steps to run optimization for
    dataset: The dataset to train the problem against
    seed: The seed used for drawing the initial parameters, or a list of
      numpy arrays used to explicitly initialize the parameters
    logdir: A directory containing model checkpoints. If given, then the
            parameters of the optimizer are loaded from the latest checkpoint
            in this folder.
    batch_size: The number of samples per batch.

  Returns:
    The average time in seconds for a single optimization iteration.
  """
  if dataset is None:
    dataset = datasets.EMPTY_DATASET
    batch_size = dataset.size
  else:
    # default batch size is the entire dataset
    batch_size = dataset.size if batch_size is None else batch_size

  # define the parameters of the optimization problem
  if isinstance(seed, (list, tuple)):
    # seed is a list of arrays
    params = problem_generator.init_fixed_variables(seed)
  else:
    # seed is an int or None
    params = problem.init_variables(seed, pretrained_model_path=pretrained_model_path)

  data_placeholder = tf.placeholder(tf.float32)
  labels_placeholder = tf.placeholder(tf.int32)

  obj = problem.objective(params, data_placeholder, labels_placeholder)
  gradients = problem.gradients(obj, params)
  vars_to_preinitialize = params

  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(graph=tf.get_default_graph(), config=config) as sess:
    # initialize the parameter scope variables; necessary for apply_gradients
    sess.run(tf.variables_initializer(vars_to_preinitialize))
    train_op = optimizer.apply_gradients(zip(gradients, params))
    if isinstance(train_op, tuple) or isinstance(train_op, list):
      # LOL apply_gradients returns a tuple. Regular optimizers do not.
      train_op = train_op[0]
    vars_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        scope=OPTIMIZER_SCOPE)
    vars_to_initialize = list(
        set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) -
        set(vars_to_restore) - set(vars_to_preinitialize))
    # load or initialize optimizer variables
    if logdir is not None:
      restorer = tf.Saver(var_list=vars_to_restore)
      ckpt = tf.train.latest_checkpoint(logdir)
      restorer.restore(sess, ckpt)
    else:
      sess.run(tf.variables_initializer(vars_to_restore))
    # initialize all the other variables
    sess.run(tf.variables_initializer(vars_to_initialize))

    problem.init_fn(sess)

    # generate the minibatch indices
    batch_inds = dataset.batch_indices(num_steps, batch_size)

    avg_iter_time = []
    for batch in batch_inds:
      # data to feed in
      feed = {data_placeholder: dataset.data[batch],
              labels_placeholder: dataset.labels[batch]}

      # run the optimization train operation
      start = time.time()
      sess.run([train_op], feed_dict=feed)
      avg_iter_time.append(time.time() - start)

  return np.median(np.array(avg_iter_time))
