from keras.layers import *
from keras import layers
from keras.models import Model, Sequential
from keras import backend as K
from keras.engine import get_source_inputs
from keras.utils import layer_utils
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.engine.topology import load_weights_from_hdf5_group_by_name, h5py
import tensorflow as tf
from python.slalom.Activation import ResNetActivation
from python.slalom.Activation import ResNetBottom
from python.slalom.global_sgx import sgxutils
from keras.losses import mean_squared_error as MSE
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops

@ops.RegisterGradient("ResnetActivation")
def _resnet_activation_grad(op, grad):
    mean_shape = op.inputs[1].shape

    mean_grad = tf.zeros(shape=mean_shape)
    attribute = op.get_attr('mode')
    grad_in_1 = sgxutils.resnet_activation_back_op(grad_out=grad, 
                                                   act_mode=attribute)

    return [grad_in_1, mean_grad]

@ops.RegisterGradient("ResnetBottom")
def _resnet_bottom_grad(op, grad):
    mean_shape0 = op.inputs[2].shape
    mean_shape1 = op.inputs[3].shape

    mean_grad0 = tf.zeros(shape=mean_shape0)
    mean_grad1 = tf.zeros(shape=mean_shape1)

    grad_left, grad_right = self.resnet_bottom_back_op(grad)

    return [grad_left, grad_right, mean_shape0, mean_shape1]

def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    sgxutils.dnnl_init()
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
    with tf.Session(config=config) as sess:
        with tf.device("/gpu:0"):
            h = w = 112
            c = 64
            batch_size  = 3
            input_size  = (batch_size, h, w, c)
            output_size = (batch_size, 55, 55, c)
            input_tensor = tf.random_uniform(input_size, minval=-1000.0, maxval=1000.0, dtype=tf.float32)
            out_tensor = tf.random_uniform(output_size, minval=-1000.0, maxval=1000.0, dtype=tf.float32)

            print("========= Building GPU Model ==========")
            
            in_exp_tensor = tf.placeholder(tf.float32, shape=input_size)
            tar_layer      = ResNetActivation(act_mode='bnrelupool', 
                                   privacy=False, 
                                   sgxutils=None,
                                   pool_window=(3, 3),
                                   strides=(2, 2),
                                   epsilon=0.0,
                                   use_bias=True,
                                   bias_shape=c)
            out_exp_tensor  = tar_layer(in_exp_tensor, training=True)
            diff_exp_tensor = tf.placeholder(tf.float32, shape=output_size)
            loss_tar_fn     = MSE(diff_exp_tensor, out_exp_tensor)

            gradients_tar   = K.gradients(loss_tar_fn, in_exp_tensor)
            
            print("========= Building sgxutils Model ==========")
            in_sgx_tensor  = tf.placeholder(tf.float32, shape=input_size)
            sgx_layer      = ResNetActivation(act_mode='bnrelupool', 
                                   privacy=True, 
                                   sgxutils=sgxutils,
                                   pool_window=(3, 3),
                                   epsilon=0.0,
                                   strides=(2, 2),
                                   use_bias=True,
                                   bias_shape=c)
            
            out_sgx_tensor  = sgx_layer(in_sgx_tensor, training=True)
            diff_sgx_tensor = tf.placeholder(tf.float32, shape=output_size)
            loss_sgx_fn     = MSE(diff_sgx_tensor, out_sgx_tensor)
            gradients_sgx   = K.gradients(loss_sgx_fn, in_sgx_tensor)

            #tar_layer.set_weights(sgx_layer.get_weights());
            input_np = sess.run(input_tensor)
            out_np   = sess.run(out_tensor)
            grad_in_exp = sess.run(gradients_tar, feed_dict={in_exp_tensor:input_np,diff_exp_tensor:out_np})[0]
            #print(grad_in_exp[0].shape, grad_in_exp[0].mean(), grad_in_exp[0].std())
            #print(input_np.shape, out_np.shape )
            grad_in_sgx = sess.run(gradients_sgx, feed_dict={in_sgx_tensor:input_np,diff_sgx_tensor:out_np})[0]
            #print(len(grad_in_exp), len(grad_in_sgx))
            print(abs(grad_in_exp[0] - grad_in_sgx[0]).sum(), abs(grad_in_exp[0]).sum())
            print(abs(grad_in_exp[1] - grad_in_sgx[1]).sum(), abs(grad_in_exp[1]).sum())
            print(abs(grad_in_exp[2] - grad_in_sgx[2]).sum(), abs(grad_in_exp[2]).sum())
            
            print("=================================")
            print(grad_in_exp[0][0][0][0], grad_in_sgx[0][0][0][0])
            print(grad_in_exp[0][0][1][0], grad_in_sgx[0][0][1][0])
            print(grad_in_exp[0][0][2][0], grad_in_sgx[0][0][2][0])            
            print(grad_in_exp[0][1][0][0], grad_in_sgx[0][1][0][0])
            print(grad_in_exp[0][1][1][0], grad_in_sgx[0][1][1][0])
            print(grad_in_exp[0][1][2][0], grad_in_sgx[0][1][2][0])            
            print(grad_in_exp[0][2][0][0], grad_in_sgx[0][2][0][0])
            print(grad_in_exp[0][2][1][0], grad_in_sgx[0][2][1][0])
            print(grad_in_exp[0][2][2][0], grad_in_sgx[0][2][2][0])
            print("=================================")
            print(grad_in_exp[1][0][0][0], grad_in_sgx[1][0][0][0])
            print(grad_in_exp[1][0][1][0], grad_in_sgx[1][0][1][0])
            print(grad_in_exp[1][0][2][0], grad_in_sgx[1][0][2][0])
            print(grad_in_exp[1][1][0][0], grad_in_sgx[1][1][0][0])
            print(grad_in_exp[1][1][1][0], grad_in_sgx[1][1][1][0])
            print(grad_in_exp[1][1][2][0], grad_in_sgx[1][1][2][0])
            print(grad_in_exp[1][2][0][0], grad_in_sgx[1][2][0][0])
            print(grad_in_exp[1][2][1][0], grad_in_sgx[1][2][1][0])
            print(grad_in_exp[1][2][2][0], grad_in_sgx[1][2][2][0])
            print("=================================")

            #y_exp    = sess.run(out_exp_tensor, feed_dict={in_exp_tensor:input_np})
            
            #y_sgx    = sess.run(out_sgx_tensor, feed_dict={in_sgx_tensor:input_np})
            #print(y_sgx.shape)
            


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)