from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import python.slalom.keras_fix

import sys
import os
import copy 

import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
from python.slalom.sgxdnn import model_to_json, SGXDNNUtils

def relu(x):
    return x*(x>0)

sgxutils = SGXDNNUtils(True, num_enclaves=1)

sgxutils.dnnl_init()

#relu_src = np.array([2, 2, 2, 2], dtype=np.int32)
#sgxutils.setup_relu(relu_src)

#relu_src = np.array([2, 2, 2, 2], dtype=np.int32)
#sgxutils.setup_relu(relu_src)

pool_src = np.array([2, 32, 224, 224], dtype=np.int32)
pool_dst = np.array([2, 32, 112, 112], dtype=np.int32)
ker_size = np.array([2,2], dtype=np.int32)
strides = np.array([2,2], dtype=np.int32)
padding = np.array([0,0], dtype=np.int32)

work_size = sgxutils.setup_maxpoolrelu(in_size=pool_src,
                                       out_size=pool_dst,
                                       ker_size=ker_size,
                                       strides=strides,
                                       padding=padding)




workspace = np.zeros(work_size[0], dtype=np.float32)
gs = np.zeros((2, 224, 224, 32), np.float32)

with tf.Session("") as sess:
    src = np.float32(np.random.uniform(low=-100.0, high=100.0, size=(2, 224, 224, 32)))
    res = np.float32(np.random.uniform(low=-100.0, high=100.0, size=(2, 112, 112, 32)))
    
    work, res = sgxutils.maxpool(src, res, workspace)

    for i in range(0,224,2):
        for j in range(0,224,2):
            e1 = src[0][i][j][0]
            e2 = src[0][i+1][j][0]
            e3 = src[0][i][j+1][0]
            e4 = src[0][i+1][j+1][0]
            ep = max(max(e1, e2), max(e3, e4))
            index = int(i/2*112*32 + j/2*32)
            re = res[index]
            if re != ep:
                print(re, ep)

    grad = np.float32(np.random.uniform(low=-100.0, high=100.0, size=(2, 112, 112, 32)))

    sgxutils.maxpoolrelu_back(grad, gs, res, workspace)


    b = tf.random_uniform((32, 112,112, 64), dtype='float32')
    sgxutils.addresstest(b)
    
#with tf.Session("") as sess:
#    for i in range(100):
#        src1 = np.float32(np.random.uniform(low=-100.0, high=100.0, size=(2, 2, 2, 2)))
#        src2 = np.float32(np.random.uniform(low=-100.0, high=100.0, size=(2, 2, 2, 2)))
#        grad1 = np.float32(np.random.uniform(low=-100.0, high=100.0, size=(2, 2, 2, 2))).reshape(-1)
#        grad2 = np.float32(np.random.uniform(low=-100.0, high=100.0, size=(2, 2, 2, 2))).reshape(-1)
#        res1 = np.float32(np.random.uniform(low=-100.0, high=100.0, size=(2, 2, 2, 2)))
#        res2 = np.float32(np.random.uniform(low=-100.0, high=100.0, size=(2, 2, 2, 2)))
    
#        r2 = sgxutils.slalom_relu_back(grad2, res2, src2)
#        r1 = sgxutils.slalom_relu_back(grad1, res1, src1)

#        src1 = src1.reshape(-1) > 0
#        src2 = src2.reshape(-1) > 0
#        ep1 = src1 * grad1
#        ep2 = src2 * grad2
#        assert(np.sum(ep1 - r1) == 0.0)
#        assert(np.sum(ep2 - r2) == 0.0)
        
sgxutils.destroy()
