import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
# import sys
# logf = open('stdout.log', 'a')
# sys.stdout = logf
# sys.stderr = logf

print("[debug] logging.")

"""
note that in test time, test_num_updates=10 means output is 1(original loss)+10(sghmc_num_updates)+10(num_updates) dimensions;
"""
import csv
import numpy as np
import pickle
import random
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
# tf.get_logger().setLevel('ERROR')

from data_generator import DataGenerator
from ipml import IPML
from tensorflow.python.platform import flags

FLAGS = flags.FLAGS

## Dataset/method options
flags.DEFINE_string('datasource', 'sinusoid', 'sinusoid or omniglot or miniimagenet')
flags.DEFINE_integer('num_classes', 5, 'number of classes used in classification (e.g. 5-way classification).')
# oracle means task id is input (only suitable for sinusoid)
flags.DEFINE_string('baseline', None, 'oracle, or None')

## Training options
flags.DEFINE_integer('pretrain_iterations', 0, 'number of pre-training iterations.')
flags.DEFINE_integer('metatrain_iterations', 15000, 'number of metatraining iterations.') # 15k for omniglot, 50k for sinusoid
flags.DEFINE_integer('meta_batch_size', 25, 'number of tasks sampled per meta-update')
flags.DEFINE_integer('active_meta_batch_size', 5, 'number of tasks used for active meta-update')
flags.DEFINE_float('meta_lr', 0.001, 'the base learning rate of the generator')
flags.DEFINE_integer('update_batch_size', 5, 'number of examples used for inner gradient update (K for K-shot learning).')
flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update.') # 0.1 for omniglot
flags.DEFINE_integer('num_updates', 1, 'number of inner gradient updates during training.')
## SGHMC Sampler
flags.DEFINE_integer('sghmc_num_burnin', 0, 'number of sghmc gradient updates during training.')
flags.DEFINE_integer('sghmc_num_sample', 5, 'number of sghmc gradient updates during training.')
flags.DEFINE_integer('sghmc_num_updates', 5, 'number of sghmc gradient updates during training.') # must equal to sghmc_num_burnin + sghmc_num_sample
flags.DEFINE_float('epsilon', 3e-2, 'step size of sampler, epsilon ** 2 approx update_lr.') # 
flags.DEFINE_float('mdecay', 0.95, 'sampler hyper.') # 
flags.DEFINE_float('prior_constant', 1e-4, 'prior hyper.') # 
## Model options
flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None')
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omiglot.')
flags.DEFINE_bool('conv', True, 'whether or not to use a convolutional network, only applicable in some cases')
flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
flags.DEFINE_bool('stop_grad', True, 'if True, do not use second derivatives in meta-optimization (as ipml does need to use)')
flags.DEFINE_bool('clip_z_grad', True, 'if True, clip grad in sghmc step')
flags.DEFINE_bool('clip_maml_grad', False, 'if True, clip grad in maml step')
flags.DEFINE_bool('get_z_samples', False, 'if True, get_z_samples')

