import tensorflow as tf
import numpy as np

from python.slalom.mobtest import MobileNetV2 as MobileNetV2_sgx
from python.slalom.global_sgx import sgxutils

import keras.backend as K
import keras.layers as layers
import time

import torch
import torch.nn as nn



def main(args):
    config = tf.ConfigProto(log_device_placement=False)
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.90
    config.gpu_options.allow_growth = True
    bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
    sgxutils.dnnl_init()
    with tf.Session(config=config) as sess:
        input_tensor  = tf.random_uniform((1, 112, 112, 32), minval=-500.0, maxval=500.0, dtype=tf.float32)
        output_tensor = tf.random_uniform((1, 112, 112, 32), minval=-500.0, maxval=500.0, dtype=tf.float32)
        kernel_tensor = tf.random_uniform((3, 3, 32, 1), minval=-500.0, maxval=500.0, dtype=tf.float32)
        enclave_list  = []
        input_np  = sess.run(input_tensor, feed_dict={})
        output_np = sess.run(output_tensor, feed_dict={})
        kernel_np = sess.run(kernel_tensor, feed_dict={})
        
        enclave_list.append(["depthwise", (1, 112, 112, 32), 
                            (1, 112, 112, 32), (3, 3), (1, 1), (0, 0), 
                            kernel_np, None])
        
        sgxutils.setup_layer(enclave_list)
        sgxutils.setup_final_reorder()
        sgxutils.enclave_update_backward()
        res            = sgxutils.forward(input_np,   (112, 112, 32))
        res_sgx_tensor = torch.tensor(res)
        res_sgx_tensor = res_sgx_tensor.reshape(1, 112, 112, 32).permute(0, 3, 1, 2)
        grad_re        = sgxutils.backward(output_np, (3, 3, 32, 1))
        grad_re_tensor = torch.tensor(grad_re)
        grad_re_tensor = grad_re_tensor.reshape(3, 3, 32, 1).permute(2, 3, 0, 1)

        weight_torch = torch.tensor(kernel_np)
        weight_torch = weight_torch.permute(2, 3, 0, 1)
        input_torch  = torch.tensor(input_np)
        input_torch  = input_torch.permute(0, 3, 1, 2)
        grad_out_torch = torch.tensor(output_np)
        grad_out_torch  = grad_out_torch.permute(0, 3, 1, 2)

        input_torch.requires_grad = True
        dep_layer = nn.Conv2d(in_channels=32, 
                              out_channels=32, 
                              kernel_size=3, 
                              stride=(1, 1), 
                              padding=(1, 1), 
                              groups=32, 
                              bias=False)
        print(dep_layer.weight.size(), weight_torch.size())
        dep_layer.weight = torch.nn.Parameter(weight_torch)
        res_exp          = dep_layer(input_torch)
        res_exp.backward(gradient=grad_out_torch)
        
        print((dep_layer.weight.grad - grad_re_tensor).sum())

        diff = dep_layer.weight.grad - grad_re_tensor
        print(dep_layer.weight.grad.mean(), grad_re_tensor.mean())
        print(dep_layer.weight.grad[0][0][0][0], grad_re_tensor[0][0][0][0])
        print(dep_layer.weight.grad[2][0][0][0], grad_re_tensor[2][0][0][0])
        for i in range(32):
            print(diff[i][0])

        '''
        e = time.time() 
        print(e - s)
        s = time.time()
        print(time.time() - s)
        
        print("================ complete setup ==================")
        s = time.time()
        sgxutils.forward(input_np, (112, 112, 16))
        e = time.time() 
        print(e - s)
        
        ker_weight = model_sgx.layers[4].get_weights()[0]
        
        res_sgx = sgxutils.forward(input_np, (1, 112, 112, 32))
        res_brute = 0.0
        for i in range(32):
            res_brute += inter[0][0][0][i] * ker_weight[0][0][i][0]

        print(ker_weight.shape)

        print("===========================================")
        print((res_sgx - y_sgx).mean())
        print("sums  ", res_sgx.sum(), y_sgx.sum())
        print("means ", res_sgx.mean(), y_sgx.mean())
        print("min   ", res_sgx.min(), y_sgx.min())
        print("max   ", res_sgx.max(), y_sgx.max())


        print("single", res_sgx[0][0][0][0], y_sgx[0][0][0][0], res_brute)
        '''
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    main(args)