import tensorflow as tf
import numpy as np
import keras.backend as K
import torch
import torch.nn as nn
from keras.layers import BatchNormalization

from python.slalom.mobileNetv2 import MobileNetV2 as MobileNetV2_sgx
from python.slalom.global_sgx import sgxutils
from python.slalom.Activation import Activation


import keras.layers as layers




def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    h = w = 224
    c = 8
    input_shape = (3, h, w, c)
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)

    with tf.Session(config=config) as sess:
        input_tensor = tf.random_uniform(input_shape, minval=-10.0, maxval=14.0, dtype=tf.float32)

        with tf.device("/gpu:0"):
            
            print("======== BEGIN TO BUILD ENCLAV ========")
            x2 = layers.Input(shape=(h, w, c))
            y2 = layers.Input(shape=(h, w, c))
            g2 = layers.Input(shape=(h, w, c))
            a2 = layers.Input(shape=(h, w, c))

            y_sgx = Activation(
                               axis=-1,
                               act_mode="bnrelu",
                               privacy=True,
                               epsilon=0.0,
                               momentum=1.0,
                               sgxutils=sgxutils,
                               name="sfs")(x2, skip_input=y2)
            
            grad_out = sgxutils.batchnorm_dark_back(g2, x2, y2, a2)

        # beta & gamma uninitialized
        K.get_session().run(tf.initialize_all_variables())

        input_np1 = sess.run(input_tensor, feed_dict={})
        input_np2 = sess.run(input_tensor, feed_dict={})
        grad = sess.run(input_tensor, feed_dict={})
        a = sess.run(input_tensor, feed_dict={})
        print("======== EVALUATING SGX ========")
        y_sgx_val = sess.run([y_sgx], feed_dict={x2:input_np1, y2:input_np2, g2:grad, a2:a})
        grad_out_val = sess.run(grad_out, feed_dict={x2:input_np1, y2:input_np2, g2:grad, a2:a})
        
        #print(grad_out_val.shape, grad_norm.shape)
        model = nn.Sequential(nn.ReLU6())
        x = torch.from_numpy(a).float()
        grad_ = torch.from_numpy(grad).float()
        x.requires_grad = True

        y = model(x)
        y.backward(gradient=grad_)
        
        

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)