## Logging, saving, and testing options
flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code.')
flags.DEFINE_string('logdir', '/tmp/data', 'directory for summaries and checkpoints.')
flags.DEFINE_bool('resume', True, 'resume training if there is a model available')
flags.DEFINE_bool('train', True, 'True to train, False to test.')
flags.DEFINE_integer('test_iter', -1, 'iteration to load model (-1 for latest model)')
flags.DEFINE_bool('test_set', False, 'Set to true to test on the the test set, False for the validation set.')
flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training (use if you want to test with a different number).')
flags.DEFINE_float('train_update_lr', -1, 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot
flags.DEFINE_bool('enable_active_learning', False, 'Set to true to run active learning.')
# investigate task embedding
flags.DEFINE_bool('embedding_experiment', False, 'Set to true to run embedding experiment.')
# flags.DEFINE_string('embedding_experiment_type', '', 'rotation/brightness/contrast/hue/saturation.')
flags.DEFINE_bool('rotation_experiment', False, 'Set to true to run rotation experiment.')
flags.DEFINE_bool('brightness_experiment', False, 'Set to true to run brightness experiment.')
flags.DEFINE_bool('hue_experiment', False, 'Set to true to run hue experiment.')
flags.DEFINE_bool('contrast_experiment', False, 'Set to true to run.')
flags.DEFINE_bool('saturation_experiment', False, 'Set to true to run.')
flags.DEFINE_bool('zoom_experiment', False, 'Set to true to run.')
flags.DEFINE_bool('task_generation_experiment', False, 'Set to true to run.')
flags.DEFINE_bool('retrain_experiment', False, 'Set to true to run.')

def train(model, saver, sess, exp_string, data_generator, resume_itr=0):
    SUMMARY_INTERVAL = 100
    SAVE_INTERVAL = 1000
    if FLAGS.datasource == 'sinusoid':
        PRINT_INTERVAL = 1000
        if FLAGS.enable_active_learning or FLAGS.meta_batch_size==1:
            PRINT_INTERVAL = 100
            SAVE_INTERVAL = 100
        TEST_PRINT_INTERVAL = PRINT_INTERVAL*5
    else:
        PRINT_INTERVAL = 100
        TEST_PRINT_INTERVAL = PRINT_INTERVAL*10
    if FLAGS.embedding_experiment:
        SAVE_INTERVAL = 1
        PRINT_INTERVAL = 1
        SUMMARY_INTERVAL = 1
        TEST_PRINT_INTERVAL = 10

    if FLAGS.log:
        train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph)
    print('Done initializing, starting training.')
    prelosses, zlosses, postlosses = [], [], []

    num_classes = data_generator.num_classes # for classification, 1 otherwise
    multitask_weights, reg_weights = [], []

    for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations):
        feed_dict = {}
        if 'generate' in dir(data_generator):
            batch_x, batch_y, amp, phase = data_generator.generate()

            if FLAGS.baseline == 'oracle':
                batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2)
                for i in range(FLAGS.meta_batch_size):
                    batch_x[i, :, 1] = amp[i]
                    batch_x[i, :, 2] = phase[i]

            inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
            labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
            inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] # b used for testing
            labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
            feed_dict = {model.inputa: inputa, model.inputb: inputb,  model.labela: labela, model.labelb: labelb}

        if itr < FLAGS.pretrain_iterations:
            input_tensors = [model.pretrain_op]
        else:
            input_tensors = [model.metatrain_op]

        if (itr % SUMMARY_INTERVAL == 0 or itr % PRINT_INTERVAL == 0):
            input_tensors.extend([model.summ_op, model.total_loss1, 
                                  model.total_losses2[FLAGS.sghmc_num_updates-1], 
                                  model.total_losses3[FLAGS.num_updates-1]])
            if model.classification:
                input_tensors.extend([model.total_accuracy1, 
                                      model.total_accuracies2[FLAGS.sghmc_num_updates-1], 
                                      model.total_accuracies3[FLAGS.num_updates-1]])
            if FLAGS.get_z_samples and itr % SAVE_INTERVAL == 0:
                input_tensors.extend([model.Z_samples])
                holder = 1
            else:
                holder = 0

        result = sess.run(input_tensors, feed_dict)

        if itr % SUMMARY_INTERVAL == 0:
            prelosses.append(result[-3-holder])
            if FLAGS.log:
                train_writer.add_summary(result[1], itr)
            zlosses.append(result[-2-holder])
            # print("[debug]:result[-2]shape", result[-2].shape) # ()
            postlosses.append(result[-1-holder])

        if itr % PRINT_INTERVAL == 0:
            if itr < FLAGS.pretrain_iterations:
                print_str = 'Pretrain Iteration ' + str(itr)
            else:
                print_str = 'Iteration ' + str(itr - FLAGS.pretrain_iterations)
            print_str += ': ' + str(np.mean(prelosses)) + ', ' + str(np.mean(zlosses)) + ', ' + str(np.mean(postlosses))
            # print("[debug]:lossesshape", len(zlosses), zlosses[0].shape)
            print(print_str)
            prelosses, zlosses, postlosses = [], [], []

        if itr % SAVE_INTERVAL == 0:
            saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
            if FLAGS.get_z_samples and FLAGS.embedding_experiment:
                list_of_Z_samples = result[-1]
                print("[debug] z shape",len(list_of_Z_samples),list_of_Z_samples[0].shape)
                if FLAGS.task_generation_experiment:
                    inputa_file = open("Z_samples/omniglot_generate/inputa_{}.pkl".format(itr), "wb")
                    inputb_file = open("Z_samples/omniglot_generate/inputb_{}.pkl".format(itr), "wb")
                    model_inputa = sess.run(model.inputa)
                    model_inputb = sess.run(model.inputb)
                    pickle.dump(model_inputa, inputa_file)
                    pickle.dump(model_inputb, inputb_file)
                    inputa_file.close()
                    inputb_file.close()
                for i, Z_samples in enumerate(list_of_Z_samples):
                    # Z_samples["task_type"] = int(i%4)
                    if FLAGS.rotation_experiment:
                        a_file = open("Z_samples/omniglot_rotate/Z_samples_{}_{}.pkl".format(itr,i), "wb")
                    elif FLAGS.brightness_experiment:
                        a_file = open("Z_samples/miniimagenet_brightness/Z_samples_{}_{}.pkl".format(itr,i), "wb")
                    elif FLAGS.hue_experiment:
                        a_file = open("Z_samples/miniimagenet_hue/Z_samples_{}_{}.pkl".format(itr,i), "wb")
                    elif FLAGS.contrast_experiment:
                        a_file = open("Z_samples/miniimagenet_contrast/Z_samples_{}_{}.pkl".format(itr,i), "wb")
                    elif FLAGS.saturation_experiment:
                        a_file = open("Z_samples/miniimagenet_saturation/Z_samples_{}_{}.pkl".format(itr,i), "wb")
                    elif FLAGS.zoom_experiment:
                        a_file = open("Z_samples/miniimagenet_zoom/Z_samples_{}_{}.pkl".format(itr,i), "wb")
                    elif FLAGS.task_generation_experiment:
                        a_file = open("Z_samples/omniglot_generate/Z_samples_{}_{}.pkl".format(itr,i), "wb")
                    pickle.dump(Z_samples, a_file)
                    a_file.close()
