
"""Scripts for meta-optimization."""

from __future__ import print_function

import os
import pickle

import tensorflow as tf

import metaopt
from optimizer import coordinatewise_rnn
from optimizer import global_learning_rate
from optimizer import hierarchical_rnn
from optimizer import learning_rate_schedule
from optimizer import trainable_adam
from problems import problem_sets as ps
from problems import problem_spec
from problems import datasets

tf.app.flags.DEFINE_string("train_dir", "opt/",
                           """Directory to store parameters and results.""")

tf.app.flags.DEFINE_string("test_optimizer", "L2o",
                           """optimizer to test.""")

tf.app.flags.DEFINE_integer("task", 0,
                            """Task id of the replica running the training.""")
tf.app.flags.DEFINE_integer("worker_tasks", 1,
                            """Number of tasks in the worker job.""")
tf.app.flags.DEFINE_integer("num_testing_itrs", 10000,
                            """Number testing iterations.""")
tf.app.flags.DEFINE_integer("num_problems", 1,
                            """Number of sub-problems to run.""")
tf.app.flags.DEFINE_integer("num_meta_iterations", 50,
                            """Number of meta-iterations to optimize.""")
tf.app.flags.DEFINE_integer("num_unroll_scale", 40,
                            """The scale parameter of the exponential
                            distribution from which the number of partial
                            unrolls is drawn""")
tf.app.flags.DEFINE_integer("min_num_unrolls", 4,
                            """The minimum number of unrolls per problem.""")
tf.app.flags.DEFINE_integer("num_partial_unroll_itr_scale", 50,
                            """The scale parameter of the exponential
                               distribution from which the number of iterations
                               per unroll is drawn.""")
tf.app.flags.DEFINE_integer("min_num_itr_partial_unroll", 10,
                            """The minimum number of iterations for one
                               unroll.""")

tf.app.flags.DEFINE_string("optimizer", "HierarchicalRNN",
                           """Which meta-optimizer to train.""")

# CoordinatewiseRNN-specific flags
tf.app.flags.DEFINE_integer("cell_size", 10,
                            """Size of the RNN hidden state in each layer.""")
tf.app.flags.DEFINE_integer("num_cells", 2,
                            """Number of RNN layers.""")
tf.app.flags.DEFINE_string("cell_cls", "GRUCell",
                           """Type of RNN cell to use.""")

# Metaoptimization parameters
tf.app.flags.DEFINE_float("meta_learning_rate", 1e-6,
                          """The learning rate for the meta-optimizer.""")
tf.app.flags.DEFINE_float("gradient_clip_level", 1e4,
                          """The level to clip gradients to.""")

# Train or test
# tf.app.flags.DEFINE_boolean("training", False,
#                             """training or testing.""")

# Training set selection

tf.app.flags.DEFINE_boolean("include_mnist_conv_problems", False,
                            """Include Convnet problems.""")
tf.app.flags.DEFINE_boolean("include_mnist_mlp_problems", False,
                            """Include Convnet problems.""")


# Optimizer parameters: initialization and scale values
tf.app.flags.DEFINE_float("min_lr", 1e-6,
                          """The minimum initial learning rate.""")
tf.app.flags.DEFINE_float("max_lr", 1e-2,
                          """The maximum initial learning rate.""")

# Optimizer parameters: small features.
tf.app.flags.DEFINE_boolean("zero_init_lr_weights", True,
                            """Whether to initialize the learning rate weights
                               to 0 rather than the scaled random initialization
                               used for other RNN variables.""")
tf.app.flags.DEFINE_boolean("use_relative_lr", True,
                            """Whether to use the relative learning rate as an
                               input during training. Can only be used if
                               learnable_decay is also True.""")
tf.app.flags.DEFINE_boolean("use_extreme_indicator", False,
                            """Whether to use the extreme indicator for learning
                               rates as an input during training. Can only be
                               used if learnable_decay is also True.""")
tf.app.flags.DEFINE_boolean("use_log_means_squared", True,
                            """Whether to track the log of the mean squared
                               grads instead of the means squared grads.""")
