from keras.layers import *
from keras import layers
from keras.models import Model, Sequential
from keras import backend as K
from keras.engine import get_source_inputs
from keras.utils import layer_utils
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.engine.topology import load_weights_from_hdf5_group_by_name, h5py
import tensorflow as tf
from python.slalom.Activation import ResNetActivation
from python.slalom.Activation import ResNetBottom
from python.slalom.global_sgx import sgxutils
from keras.losses import mean_squared_error as MSE
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops

@ops.RegisterGradient("ResnetActivation")
def _resnet_activation_grad(op, grad):
    mean_shape = op.inputs[1].shape

    mean_grad = tf.zeros(shape=mean_shape)
    attribute = op.get_attr('mode')
    grad_in_1 = sgxutils.resnet_activation_back_op(grad_out=grad, 
                                                   act_mode=attribute)

    return [grad_in_1, mean_grad]

@ops.RegisterGradient("ResnetBottom")
def _resnet_bottom_grad(op, grad):
    mean_shape0 = op.inputs[2].shape
    mean_shape1 = op.inputs[3].shape

    mean_grad0 = tf.zeros(shape=mean_shape0)
    mean_grad1 = tf.zeros(shape=mean_shape1)

    grad_left, grad_right = sgxutils.resnet_bottom_back_op(grad)

    return [grad_left, grad_right, mean_grad0, mean_grad1]

def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    sgxutils.dnnl_init()
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
    with tf.Session(config=config) as sess:
        with tf.device("/gpu:0"):
            h = w = 112
            c = 64
            batch_size  = 3
            input_size  = (batch_size, h, w, c)
            output_size = (batch_size, h, w, c)
            input_tensor = tf.random_uniform(input_size, minval=-1000.0, maxval=1000.0, dtype=tf.float32)
            out_tensor = tf.random_uniform(output_size, minval=-1000.0, maxval=1000.0, dtype=tf.float32)

            print("========= Building GPU Model ==========")
            
            in_exp_tensor0 = tf.placeholder(tf.float32, shape=input_size)
            in_exp_tensor1 = tf.placeholder(tf.float32, shape=input_size)
            tar_layer      = ResNetBottom(right_norm=False, 
                                          privacy=False, 
                                          sgxutils=None,
                                          epsilon=0.0,
                                          use_bias=True,
                                          bias_shape=c)
            out_exp_tensor  = tar_layer(in_exp_tensor0, right=in_exp_tensor1, training=True)
            diff_exp_tensor = tf.placeholder(tf.float32, shape=output_size)
            loss_tar_fn     = MSE(diff_exp_tensor, out_exp_tensor)

            gradients_tar0  = K.gradients(loss_tar_fn, in_exp_tensor0)
            gradients_tar1  = K.gradients(loss_tar_fn, in_exp_tensor1)
            
            print("========= Building sgxutils Model ==========")
            in_sgx_tensor0 = tf.placeholder(tf.float32, shape=input_size)
            in_sgx_tensor1 = tf.placeholder(tf.float32, shape=input_size)
            sgx_layer      = ResNetBottom(right_norm=False, 
                                          privacy=True, 
                                          sgxutils=sgxutils,
                                          epsilon=0.0,
                                          use_bias=True,
                                          bias_shape=c)
            
            out_sgx_tensor  = sgx_layer(in_sgx_tensor0, right=in_sgx_tensor1, training=True)
            diff_sgx_tensor = tf.placeholder(tf.float32, shape=output_size)
            loss_sgx_fn     = MSE(diff_sgx_tensor, out_sgx_tensor)
            
            gradients_sgx0  = K.gradients(loss_sgx_fn, in_sgx_tensor0)
            gradients_sgx1  = K.gradients(loss_sgx_fn, in_sgx_tensor1)
            

            tar_layer.set_weights(sgx_layer.get_weights());
            input_np0 = sess.run(input_tensor)
            input_np1 = sess.run(input_tensor)
            out_np   = sess.run(out_tensor)
            
            
            sgx_res    = sess.run([gradients_sgx0, out_sgx_tensor], feed_dict={in_sgx_tensor0:  input_np0, 
                                                                  in_sgx_tensor1:  input_np1,
                                                                  diff_sgx_tensor: out_np
                                                                  })
            
            tar_res    = sess.run([gradients_tar0, out_exp_tensor], feed_dict={in_exp_tensor0:  input_np0, 
                                                                  in_exp_tensor1:  input_np1,
                                                                  diff_exp_tensor: out_np
                                                                  })
            grad_sgx0_np = sgx_res[0][0]
            grad_tar0_np = tar_res[0][0]
            y_sgx        = sgx_res[1]
            y_exp        = tar_res[1]
            #print(grad_sgx0_np.shape, grad_tar0_np.shape)
            print(abs(grad_sgx0_np[0] - grad_tar0_np[0]).mean(), abs(grad_tar0_np[0]).mean())
            print(abs(grad_sgx0_np[1] - grad_tar0_np[1]).mean(), abs(grad_tar0_np[1]).mean())
            print(y_sgx[0][0][0][0], y_exp[0][0][0][0])
            print(grad_sgx0_np[0][0][0][0], grad_tar0_np[0][0][0][0])
            '''
            tar_layer.set_weights(sgx_layer.get_weights());
            input_np0 = sess.run(input_tensor)
            input_np1 = sess.run(input_tensor)
            out_np   = sess.run(out_tensor)
    
            
            print(abs(y_sgx[3] - y_exp[3]).sum(), abs(y_exp[3]).sum())
            print(abs(y_sgx[4] - y_exp[4]).sum(), abs(y_exp[4]).sum())
            print(y_sgx[0][0][0][0], y_exp[0][0][0][0])
            print(y_sgx[0].mean(), y_exp[0].mean())
            print(y_sgx[1].mean(), y_exp[1].mean())
            print(y_sgx[3].mean(), y_exp[3].mean())
            print(y_sgx[4].mean(), y_exp[4].mean())
            '''

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)
