
"""Scripts for meta-optimization."""

from __future__ import print_function

import os

import tensorflow as tf

import metaopt
from optimizer import coordinatewise_rnn
from optimizer import global_learning_rate
from optimizer import hierarchical_rnn
from optimizer import learning_rate_schedule
from optimizer import trainable_adam
from problems import problem_sets as ps
from problems import problem_spec
from problems import datasets

tf.app.flags.DEFINE_string("train_dir", "opt/",
                           """Directory to store parameters and results.""")

tf.app.flags.DEFINE_integer("task", 0,
                            """Task id of the replica running the training.""")
tf.app.flags.DEFINE_integer("worker_tasks", 1,
                            """Number of tasks in the worker job.""")
# tf.app.flags.DEFINE_integer("num_testing_itrs", 10000,
#                             """Number testing iterations.""")
tf.app.flags.DEFINE_integer("num_problems", 1,
                            """Number of sub-problems to run.""")
tf.app.flags.DEFINE_integer("num_meta_iterations", 5,
                            """Number of meta-iterations to optimize.""")
tf.app.flags.DEFINE_integer("num_unroll_scale", 40,
                            """The scale parameter of the exponential
                            distribution from which the number of partial
                            unrolls is drawn""")
tf.app.flags.DEFINE_integer("min_num_unrolls", 10,
                            """The minimum number of unrolls per problem.""")
tf.app.flags.DEFINE_integer("num_partial_unroll_itr_scale", 20,
                            """The scale parameter of the exponential
                               distribution from which the number of iterations
                               per unroll is drawn.""")
tf.app.flags.DEFINE_integer("min_num_itr_partial_unroll", 10,
                            """The minimum number of iterations for one
                               unroll.""")

tf.app.flags.DEFINE_string("optimizer", "HierarchicalRNN",
                           """Which meta-optimizer to train.""")

# CoordinatewiseRNN-specific flags
tf.app.flags.DEFINE_integer("cell_size", 20,
                            """Size of the RNN hidden state in each layer.""")
tf.app.flags.DEFINE_integer("num_cells", 2,
                            """Number of RNN layers.""")
tf.app.flags.DEFINE_string("cell_cls", "GRUCell",
                           """Type of RNN cell to use.""")

# Metaoptimization parameters
tf.app.flags.DEFINE_float("meta_learning_rate", 1e-6,
                          """The learning rate for the meta-optimizer.""")
tf.app.flags.DEFINE_float("gradient_clip_level", 1e4,
                          """The level to clip gradients to.""")

# Train or test
# tf.app.flags.DEFINE_boolean("training", False,
#                             """training or testing.""")

# Training set selection



tf.app.flags.DEFINE_boolean("include_mnist_conv_problems", False,
                            """Include Convnet problems.""")
tf.app.flags.DEFINE_boolean("include_cifar10_conv_problems", False,
                            """Include Convnet problems.""")
tf.app.flags.DEFINE_boolean("include_mnist_mlp_problems", False,
                            """Include Convnet problems.""")

# Optimizer parameters: initialization and scale values
tf.app.flags.DEFINE_float("min_lr", 1e-6,
                          """The minimum initial learning rate.""")
tf.app.flags.DEFINE_float("max_lr", 1e-2,
                          """The maximum initial learning rate.""")

# Optimizer parameters: small features.
tf.app.flags.DEFINE_boolean("zero_init_lr_weights", True,
                            """Whether to initialize the learning rate weights
                               to 0 rather than the scaled random initialization
                               used for other RNN variables.""")
tf.app.flags.DEFINE_boolean("use_relative_lr", True,
                            """Whether to use the relative learning rate as an
                               input during training. Can only be used if
                               learnable_decay is also True.""")
tf.app.flags.DEFINE_boolean("use_extreme_indicator", False,
                            """Whether to use the extreme indicator for learning
                               rates as an input during training. Can only be
                               used if learnable_decay is also True.""")
tf.app.flags.DEFINE_boolean("use_log_means_squared", True,
                            """Whether to track the log of the mean squared
                               grads instead of the means squared grads.""")
tf.app.flags.DEFINE_boolean("use_problem_lr_mean", True,
                            """Whether to use the mean over all learning rates
                               in the problem when calculating the relative
                               learning rate.""")

# Optimizer parameters: major features
tf.app.flags.DEFINE_boolean("learnable_decay", True,
                            """Whether to learn weights that dynamically
                              modulate the input scale via RMS decay.""")