tf.app.flags.DEFINE_boolean("use_problem_lr_mean", True,
                            """Whether to use the mean over all learning rates
                               in the problem when calculating the relative
                               learning rate.""")

# Optimizer parameters: major features
tf.app.flags.DEFINE_boolean("learnable_decay", True,
                            """Whether to learn weights that dynamically
                              modulate the input scale via RMS decay.""")
tf.app.flags.DEFINE_boolean("dynamic_output_scale", True,
                            """Whether to learn weights that dynamically
                               modulate the output scale.""")
tf.app.flags.DEFINE_boolean("use_log_objective", True,
                            """Whether to use the log of the scaled objective
                               rather than just the scaled obj for training.""")
tf.app.flags.DEFINE_boolean("use_attention", False,
                            """Whether to learn where to attend.""")
tf.app.flags.DEFINE_boolean("use_second_derivatives", True,
                            """Whether to use second derivatives.""")
tf.app.flags.DEFINE_integer("num_gradient_scales", 4,
                            """How many different timescales to keep for
                               gradient history. If > 1, also learns a scale
                               factor for gradient history.""")
tf.app.flags.DEFINE_float("max_log_lr", 33,
                          """The maximum log learning rate allowed.""")
tf.app.flags.DEFINE_float("objective_training_max_multiplier", -1,
                          """How much the objective can grow before training on
                             this problem / param pair is terminated. Sets a max
                             on the objective value when multiplied by the
                             initial objective. If <= 0, not used.""")
tf.app.flags.DEFINE_boolean("use_gradient_shortcut", True,
                            """Whether to add a learned affine projection of the
                               gradient to the update delta in addition to the
                               gradient function computed by the RNN.""")
tf.app.flags.DEFINE_boolean("use_lr_shortcut", False,
                            """Whether to add the difference between the current
                               learning rate and the desired learning rate to
                               the RNN input.""")
tf.app.flags.DEFINE_boolean("use_grad_products", True,
                            """Whether to use gradient products in the input to
                               the RNN. Only applicable when num_gradient_scales
                               > 1.""")
tf.app.flags.DEFINE_boolean("use_multiple_scale_decays", False,
                            """Whether to use many-timescale scale decays.""")
tf.app.flags.DEFINE_boolean("use_numerator_epsilon", False,
                            """Whether to use epsilon in the numerator of the
                               log objective.""")
tf.app.flags.DEFINE_boolean("learnable_inp_decay", True,
                            """Whether to learn input decay weight and bias.""")
tf.app.flags.DEFINE_boolean("learnable_rnn_init", True,
                            """Whether to learn RNN state initialization.""")

FLAGS = tf.app.flags.FLAGS

# The Size of the RNN hidden state in each layer:
# [PerParam, PerTensor, Global]. The length of this list must be 1, 2, or 3.
# If less than 3, the Global and/or PerTensor RNNs will not be created.

HRNN_CELL_SIZES = [10, 20, 20]


def register_optimizers():
    opts = {}
    opts["CoordinatewiseRNN"] = coordinatewise_rnn.CoordinatewiseRNN
    opts["GlobalLearningRate"] = global_learning_rate.GlobalLearningRate
    opts["HierarchicalRNN"] = hierarchical_rnn.HierarchicalRNN
    opts["LearningRateSchedule"] = learning_rate_schedule.LearningRateSchedule
    opts["TrainableAdam"] = trainable_adam.TrainableAdam
    return opts


