from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import python.slalom.keras_fix

import sys
import os
import copy 

import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
from python.slalom.sgxdnn import model_to_json, SGXDNNUtils

def relu(x):
    return x*(x>0)

sgxutils = SGXDNNUtils(True, num_enclaves=1)

sgxutils.dnnl_init()

pool_src = np.array([2, 32, 224, 224], dtype=np.int32)
pool_dst = np.array([2, 32, 112, 112], dtype=np.int32)
ker_size = np.array([2,2], dtype=np.int32)
strides = np.array([2,2], dtype=np.int32)
padding = np.array([0,0], dtype=np.int32)

work_size = sgxutils.setup_maxpoolrelu(in_size=pool_src,
                                       out_size=pool_dst,
                                       ker_size=ker_size,
                                       strides=strides,
                                       padding=padding)




workspace = np.zeros(work_size[0], dtype=np.float32)
gs = np.zeros((2, 224, 224, 32), np.float32)

with tf.Session("") as sess:
    src = tf.random_uniform((2, 224, 224, 32), minval=-100.0, maxval=100.0)
    res = tf.random_uniform((2, 112, 112, 32), minval=-100.0, maxval=100.0)
    work = tf.random_uniform((2, 224, 224, 32), minval=-100.0, maxval=100.0)

    grad = tf.random_uniform((2, 112, 112, 32), minval=-100.0, maxval=100.0)

    res = sgxutils.maxpoolrelu_back(grad, gs, workspace)

    res.eval()
        
sgxutils.destroy()
