import tensorflow as tf
import numpy as np

from python.slalom.mobtest import MobileNetV2 as MobileNetV2_sgx
from python.slalom.global_sgx import sgxutils

import keras.backend as K
import keras.layers as layers
import time

import torch
import torch.nn as nn



def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
    sgxutils.dnnl_init()
    with tf.Session(config=config) as sess:

        '''
        in_size = (1, 112, 112, 32)
        out_size= (1, 56, 56, 32)
        input_tensor  = tf.random_uniform(in_size, minval=-500.0, maxval=500.0, dtype=tf.float32)
        output_tensor = tf.random_uniform(out_size, minval=-500.0, maxval=500.0, dtype=tf.float32)
        kernel_tensor = tf.random_uniform((3, 3, 32, 1), minval=-500.0, maxval=500.0, dtype=tf.float32)
        enclave_list  = []
        input_np  = sess.run(input_tensor, feed_dict={})
        output_np = sess.run(output_tensor, feed_dict={})
        kernel_np = sess.run(kernel_tensor, feed_dict={})
        enclave_list.append(["pool", in_size, out_size, (2, 2), (2, 2), (0, 0), 1])
        sgxutils.setup_layer(enclave_list)
        sgxutils.setup_final_reorder()
        sgxutils.enclave_update_backward()
        res = sgxutils.forward(input_np,   (56, 56, 32))
        print(res.shape)
        grad = sgxutils.backward(output_np,   (112, 112, 32))

        print(grad[0][0][0], output_np[0][0][0][0])
        '''

        in_size = (1, 1, 1, 32)
        out_size= (1, 16)
        input_tensor  = tf.random_uniform(in_size, minval=-500.0, maxval=500.0, dtype=tf.float32)
        output_tensor = tf.random_uniform(out_size, minval=-500.0, maxval=500.0, dtype=tf.float32)
        kernel_tensor = tf.random_uniform((32, 16), minval=-500.0, maxval=500.0, dtype=tf.float32)
        bias_tensor   = tf.random_uniform((1,16), minval=-500.0, maxval=500.0, dtype=tf.float32)
        enclave_list  = []
        input_np  = sess.run(input_tensor, feed_dict={})
        output_np = sess.run(output_tensor, feed_dict={})
        kernel_np = sess.run(kernel_tensor, feed_dict={})
        bias_np   = sess.run(bias_tensor, feed_dict={})
        enclave_list.append(["linear", in_size, out_size, kernel_np, bias_np])
        sgxutils.setup_layer(enclave_list)
        
        sgxutils.setup_final_reorder()
        sgxutils.enclave_update_backward()
        res = sgxutils.forward(input_np,   (1, 1, 16))
        grad = sgxutils.backward(output_np, (1, 1, 32))


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)