import tensorflow as tf
import numpy as np

from python.slalom.mobtest import MobileNetV2 as MobileNetV2_sgx
from python.slalom.global_sgx import sgxutils

import keras.backend as K

import time

def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
    sgxutils.dnnl_init()
    with tf.Session(config=config) as sess:
        input_tensor = tf.random_uniform((1, 224, 224, 3), minval=-5000.0, maxval=50000.0, dtype=tf.float32)
        enclave_list = []

        with tf.device("/gpu:0"):
            model_sgx = MobileNetV2_sgx(input_size=(1, 224, 224, 3), 
                                        privacy=False, 
                                        weights=None, 
                                        classes=1000,
                                        sgxutils=sgxutils,
                                        enclave_list=enclave_list)

            #K.get_session().run(tf.initialize_all_variables())
            print(model_sgx.summary())
            input_np = sess.run(input_tensor, feed_dict={})
            #y_sgx, inter = sess.run([model_sgx.layers[4].output, model_sgx.layers[3].output], feed_dict={model_sgx.input:input_np})
        sgxutils.setup_layer(enclave_list)
        sgxutils.setup_final_reorder()
        print("++++++++++")
        sgxutils.enclave_update_backward()

        

        sgxutils.sgx_reset_timing()

        time_fwd = 0.0
        time_bwd = 0.0

        start = time.time()
        sgxutils.forward(input_np, (112, 112, 32))
        time_fwd += time.time() - start

        start = time.time()
        sgxutils.backward(input_np, (224, 224, 3))
        time_bwd += time.time() - start

        start = time.time()
        sgxutils.forward(input_np, (112, 112, 32))
        time_fwd += time.time() - start

        start = time.time()
        sgxutils.backward(input_np, (224, 224, 3))
        time_bwd += time.time() - start

        start = time.time()
        sgxutils.forward(input_np, (112, 112, 32))
        time_fwd += time.time() - start

        start = time.time()
        sgxutils.backward(input_np, (224, 224, 3))
        time_bwd += time.time() - start
        
        sgxutils.sgx_print_timing()
        print(time_fwd, time_bwd)
        '''
        print("================ complete setup ==================")
        s = time.time()
        sgxutils.forward(input_np, (112, 112, 16))
        e = time.time() 
        print(e - s)
        
        ker_weight = model_sgx.layers[4].get_weights()[0]
        
        res_sgx = sgxutils.forward(input_np, (1, 112, 112, 32))
        res_brute = 0.0
        for i in range(32):
            res_brute += inter[0][0][0][i] * ker_weight[0][0][i][0]

        print(ker_weight.shape)

        print("===========================================")
        print((res_sgx - y_sgx).mean())
        print("sums  ", res_sgx.sum(), y_sgx.sum())
        print("means ", res_sgx.mean(), y_sgx.mean())
        print("min   ", res_sgx.min(), y_sgx.min())
        print("max   ", res_sgx.max(), y_sgx.max())


        print("single", res_sgx[0][0][0][0], y_sgx[0][0][0][0], res_brute)
        '''
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)