import tensorflow as tf
import numpy as np

from python.slalom.mobileNetv2 import MobileNetV2 as MobileNetV2_sgx
from python.slalom.resnet_sp import ResNet50
from python.slalom.resnet_bias_sp import ResNet50 as ResNetBias
from python.slalom.resnet_common import ResNet152 as ResNet152
from python.slalom.global_sgx import sgxutils
import time
import keras.backend as K
from keras.losses import mean_squared_error as MSE

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
import sys
import json

@ops.RegisterGradient("ResnetActivation")
def _resnet_activation_grad(op, grad):
    #mean_shape = op.inputs[1].shape

    mean_grad = op.inputs[1]#tf.zeros(shape=mean_shape)
    attribute = op.get_attr('mode')
    grad_in_1 = sgxutils.resnet_activation_back_op(grad_out=grad, 
                                                   act_mode=attribute)

    return [grad_in_1, mean_grad]

@ops.RegisterGradient("ResnetBottom")
def _resnet_bottom_grad(op, grad):
    #mean_shape0 = op.inputs[2].shape
    #mean_shape1 = op.inputs[3].shape

    mean_grad0 = op.inputs[2]#tf.zeros(shape=mean_shape0)
    mean_grad1 = op.inputs[3]#tf.zeros(shape=mean_shape1)

    grad_left, grad_right = sgxutils.resnet_bottom_back_op(grad)

    return [grad_left, grad_right, mean_grad0, mean_grad1]

def copy_resnet_data(bias_model, normal_model):
    weight = normal_model.layers[2].get_weights()[0]
    bias   = normal_model.layers[2].get_weights()[1]
    bias_model.layers[2].set_weights([weight])
    bias_model.layers[3].set_weights([bias])

    total_size = len(normal_model.layers)
    for idx in range(4, total_size-3):
        bias_model.layers[idx].copy_data(normal_model.layers[idx])

def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
    sgxutils.dnnl_init()
    batch_size = 15

    sgxutils.set_dark_batch_size(batch_size)
    output_size= (batch_size, 1000)
    with tf.Session(config=config) as sess:
        
        with tf.device("/gpu:0"):
            input_tensor   = tf.random_uniform((batch_size, 224, 224, 3), minval=-1000.0, maxval=1000.0, dtype=tf.float32)
            out_tensor     = tf.random_uniform((batch_size, 1000), minval=-1000.0, maxval=1000.0, dtype=tf.float32)
            diff_tensor    = tf.placeholder(tf.float32, shape=output_size)
            enclave_list = []
            model_trans = ResNet152(input_shape=(224, 224, 3),
                                    weights=None, 
                                    classes=1000,
                                    sgxutils=sgxutils,
                                    privacy=False,
                                    enclave_list=enclave_list)
            
            #with open('ResNet152_list.txt', 'w') as f:
            #    f.write(json.dumps(enclave_list))


            print(model_trans.summary())
            
            
            input_np = sess.run(input_tensor)
            sgxutils.setup_layer(enclave_list)

            sgxutils.setup_final_reorder()
            
            sgxutils.enclave_update_backward()

            
            sgxutils.sgx_reset_timing()

            time_fwd = 0.0
            time_bwd = 0.0

            start = time.time()
            sgxutils.forward(input_np, (224, 224, 2048))
            time_fwd += time.time() - start
            

            start = time.time()
            sgxutils.backward(input_np, (230, 230, 3))
            time_bwd += time.time() - start
            
            start = time.time()
            sgxutils.forward(input_np, (224, 224, 2048))
            time_fwd += time.time() - start
            

            start = time.time()
            sgxutils.backward(input_np, (230, 230, 3))
            time_bwd += time.time() - start

            start = time.time()
            sgxutils.forward(input_np, (224, 224, 2048))
            time_fwd += time.time() - start
            

            start = time.time()
            sgxutils.backward(input_np, (230, 230, 3))
            time_bwd += time.time() - start

            sgxutils.sgx_print_timing()
            print(time_fwd, time_bwd)

            
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)