
"""Scripts for meta-optimization."""

from __future__ import print_function

import os
import pickle

import tensorflow as tf

import metaopt
from optimizer import coordinatewise_rnn
from optimizer import global_learning_rate
from optimizer import hierarchical_rnn
from optimizer import learning_rate_schedule
from optimizer import trainable_adam
from problems import problem_sets as ps
from problems import problem_spec
from problems import datasets

# tf.app.flags.DEFINE_integer("seed", 14,
#                             """random seed.""")

tf.app.flags.DEFINE_string("train_dir", "opt/",
                           """Directory to store parameters and results.""")
tf.app.flags.DEFINE_string("save_dir", "opt/",
                           """Directory to testing results.""")
tf.app.flags.DEFINE_string("restore_model_name", "", "")
tf.app.flags.DEFINE_string("test_optimizer", "L2o",
                           """optimizer to test.""")

tf.app.flags.DEFINE_integer("task", 0,
                            """Task id of the replica running the training.""")
tf.app.flags.DEFINE_integer("worker_tasks", 1,
                            """Number of tasks in the worker job.""")
tf.app.flags.DEFINE_integer("num_testing_itrs", 100,
                            """Number testing iterations.""")
tf.app.flags.DEFINE_integer("num_problems", 1,
                            """Number of sub-problems to run.""")
tf.app.flags.DEFINE_integer("num_meta_iterations", 5,
                            """Number of meta-iterations to optimize.""")
tf.app.flags.DEFINE_integer("num_unroll_scale", 10,
                            """The scale parameter of the exponential
                            distribution from which the number of partial
                            unrolls is drawn""")
tf.app.flags.DEFINE_integer("min_num_unrolls", 4,
                            """The minimum number of unrolls per problem.""")
tf.app.flags.DEFINE_integer("num_partial_unroll_itr_scale", 50,
                            """The scale parameter of the exponential
                               distribution from which the number of iterations
                               per unroll is drawn.""")
tf.app.flags.DEFINE_integer("min_num_itr_partial_unroll", 10,
                            """The minimum number of iterations for one
                               unroll.""")

tf.app.flags.DEFINE_string("optimizer", "HierarchicalRNN",
                           """Which meta-optimizer to train.""")

# CoordinatewiseRNN-specific flags
tf.app.flags.DEFINE_integer("cell_size", 10,
                            """Size of the RNN hidden state in each layer.""")
tf.app.flags.DEFINE_integer("num_cells", 2,
                            """Number of RNN layers.""")
tf.app.flags.DEFINE_string("cell_cls", "GRUCell",
                           """Type of RNN cell to use.""")

# Metaoptimization parameters
tf.app.flags.DEFINE_float("meta_learning_rate", 1e-6,
                          """The learning rate for the meta-optimizer.""")
tf.app.flags.DEFINE_float("gradient_clip_level", 1e4,
                          """The level to clip gradients to.""")

# Train or test
# tf.app.flags.DEFINE_boolean("training", False,
#                             """training or testing.""")

# Training set selection



tf.app.flags.DEFINE_boolean("include_mnist_conv_problems", False,
                            """Include Convnet problems.""")

tf.app.flags.DEFINE_boolean("include_cifar10_conv_problems", False,
                            """Include cifar10 Convnet problems.""")

tf.app.flags.DEFINE_boolean("include_mnist_mlp_relu_problems", False,
                            """include_mnist_mlp_relu_problems.""")


tf.app.flags.DEFINE_boolean("include_mnist_conv_large_problems", False,
                            """include_mnist_conv_large_problems.""")
                            



# Optimizer parameters: initialization and scale values
tf.app.flags.DEFINE_float("min_lr", 1e-6,
                          """The minimum initial learning rate.""")
tf.app.flags.DEFINE_float("max_lr", 1e-2,
                          """The maximum initial learning rate.""")

# Optimizer parameters: small features.
tf.app.flags.DEFINE_boolean("zero_init_lr_weights", True,
                            """Whether to initialize the learning rate weights
                               to 0 rather than the scaled random initialization
                               used for other RNN variables.""")