tf.app.flags.DEFINE_boolean("dynamic_output_scale", True,
                            """Whether to learn weights that dynamically
                               modulate the output scale.""")
tf.app.flags.DEFINE_boolean("use_log_objective", True,
                            """Whether to use the log of the scaled objective
                               rather than just the scaled obj for training.""")
tf.app.flags.DEFINE_boolean("use_attention", False,
                            """Whether to learn where to attend.""")
tf.app.flags.DEFINE_boolean("use_second_derivatives", True,
                            """Whether to use second derivatives.""")
tf.app.flags.DEFINE_integer("num_gradient_scales", 4,
                            """How many different timescales to keep for
                               gradient history. If > 1, also learns a scale
                               factor for gradient history.""")
tf.app.flags.DEFINE_float("max_log_lr", 33,
                          """The maximum log learning rate allowed.""")
tf.app.flags.DEFINE_float("objective_training_max_multiplier", -1,
                          """How much the objective can grow before training on
                             this problem / param pair is terminated. Sets a max
                             on the objective value when multiplied by the
                             initial objective. If <= 0, not used.""")
tf.app.flags.DEFINE_boolean("use_gradient_shortcut", True,
                            """Whether to add a learned affine projection of the
                               gradient to the update delta in addition to the
                               gradient function computed by the RNN.""")
tf.app.flags.DEFINE_boolean("use_lr_shortcut", False,
                            """Whether to add the difference between the current
                               learning rate and the desired learning rate to
                               the RNN input.""")
tf.app.flags.DEFINE_boolean("use_grad_products", True,
                            """Whether to use gradient products in the input to
                               the RNN. Only applicable when num_gradient_scales
                               > 1.""")
tf.app.flags.DEFINE_boolean("use_multiple_scale_decays", False,
                            """Whether to use many-timescale scale decays.""")
tf.app.flags.DEFINE_boolean("use_numerator_epsilon", False,
                            """Whether to use epsilon in the numerator of the
                               log objective.""")
tf.app.flags.DEFINE_boolean("learnable_inp_decay", True,
                            """Whether to learn input decay weight and bias.""")
tf.app.flags.DEFINE_boolean("learnable_rnn_init", True,
                            """Whether to learn RNN state initialization.""")

# MODIFY
tf.app.flags.DEFINE_boolean("if_cl", False, "")
tf.app.flags.DEFINE_boolean("fix_unroll", False, "")
tf.app.flags.DEFINE_integer("fix_unroll_length", 20, "")
tf.app.flags.DEFINE_integer("fix_num_steps", 100, "")
tf.app.flags.DEFINE_integer("fix_num_steps_eval", 100, "")
tf.app.flags.DEFINE_integer("evaluation_period", 1, "")
tf.app.flags.DEFINE_integer("evaluation_epochs", 20, "")
tf.app.flags.DEFINE_integer("save_period", 1, "")

FLAGS = tf.app.flags.FLAGS

# The Size of the RNN hidden state in each layer:
# [PerParam, PerTensor, Global]. The length of this list must be 1, 2, or 3.
# If less than 3, the Global and/or PerTensor RNNs will not be created.

HRNN_CELL_SIZES = [10, 20, 20]



def register_optimizers():
  opts = {}
  opts["CoordinatewiseRNN"] = coordinatewise_rnn.CoordinatewiseRNN
  opts["GlobalLearningRate"] = global_learning_rate.GlobalLearningRate
  opts["HierarchicalRNN"] = hierarchical_rnn.HierarchicalRNN
  opts["LearningRateSchedule"] = learning_rate_schedule.LearningRateSchedule
  opts["TrainableAdam"] = trainable_adam.TrainableAdam
  return opts