#                     if 'reduced_mnist' in FLAGS.datasource:
#                         b_file = open("Z_samples/inputas_{}_{}.pkl".format(itr,i), "wb")
#                         pickle.dump(inputa, b_file) # must eval inputa into nparray
#                         b_file.close()
                    

        # sinusoid is infinite data, so no need to test on meta-validation set.
        if itr % TEST_PRINT_INTERVAL == 0 and FLAGS.datasource !='sinusoid' and not FLAGS.embedding_experiment:
            if 'generate' not in dir(data_generator):
                feed_dict = {}
                if model.classification:
                    input_tensors = [model.metaval_total_accuracy1, 
                        model.metaval_total_accuracies2[FLAGS.sghmc_num_updates-1], 
                        model.metaval_total_accuracies3[FLAGS.num_updates-1], model.summ_op]
                else:
                    input_tensors = [model.metaval_total_loss1, 
                        model.metaval_total_losses2[FLAGS.sghmc_num_updates-1], 
                        model.metaval_total_losses3[FLAGS.num_updates-1], 
                        model.summ_op]
            else:
                batch_x, batch_y, amp, phase = data_generator.generate(train=False)
                inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
                inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :]
                labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
                labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
                feed_dict = {model.inputa: inputa, model.inputb: inputb,  model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0}
                if model.classification:
                    input_tensors = [model.total_accuracy1, 
                    model.total_accuracies2[FLAGS.sghmc_num_updates-1],
                    model.total_accuracies3[FLAGS.num_updates-1]]
                else:
                    input_tensors = [model.total_loss1, 
                    model.total_losses2[FLAGS.sghmc_num_updates-1],
                    model.total_losses3[FLAGS.num_updates-1]]

            result = sess.run(input_tensors, feed_dict)
            print('Validation results: ' + str(result[0]) + ', ' + str(result[1]) + ', ' + str(result[2]))

    saver.save(sess, FLAGS.logdir + '/' + exp_string +  '/model' + str(itr))

# calculated for omniglot
NUM_TEST_POINTS = 600

def test(model, saver, sess, exp_string, data_generator, test_num_updates=None):
    num_classes = data_generator.num_classes # for classification, 1 otherwise

    np.random.seed(1)
    random.seed(1)

    metaval_accuracies = []

    for _ in range(NUM_TEST_POINTS):
        if 'generate' not in dir(data_generator):
            feed_dict = {}
            feed_dict = {model.meta_lr : 0.0}
        else:
            batch_x, batch_y, amp, phase = data_generator.generate(train=False)

            if FLAGS.baseline == 'oracle': # NOTE - this flag is specific to sinusoid
                batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2)
                batch_x[0, :, 1] = amp[0]
                batch_x[0, :, 2] = phase[0]

            inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
            inputb = batch_x[:,num_classes*FLAGS.update_batch_size:, :]
            labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
            labelb = batch_y[:,num_classes*FLAGS.update_batch_size:, :]

            feed_dict = {model.inputa: inputa, model.inputb: inputb,  model.labela: labela, model.labelb: labelb, model.meta_lr: 0.0}

        if model.classification:
            result = sess.run([model.metaval_total_accuracy1] + model.metaval_total_accuracies2 + model.metaval_total_accuracies3, feed_dict)
        else:  # this is for sinusoid
            # list_of losses = [model.total_loss1] +  model.total_losses2 +  +  model.total_losses3
            result = sess.run([model.total_loss1] +  model.total_losses2 + model.total_losses3, feed_dict)
        metaval_accuracies.append(result)

    metaval_accuracies = np.array(metaval_accuracies)
    print("[DEBUG]: ",metaval_accuracies.shape)
    means = np.mean(metaval_accuracies, 0)
    stds = np.std(metaval_accuracies, 0)
    ci95 = 1.96*stds/np.sqrt(NUM_TEST_POINTS)

    print('Mean validation accuracy/loss, stddev, and confidence intervals')
    print((means, stds, ci95))

    out_filename = FLAGS.logdir +'/'+ exp_string + '/' + 'test_ubs' + str(FLAGS.update_batch_size) + '_stepsize' + str(FLAGS.update_lr) + '.csv'
    out_pkl = FLAGS.logdir +'/'+ exp_string + '/' + 'test_ubs' + str(FLAGS.update_batch_size) + '_stepsize' + str(FLAGS.update_lr) + '.pkl'
    with open(out_pkl, 'wb') as f:
        pickle.dump({'mses': metaval_accuracies}, f)
    with open(out_filename, 'w') as f:
        writer = csv.writer(f, delimiter=',')
        writer.writerow(['update'+str(i) for i in range(len(means))])
        writer.writerow(means)
        writer.writerow(stds)
        writer.writerow(ci95)