tf.app.flags.DEFINE_boolean("use_relative_lr", True,
                            """Whether to use the relative learning rate as an
                               input during training. Can only be used if
                               learnable_decay is also True.""")
tf.app.flags.DEFINE_boolean("use_extreme_indicator", False,
                            """Whether to use the extreme indicator for learning
                               rates as an input during training. Can only be
                               used if learnable_decay is also True.""")
tf.app.flags.DEFINE_boolean("use_log_means_squared", True,
                            """Whether to track the log of the mean squared
                               grads instead of the means squared grads.""")
tf.app.flags.DEFINE_boolean("use_problem_lr_mean", True,
                            """Whether to use the mean over all learning rates
                               in the problem when calculating the relative
                               learning rate.""")

# Optimizer parameters: major features
tf.app.flags.DEFINE_boolean("learnable_decay", True,
                            """Whether to learn weights that dynamically
                              modulate the input scale via RMS decay.""")
tf.app.flags.DEFINE_boolean("dynamic_output_scale", True,
                            """Whether to learn weights that dynamically
                               modulate the output scale.""")
tf.app.flags.DEFINE_boolean("use_log_objective", True,
                            """Whether to use the log of the scaled objective
                               rather than just the scaled obj for training.""")
tf.app.flags.DEFINE_boolean("use_attention", False,
                            """Whether to learn where to attend.""")
tf.app.flags.DEFINE_boolean("use_second_derivatives", True,
                            """Whether to use second derivatives.""")
tf.app.flags.DEFINE_integer("num_gradient_scales", 4,
                            """How many different timescales to keep for
                               gradient history. If > 1, also learns a scale
                               factor for gradient history.""")
tf.app.flags.DEFINE_float("max_log_lr", 33,
                          """The maximum log learning rate allowed.""")
tf.app.flags.DEFINE_float("objective_training_max_multiplier", -1,
                          """How much the objective can grow before training on
                             this problem / param pair is terminated. Sets a max
                             on the objective value when multiplied by the
                             initial objective. If <= 0, not used.""")
tf.app.flags.DEFINE_boolean("use_gradient_shortcut", True,
                            """Whether to add a learned affine projection of the
                               gradient to the update delta in addition to the
                               gradient function computed by the RNN.""")
tf.app.flags.DEFINE_boolean("use_lr_shortcut", False,
                            """Whether to add the difference between the current
                               learning rate and the desired learning rate to
                               the RNN input.""")
tf.app.flags.DEFINE_boolean("use_grad_products", True,
                            """Whether to use gradient products in the input to
                               the RNN. Only applicable when num_gradient_scales
                               > 1.""")
tf.app.flags.DEFINE_boolean("use_multiple_scale_decays", False,
                            """Whether to use many-timescale scale decays.""")
tf.app.flags.DEFINE_boolean("use_numerator_epsilon", False,
                            """Whether to use epsilon in the numerator of the
                               log objective.""")
tf.app.flags.DEFINE_boolean("learnable_inp_decay", True,
                            """Whether to learn input decay weight and bias.""")
tf.app.flags.DEFINE_boolean("learnable_rnn_init", True,
                            """Whether to learn RNN state initialization.""")

# Additional tricks
tf.app.flags.DEFINE_boolean("use_output_constrain", False, "Use output constrain on RNN's output.")
tf.app.flags.DEFINE_float("output_constrain_alpha", 0.1, "scale of the LSTM output after adding output constains")

FLAGS = tf.app.flags.FLAGS

# The Size of the RNN hidden state in each layer:
# [PerParam, PerTensor, Global]. The length of this list must be 1, 2, or 3.
# If less than 3, the Global and/or PerTensor RNNs will not be created.

HRNN_CELL_SIZES = [10, 20, 20]


def register_optimizers():
    opts = {}
    opts["CoordinatewiseRNN"] = coordinatewise_rnn.CoordinatewiseRNN
    opts["GlobalLearningRate"] = global_learning_rate.GlobalLearningRate
    opts["HierarchicalRNN"] = hierarchical_rnn.HierarchicalRNN
    opts["LearningRateSchedule"] = learning_rate_schedule.LearningRateSchedule
    opts["TrainableAdam"] = trainable_adam.TrainableAdam
    return opts


