from python.slalom.quant_layers import ActivationQ , transform, get_all_linear_layers
from python.slalom.global_sgx import sgxutils
from python.slalom.sgxdnn import model_to_json
import tensorflow as tf
import keras
from keras.layers import Dense, Activation
from keras import Sequential
import numpy as np

out_size= 224*224*3
in_size =  224

y = tf.random_uniform((1, out_size), dtype="float32")
x = tf.random_normal((1, in_size), dtype="float32")
shape=np.array([1,1,1,out_size], dtype=np.int32)



sgxutils.dnnl_init()
sgxutils.setup_relu(shape)
with tf.Session("") as sess:

    with tf.device('/cpu:0'):
        model = Sequential()
        model.add(Dense(out_size, input_shape=(in_size,)))
        model.add(Activation('relu'))

    num_linear_layers = len(get_all_linear_layers(model))
        

    queues = [tf.FIFOQueue(capacity=2 + 1, dtypes=[tf.float32]) for _ in range(num_linear_layers)]


    
    model, linear_ops_in, linear_ops_out = transform(model, log=False, quantize=True, verif_preproc=False,
                                                     slalom=True, slalom_integrity=False, slalom_privacy=True,
                                                     bits_w=0,
                                                     bits_x=0,
                                                     sgxutils=sgxutils, queues=queues
    )


    model_json, weights, model_json2, weigh2 = model_to_json(sess=sess, model=model,
                                                             dtype=np.float32, verif_preproc=False,
                                                             slalom_privacy=True,
                                                             bits_w=0,
                                                             bits_x=0)

                    
    sgxutils.load_model(model_json, weights, dtype=np.float32, verify=False, verify_preproc=False)
    sgxutils.slalom_init(False, True, 1)

act = model.layers.pop()
model = Sequential()
model.add(Dense(out_size, input_shape=(in_size,)))
model.add(act)

model.compile(optimizer=keras.optimizers.RMSprop(),
              # Loss function to minimize
              loss="categorical_crossentropy",
              # List of metrics to monitor
              metrics=['accuracy'])

with tf.Session("") as sess:
    history = model.fit(x.eval(), y.eval(),
                     steps_per_epoch=1,
                        epochs=1,
                	)