def main():
    main_seed1 = 10
    random.seed(main_seed1)
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 10
            # test_num_updates = 5
            print("[Tune]:test_num_updates=", test_num_updates)
    else:
        if FLAGS.datasource == 'miniimagenet':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet?
                if FLAGS.train:
                    data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
            else:
                data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory


    dim_output = data_generator.dim_output
    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'reduced_mnist' or FLAGS.datasource == 'reduced_omniglot':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train: # only construct training model if needed
            random.seed(5)
            if FLAGS.rotation_experiment:
                image_tensor, label_tensor = data_generator.make_rotate_data_tensor()
            elif FLAGS.brightness_experiment:
                image_tensor, label_tensor = data_generator.make_bright_data_tensor()
            elif FLAGS.hue_experiment:
                image_tensor, label_tensor = data_generator.make_hue_data_tensor()
            elif FLAGS.contrast_experiment:
                image_tensor, label_tensor = data_generator.make_contrast_data_tensor()
            elif FLAGS.saturation_experiment:
                image_tensor, label_tensor = data_generator.make_saturation_data_tensor()
            elif FLAGS.zoom_experiment:
                image_tensor, label_tensor = data_generator.make_zoom_data_tensor()
            elif FLAGS.task_generation_experiment:
                image_tensor, label_tensor = data_generator.make_generate_data_tensor()
            elif FLAGS.retrain_experiment:
                image_tensor, label_tensor = data_generator.make_fake_data_tensor()
            elif 'reduced_' in FLAGS.datasource:
                image_tensor, label_tensor = data_generator.make_reduced_data_tensor()
            else:
                image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
            labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
            print(inputa,inputb)
            input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

        random.seed(6)
        if FLAGS.rotation_experiment:
            image_tensor, label_tensor = data_generator.make_rotate_data_tensor(train=False)
        elif FLAGS.brightness_experiment:
            image_tensor, label_tensor = data_generator.make_bright_data_tensor(train=False)
        elif FLAGS.hue_experiment:
            image_tensor, label_tensor = data_generator.make_hue_data_tensor(train=False)
        elif FLAGS.contrast_experiment:
            image_tensor, label_tensor = data_generator.make_contrast_data_tensor(train=False)
        elif FLAGS.saturation_experiment:
            image_tensor, label_tensor = data_generator.make_saturation_data_tensor(train=False)
        elif FLAGS.zoom_experiment:
            image_tensor, label_tensor = data_generator.make_zoom_data_tensor(train=False)
        elif FLAGS.task_generation_experiment:
            image_tensor, label_tensor = data_generator.make_generate_data_tensor(train=False)
        elif FLAGS.retrain_experiment:
            image_tensor, label_tensor = data_generator.make_fake_data_tensor(train=False)
        elif 'reduced_' in FLAGS.datasource:
            image_tensor, label_tensor = data_generator.make_reduced_data_tensor(train=False)
        else:
            image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
        inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
        labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
        
        metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
    else:
        tf_data_load = False
        input_tensors = None

    model = IPML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.enable_active_learning and FLAGS.train:
        model.construct_model(input_tensors=input_tensors, prefix='active_')
    else:
        if FLAGS.train or not tf_data_load:
            model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_'+str(FLAGS.num_classes)+'.mbs_'+str(FLAGS.meta_batch_size) + '.ubs_' + str(FLAGS.train_update_batch_size) \
    + '.burnin' + str(FLAGS.sghmc_num_burnin)  \
    + '.sample' + str(FLAGS.sghmc_num_sample)  \
    + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) \
    + '.seed1' + str(main_seed1) 

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.enable_active_learning:
        exp_string += 'active'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1+5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)

if __name__ == "__main__":
    main()