def main(_):
    """Runs the main script."""
    
    opts = register_optimizers()
    
    # Choose a set of problems to optimize. By default this includes quadratics,
    # 2-dimensional bowls, 2-class softmax problems, and non-noisy optimization
    # test problems (e.g. Rosenbrock, Beale)
    problems_and_data = []
    


    if FLAGS.include_mnist_conv_problems:
        problems_and_data.extend(ps.test_mnist_conv_problems())
    
    if FLAGS.include_cifar10_conv_problems:
        problems_and_data.extend(ps.test_cifar10_conv_problems())


    if FLAGS.include_mnist_mlp_relu_problems:
        problems_and_data.extend(ps.test_mnist_mlp_relu_problems())
        
    
    if FLAGS.include_mnist_conv_large_problems:
        problems_and_data.extend(ps.test_mnist_conv_large_problems())


    
    # log directory
    logdir = os.path.join(FLAGS.train_dir,
                          "{}_{}_{}_{}".format(FLAGS.optimizer,
                                               FLAGS.cell_cls,
                                               FLAGS.cell_size,
                                               FLAGS.num_cells))
    
    # get the optimizer class and arguments
    optimizer_cls = opts[FLAGS.optimizer]
    
    assert len(HRNN_CELL_SIZES) in [1, 2, 3]
    optimizer_args = (HRNN_CELL_SIZES,)
    
    optimizer_kwargs = {
        "init_lr_range": (FLAGS.min_lr, FLAGS.max_lr),
        "learnable_decay": FLAGS.learnable_decay,
        "dynamic_output_scale": FLAGS.dynamic_output_scale,
        "cell_cls": getattr(tf.contrib.rnn, FLAGS.cell_cls),
        "use_attention": FLAGS.use_attention,
        "use_log_objective": FLAGS.use_log_objective,
        "num_gradient_scales": FLAGS.num_gradient_scales,
        "zero_init_lr_weights": FLAGS.zero_init_lr_weights,
        "use_log_means_squared": FLAGS.use_log_means_squared,
        "use_relative_lr": FLAGS.use_relative_lr,
        "use_extreme_indicator": FLAGS.use_extreme_indicator,
        "max_log_lr": FLAGS.max_log_lr,
        "obj_train_max_multiplier": FLAGS.objective_training_max_multiplier,
        "use_problem_lr_mean": FLAGS.use_problem_lr_mean,
        "use_gradient_shortcut": FLAGS.use_gradient_shortcut,
        "use_second_derivatives": FLAGS.use_second_derivatives,
        "use_lr_shortcut": FLAGS.use_lr_shortcut,
        "use_grad_products": FLAGS.use_grad_products,
        "use_multiple_scale_decays": FLAGS.use_multiple_scale_decays,
        "use_numerator_epsilon": FLAGS.use_numerator_epsilon,
        "learnable_inp_decay": FLAGS.learnable_inp_decay,
        "learnable_rnn_init": FLAGS.learnable_rnn_init,
        "use_output_constrain": FLAGS.use_output_constrain,
        "output_constrain_alpha": FLAGS.output_constrain_alpha
    }
    
    
    # make log directory
    tf.gfile.MakeDirs(logdir)
    
    is_chief = FLAGS.task == 0
    # if this is a distributed run, make the chief run through problems in order
    select_random_problems = FLAGS.worker_tasks == 1 or not is_chief
    
    def num_unrolls():
        return metaopt.sample_numiter(FLAGS.num_unroll_scale, FLAGS.min_num_unrolls)
    
    def num_partial_unroll_itrs():
        return metaopt.sample_numiter(FLAGS.num_partial_unroll_itr_scale,
                                      FLAGS.min_num_itr_partial_unroll)
    
    
    # test trainable_optimizer
    opt = None
    if FLAGS.test_optimizer == 'L2o':
      print('using optimzier L2o')
      optimizer_spec = problem_spec.Spec(
            optimizer_cls, optimizer_args, optimizer_kwargs)
      opt = optimizer_spec.build()
      logdir = logdir
    elif FLAGS.test_optimizer == 'SGD':
      print('using optimzier SGD')
      opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.00001)
      logdir = None
    elif FLAGS.test_optimizer == 'Adam':
      print('using optimzier Adam')
      opt = tf.train.AdamOptimizer(learning_rate=0.00001)
      logdir = None
    elif FLAGS.test_optimizer == 'Adagrad':
      print('using optimzier Adagrad')
      opt = tf.train.AdagradOptimizer(learning_rate=0.00001)
      logdir = None
    # if dataset is None, use the EMPTY_DATASET
    # if dataset is None:
    #   dataset = datasets.EMPTY_DATASET
    #   batch_size = dataset.size
    for problem_itr, (problem, dataset, batch_size) in enumerate(problems_and_data):
        # build a new graph for this problem
        print(dataset)
        problem = problem.build()
        current_dir = os.path.dirname(os.path.realpath(__file__))
        problem_name = FLAGS.train_dir.split(os.path.sep)[-2]
        problem_feature = FLAGS.train_dir.split(os.path.sep)[-1]
        #save_dir = os.path.join(current_dir, 'records', problem_name, FLAGS.test_optimizer +"_"+problem_feature)
        '''
        *************************
        '''
        save_dir = FLAGS.save_dir
        '''
        '''
        #print(save_dir)
        #print(FLAGS.save_dir)

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for seed in range(0, 50, 5):
            print('testing the {} using seed {}'.format(problem_name, seed))
            # initialize a problem
            
            objective_values, acc_records, parameters, _, test_acc, test_loss, training_acc, final_test_acc, final_test_loss, final_train_loss, final_train_acc = metaopt.test_optimizer(
                opt,
                problem,
                num_epochs=FLAGS.num_testing_itrs,
                dataset=dataset,
                batch_size=batch_size,
                seed=seed,
                graph=None,
                logdir=logdir,
                record_every=None,
                restore_model_name=FLAGS.restore_model_name)
            
            with open(os.path.join(save_dir, 'seed{}_train_loss_record.pickle'.format(seed)), 'wb') as l1_record:
                pickle.dump(objective_values, l1_record)
            with open(os.path.join(save_dir, 'seed{}_acc_record.pickle'.format(seed)), 'wb') as l2_record:
                pickle.dump(acc_records, l2_record)
            with open(os.path.join(save_dir, 'seed{}_model_params.pickle'.format(seed)), 'wb') as l3_record:
                pickle.dump(parameters, l3_record)
            '''
            *******************
            '''
            with open(os.path.join(save_dir, 'seed{}_test_acc.pickle'.format(seed)), 'wb') as l4_record:
                pickle.dump(test_acc, l4_record)
            with open(os.path.join(save_dir, 'seed{}_test_loss.pickle'.format(seed)), 'wb') as l5_record:
                pickle.dump(test_loss, l5_record)
            with open(os.path.join(save_dir, 'seed{}_training_acc.pickle'.format(seed)), 'wb') as l6_record:
                pickle.dump(training_acc, l6_record)
            
            with open(os.path.join(save_dir, 'seed{}_final_test_acc.pickle'.format(seed)), 'wb') as l7_record:
                pickle.dump(final_test_acc, l7_record)
            with open(os.path.join(save_dir, 'seed{}_final_test_loss.pickle'.format(seed)), 'wb') as l8_record:
                pickle.dump(final_test_loss, l8_record)
            with open(os.path.join(save_dir, 'seed{}_final_train_loss.pickle'.format(seed)), 'wb') as l9_record:
                pickle.dump(final_train_loss, l9_record)
            with open(os.path.join(save_dir, 'seed{}_final_train_acc.pickle'.format(seed)), 'wb') as l10_record:
                pickle.dump(final_train_acc, l10_record)
            '''
            *******************
            '''
            print("Saving evaluate seed{} loss and acc records to {} ".format(seed, FLAGS.test_optimizer))
    
    return 0


if __name__ == "__main__":
    tf.app.run()