import tensorflow as tf
import numpy as np

from python.slalom.vgg16 import VGG16
from python.slalom.global_sgx import sgxutils

import keras.backend as K

import time

def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
    sgxutils.dnnl_init()

    with tf.Session(config=config) as sess:
        input_tensor = tf.random_uniform((1, 224, 224, 3), minval=-5000.0, maxval=50000.0, dtype=tf.float32)
        enclave_list = []

        with tf.device("/gpu:0"):
            model = VGG16(input_shape=(224, 224, 3), 
                          weights=None, 
                          classes=1000,
                          enclave_list=enclave_list,
                          sgxutils=sgxutils)

            #K.get_session().run(tf.initialize_all_variables())
            print(model.summary())
            input_np = sess.run(input_tensor, feed_dict={})
            #y_sgx, inter = sess.run([model_sgx.layers[4].output, model_sgx.layers[3].output], feed_dict={model_sgx.input:input_np})
            
        sgxutils.setup_layer(enclave_list)

        sgxutils.setup_final_reorder()
        
        print("++++++++++")
        sgxutils.enclave_update_backward()

        


        sgxutils.sgx_reset_timing()

        time_fwd = 0.0
        time_bwd = 0.0
        start = time.time()
        sgxutils.forward(input_np, (112, 112, 32))
        time_fwd += time.time() - start

        start = time.time()
        sgxutils.backward(input_np, (224, 224, 3))
        time_bwd += time.time() - start

        start = time.time()
        sgxutils.forward(input_np, (112, 112, 32))
        time_fwd += time.time() - start

        start = time.time()
        sgxutils.backward(input_np, (224, 224, 3))
        time_bwd += time.time() - start

        start = time.time()
        sgxutils.forward(input_np, (112, 112, 32))
        time_fwd += time.time() - start

        start = time.time()
        sgxutils.backward(input_np, (224, 224, 3))
        time_bwd += time.time() - start
        
        sgxutils.sgx_print_timing()
        print(time_fwd, time_bwd)
        
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)