def main(_):
  """Runs the main script."""

  opts = register_optimizers()

  # Choose a set of problems to optimize. By default this includes quadratics,
  # 2-dimensional bowls, 2-class softmax problems, and non-noisy optimization
  # test problems (e.g. Rosenbrock, Beale)
  problems_and_data = []


  if FLAGS.include_mnist_conv_problems:
    problems_and_data.extend(ps.mnist_conv_problems())
  
  if FLAGS.include_cifar10_conv_problems:
    problems_and_data.extend(ps.cifar10_conv_problems())
  
  if FLAGS.include_mnist_mlp_problems:
    problems_and_data.extend(ps.mnist_mlp_problems())
  
  
  

  # log directory
  logdir = os.path.join(FLAGS.train_dir,
                        "{}_{}_{}_{}".format(FLAGS.optimizer,
                                             FLAGS.cell_cls,
                                             FLAGS.cell_size,
                                             FLAGS.num_cells))

  # get the optimizer class and arguments
  optimizer_cls = opts[FLAGS.optimizer]

  assert len(HRNN_CELL_SIZES) in [1, 2, 3]
  optimizer_args = (HRNN_CELL_SIZES,)

  optimizer_kwargs = {
      "init_lr_range": (FLAGS.min_lr, FLAGS.max_lr),
      "learnable_decay": FLAGS.learnable_decay,
      "dynamic_output_scale": FLAGS.dynamic_output_scale,
      "cell_cls": getattr(tf.contrib.rnn, FLAGS.cell_cls),
      "use_attention": FLAGS.use_attention,
      "use_log_objective": FLAGS.use_log_objective,
      "num_gradient_scales": FLAGS.num_gradient_scales,
      "zero_init_lr_weights": FLAGS.zero_init_lr_weights,
      "use_log_means_squared": FLAGS.use_log_means_squared,
      "use_relative_lr": FLAGS.use_relative_lr,
      "use_extreme_indicator": FLAGS.use_extreme_indicator,
      "max_log_lr": FLAGS.max_log_lr,
      "obj_train_max_multiplier": FLAGS.objective_training_max_multiplier,
      "use_problem_lr_mean": FLAGS.use_problem_lr_mean,
      "use_gradient_shortcut": FLAGS.use_gradient_shortcut,
      "use_second_derivatives": FLAGS.use_second_derivatives,
      "use_lr_shortcut": FLAGS.use_lr_shortcut,
      "use_grad_products": FLAGS.use_grad_products,
      "use_multiple_scale_decays": FLAGS.use_multiple_scale_decays,
      "use_numerator_epsilon": FLAGS.use_numerator_epsilon,
      "learnable_inp_decay": FLAGS.learnable_inp_decay,
      "learnable_rnn_init": FLAGS.learnable_rnn_init,
  }
  optimizer_spec = problem_spec.Spec(
      optimizer_cls, optimizer_args, optimizer_kwargs)

  # make log directory
  tf.gfile.MakeDirs(logdir)

  is_chief = FLAGS.task == 0
  # if this is a distributed run, make the chief run through problems in order
  select_random_problems = FLAGS.worker_tasks == 1 or not is_chief

  def num_unrolls():
    return metaopt.sample_numiter(FLAGS.num_unroll_scale, FLAGS.min_num_unrolls)

  def num_partial_unroll_itrs():
    return metaopt.sample_numiter(FLAGS.num_partial_unroll_itr_scale,
                                  FLAGS.min_num_itr_partial_unroll)

  # run it
  metaopt.train_optimizer(
      logdir,
      optimizer_spec,
      problems_and_data,
      FLAGS.num_problems,
      FLAGS.num_meta_iterations,
      num_unrolls,
      num_partial_unroll_itrs,
      learning_rate=FLAGS.meta_learning_rate,
      gradient_clip=FLAGS.gradient_clip_level,
      is_chief=is_chief,
      select_random_problems=select_random_problems,
      obj_train_max_multiplier=FLAGS.objective_training_max_multiplier,
      callbacks=[],
      fix_unroll=FLAGS.fix_unroll,
      fix_unroll_length=FLAGS.fix_unroll_length,
      fix_num_steps=FLAGS.fix_num_steps,
      fix_num_steps_eval=FLAGS.fix_num_steps_eval,
      evaluation_period=FLAGS.evaluation_period,
      evaluation_epochs=FLAGS.evaluation_epochs,
      save_period=FLAGS.save_period,
      if_cl=FLAGS.if_cl)
  # else:
    # # test trainable_optimizer
    # for problem_itr, (problem_spec, dataset, batch_size) in enumerate(problems_and_data):
    #
    #     # if dataset is None, use the EMPTY_DATASET
    #     if dataset is None:
    #       dataset = datasets.EMPTY_DATASET
    #       batch_size = dataset.size
    #
    #     # build a new graph for this problem
    #     graph = tf.Graph()
    #
    #     with graph.as_default():
    #
    #         # initialize a problem
    #         problem = problem_spec.build()
    #         metaopt.test_optimizer(
    #             optimizer_spec,
    #             problem,
    #             num_iter=FLAGS.num_testing_itrs,
    #             dataset=dataset,
    #             batch_size=batch_size,
    #             seed=None,
    #             graph=graph,
    #             logdir=logdir,
    #             record_every=None)
    
  return 0


if __name__ == "__main__":
  tf.app.run()
