import tensorflow as tf
import numpy as np

from python.slalom.mobileNetv2 import MobileNetV2 as MobileNetV2_sgx
from python.slalom.global_sgx import sgxutils
import time
import keras.backend as K
from keras.losses import mean_squared_error as MSE

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
import sys
@ops.RegisterGradient("BatchNormDark")
def _batchnorm_dark_grad(op, grad, grad1):
    act_src    = op.outputs[1]
    batch_src  = op.inputs[0]
    mean       = op.inputs[1]
    skip_input = op.inputs[2]

    res = sgxutils.batchnorm_dark_back(grad, batch_src, skip_input, act_src)

    grad_skip = res

    if (res.shape != skip_input.shape):
        grad_skip = skip_input


    return [res, mean, res]

def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
    batch_size = 15
    with tf.Session(config=config) as sess:
        input_tensor = tf.random_uniform((batch_size, 224, 224, 3), minval=-1000.0, maxval=1000.0, dtype=tf.float32)
        out_tensor = tf.random_uniform((batch_size, 1000), minval=-1000.0, maxval=1000.0, dtype=tf.float32)

        with tf.device("/gpu:0"):
            model_sgx = MobileNetV2_sgx(input_size=(batch_size, 224, 224, 3), 
                                        privacy=True, 
                                        weights=None, 
                                        classes=1000,
                                        sgxutils=sgxutils)
            print(model_sgx.summary())
        diff_sum = 0.0
        K.get_session().run(tf.initialize_all_variables())
        y_tar       = tf.placeholder(shape=(batch_size, 1000), dtype='float32')
        loss_tar_fn = MSE(y_tar, model_sgx.output)
        trainable_weights_tar = model_sgx.trainable_weights
        gradients_tar = K.gradients(loss_tar_fn, model_sgx.layers[0].input)

        input_np = sess.run(input_tensor, feed_dict={})
        out_np   = sess.run(out_tensor, feed_dict={})
        
        gradients_tar.append(model_sgx.output)
        y_sgx = sess.run(gradients_tar, feed_dict={model_sgx.input:input_np, y_tar:out_np})

        s = time.time()
        y_sgx = sess.run([gradients_tar, model_sgx.output], feed_dict={model_sgx.input:input_np, y_tar:out_np})
        e = time.time()
        print(e-s)
        '''
        for _ in range(1):
            input_np = sess.run(input_tensor, feed_dict={})
            y_sgx = sess.run(model_sgx.output, feed_dict={model_sgx.input:input_np})
        ''' 
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)