import tensorflow as tf
import numpy as np

from python.slalom.mobileNetv2 import MobileNetV2 as MobileNetV2_sgx
from python.slalom.resnet_sp import ResNet50
from python.slalom.global_sgx import sgxutils
import time
import keras.backend as K
from keras.losses import mean_squared_error as MSE

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
import sys
from python.slalom.Activation import ResNetActivation

config = tf.ConfigProto(log_device_placement=False)
config.allow_soft_placement = True
config.gpu_options.per_process_gpu_memory_fraction = 0.90
config.gpu_options.allow_growth = True
bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
batch_size = 6
print("her")

with tf.Session(config=config) as sess:
    
    with tf.device("/gpu:0"):
        input_tensor = tf.random_uniform((batch_size, 224, 224, 3), minval=-1000.0, maxval=1000.0, dtype=tf.float32)
        out_tensor = tf.random_uniform((batch_size, 1000), minval=-1000.0, maxval=1000.0, dtype=tf.float32)

        in_tensor  = tf.placeholder(shape=(batch_size, 224, 224, 3), dtype=tf.float32)
        res_tensor = ResNetActivation(act_mode='bnzerorelu', epsilon=0.0, privacy=False, sgxutils=None)(in_tensor, training=True)
        K.get_session().run(tf.initialize_all_variables())

        input_np = sess.run(input_tensor, feed_dict={})

        res = sess.run(res_tensor, feed_dict={in_tensor:input_np})
        print("her")
        expected_res = (input_np - input_np.mean()) / input_np.std()
        print(expected_res.shape)
        expected_res = np.maximum(expected_res, 0)
        print((expected_res-res).mean())
        print(expected_res[0][0][0][0], res[0][0][0][0])
        print(expected_res.shape, res.shape)