def main(_):
    """Runs the main script."""
    
    opts = register_optimizers()
    
    # Choose a set of problems to optimize. By default this includes quadratics,
    # 2-dimensional bowls, 2-class softmax problems, and non-noisy optimization
    # test problems (e.g. Rosenbrock, Beale)
    problems_and_data = []
    

    if FLAGS.include_mnist_conv_problems:
        problems_and_data.extend(ps.mnist_conv_problems())
    
    if FLAGS.include_cifar10_conv_problems:
        problems_and_data.extend(ps.cifar10_conv_problems())
    
    if FLAGS.include_mnist_mlp_problems:
        problems_and_data.extend(ps.mnist_mlp_problems())
    

    
    # log directory
    logdir = os.path.join(FLAGS.train_dir,
                          "{}_{}_{}_{}".format(FLAGS.optimizer,
                                               FLAGS.cell_cls,
                                               FLAGS.cell_size,
                                               FLAGS.num_cells))
    
    # get the optimizer class and arguments
    optimizer_cls = opts[FLAGS.optimizer]
    
    assert len(HRNN_CELL_SIZES) in [1, 2, 3]
    optimizer_args = (HRNN_CELL_SIZES,)
    
    optimizer_kwargs = {
        "init_lr_range": (FLAGS.min_lr, FLAGS.max_lr),
        "learnable_decay": FLAGS.learnable_decay,
        "dynamic_output_scale": FLAGS.dynamic_output_scale,
        "cell_cls": getattr(tf.contrib.rnn, FLAGS.cell_cls),
        "use_attention": FLAGS.use_attention,
        "use_log_objective": FLAGS.use_log_objective,
        "num_gradient_scales": FLAGS.num_gradient_scales,
        "zero_init_lr_weights": FLAGS.zero_init_lr_weights,
        "use_log_means_squared": FLAGS.use_log_means_squared,
        "use_relative_lr": FLAGS.use_relative_lr,
        "use_extreme_indicator": FLAGS.use_extreme_indicator,
        "max_log_lr": FLAGS.max_log_lr,
        "obj_train_max_multiplier": FLAGS.objective_training_max_multiplier,
        "use_problem_lr_mean": FLAGS.use_problem_lr_mean,
        "use_gradient_shortcut": FLAGS.use_gradient_shortcut,
        "use_second_derivatives": FLAGS.use_second_derivatives,
        "use_lr_shortcut": FLAGS.use_lr_shortcut,
        "use_grad_products": FLAGS.use_grad_products,
        "use_multiple_scale_decays": FLAGS.use_multiple_scale_decays,
        "use_numerator_epsilon": FLAGS.use_numerator_epsilon,
        "learnable_inp_decay": FLAGS.learnable_inp_decay,
        "learnable_rnn_init": FLAGS.learnable_rnn_init,
    }
    
    
    # make log directory
    tf.gfile.MakeDirs(logdir)
    
    is_chief = FLAGS.task == 0
    # if this is a distributed run, make the chief run through problems in order
    select_random_problems = FLAGS.worker_tasks == 1 or not is_chief
    
    def num_unrolls():
        return metaopt.sample_numiter(FLAGS.num_unroll_scale, FLAGS.min_num_unrolls)
    
    def num_partial_unroll_itrs():
        return metaopt.sample_numiter(FLAGS.num_partial_unroll_itr_scale,
                                      FLAGS.min_num_itr_partial_unroll)
    
    
    # test trainable_optimizer
    for problem_itr, (problem, dataset, batch_size) in enumerate(problems_and_data):
        if FLAGS.test_optimizer == 'L2o':
          optimizer_spec = problem_spec.Spec(
                optimizer_cls, optimizer_args, optimizer_kwargs)
          opt = optimizer_spec.build()
          logdir = logdir
        elif FLAGS.test_optimizer == 'Adam':
          opt = tf.train.AdamOptimizer(learning_rate=0.001)
          logdir = None
        # if dataset is None, use the EMPTY_DATASET
        if dataset is None:
          dataset = datasets.EMPTY_DATASET
          batch_size = dataset.size

        # build a new graph for this problem
        

        # initialize a problem
        problem = problem.build()
        objective_values, parameters, records = metaopt.test_optimizer(
            opt,
            problem,
            num_iter=FLAGS.num_testing_itrs,
            dataset=dataset,
            batch_size=batch_size,
            seed=None,
            graph=None,
            logdir=logdir,
            record_every=None)
        current_dir = os.path.dirname(os.path.realpath(__file__))
        problem_name = FLAGS.train_dir.split('/')[0]
        save_dir = os.path.join(current_dir, problem_name)
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        with open(os.path.join(save_dir, '{}_eval_loss_record.pickle'.format(FLAGS.test_optimizer)), 'wb') as l_record:
            pickle.dump(objective_values, l_record)
        print("Saving evaluate loss record {}".format(problem_name))
    
    return 0


if __name__ == "__main__":
    tf.app.run()