import tensorflow as tf
import numpy as np
import keras.backend as K

from keras.layers import BatchNormalization

from python.slalom.mobileNetv2 import MobileNetV2 as MobileNetV2_sgx
from python.slalom.global_sgx import sgxutils
from python.slalom.Activation import Activation


import keras.layers as layers

def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    h = w = 224
    c = 20
    batch = 15
    input_shape = (batch, h, w, c)
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)

    with tf.Session(config=config) as sess:
        input_tensor = tf.random_uniform(input_shape, minval=-10000.0, maxval=1000.0, dtype=tf.float32)

        with tf.device("/gpu:0"):
            print("======== BEGIN TO BUILD TARGET ========")

            x1 = layers.Input(shape=(h, w, c))
            y1 = layers.Input(shape=(h, w, c))

            y_tar = Activation(axis=-1,
                             privacy=False,
                             act_mode="bn",
                             sgxutils=sgxutils,
                             epsilon=0.0,
                             momentum=1.0,
                             name="target")(x1, skip_input=y1, training=True)
           

        input_np1 = sess.run(input_tensor, feed_dict={})
        input_np2 = sess.run(input_tensor, feed_dict={})

        print("======== EVALUATING GPU ========")
        y_tar_val = sess.run(y_tar, feed_dict={x1:input_np1, y1:input_np2})
        
        print("======== EVALUATING ALG ========")
        mean = input_np1.mean(axis=(0,1,2))
        std  = input_np1.std(axis=(0,1,2))
        print(mean.shape, std.shape)
        
        res  = np.zeros(input_shape)

        for bi in range(batch):
            for hi in range(h):
                for wi in range(w):
                    for ci in range(c):
                        res[bi][hi][wi][ci] = (input_np1[bi][hi][wi][ci] - mean[ci]) / std[ci]

        res2 = (input_np1 - input_np1.mean()) / input_np1.std() 
        print((res - y_tar_val).mean())
        print((res2- y_tar_val).mean())

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)