from ctypes import *
from ctypes import POINTER
import json

import numpy as np
import time

import tensorflow as tf
import sys

from python.slalom.parameters import get_matrix, identity

SGXDNNLIB = "App/enclave_bridge.so"
DNNLIB = "lib/sgxdnn.so"

SGX_SLALOM_LIB = "lib/slalom_ops_sgx.so"
SLALOM_LIB = "lib/slalom_ops.so"


# interface with the C++ SGXDNN library
class SGXDNNUtils(object):

    def __init__(self, use_sgx=False, num_enclaves=1):
        self.use_sgx = use_sgx
        self.is_first = 1
        if use_sgx:
            self.lib = cdll.LoadLibrary(SGXDNNLIB)
            self.lib.initialize_enclave.restype = c_ulong
            self.eid = []
            for i in range(num_enclaves): 
                self.eid.append(self.lib.initialize_enclave())

            self.slalom_lib = tf.load_op_library(SGX_SLALOM_LIB)
        else:
            self.lib = cdll.LoadLibrary(DNNLIB)
            self.eid = None
            self.slalom_lib = tf.load_op_library(SLALOM_LIB)
            self.dark_batch_size = 0

    def set_dark_batch_size(self, dark_batch_size):

        self.dark_batch_size = dark_batch_size

    def destroy(self):
        if self.use_sgx and self.eid is not None:
            self.lib.destroy_enclave.argtypes = [c_ulong]
            for eid in self.eid:
                self.lib.destroy_enclave(eid)
            self.eid = None

    def benchmark(self, num_threads):
        if self.use_sgx:
            self.lib.sgxdnn_benchmarks.argtypes = [c_ulong, c_int]
            self.lib.sgxdnn_benchmarks(self.eid[0], num_threads)
        else:
            self.lib.sgxdnn_benchmarks(num_threads)

    def load_model(self, model_json, weights, dtype=np.float32, verify=False, verify_preproc=False):

        assert np.all([w.dtype == dtype for w in weights])

        assert dtype == np.float32
        print("loading model in float32")
        ptr_type = c_float
        if verify:
            load_method = self.lib.load_model_float_verify
        else:
            load_method = self.lib.load_model_float

        filter_ptrs = (POINTER(ptr_type) * len(weights))()
        for i in range(len(weights)):
            filter_ptrs[i] = np.ctypeslib.as_ctypes(weights[i])

        print("loading model...")
        if self.use_sgx:
            if verify:
                load_method.argtypes = [c_ulong, c_char_p, POINTER(POINTER(ptr_type)), c_bool]
                for eid in self.eid:
                    load_method(eid, json.dumps(model_json).encode('utf-8'), filter_ptrs, verify_preproc)
            else:
                load_method.argtypes = [c_ulong, c_char_p, POINTER(POINTER(ptr_type))]
                for eid in self.eid:
                    load_method(eid, json.dumps(model_json).encode('utf-8'), filter_ptrs)
        else:
            if verify:
                load_method.argtypes = [c_char_p, POINTER(POINTER(ptr_type)), c_bool]
                load_method(json.dumps(model_json).encode('utf-8'), filter_ptrs, verify_preproc)
            else:
                load_method.argtypes = [c_char_p, POINTER(POINTER(ptr_type))]
                load_method(json.dumps(model_json).encode('utf-8'), filter_ptrs)
        print("model loaded")


    def setup_conv(self, in_size, out_size, kernel_size, 
                   strides, padding, kernel_weight, bias_weight, is_first, eid_idx=0):
        setup_conv = self.lib.sgx_conv_create
        
        print(in_size, out_size, kernel_size, strides, padding, kernel_weight.shape, bias_weight, is_first)
        # set up input type
        setup_conv.argtypes = [c_ulong, POINTER(c_int), POINTER(c_int),
                                        POINTER(c_int), POINTER(c_int),
                                        POINTER(c_int), POINTER(c_float),
                                        POINTER(c_float), c_int]
        in_size_np       = np.ctypeslib.as_ctypes(np.array((1, in_size[1], in_size[2], in_size[3])).astype(np.int32))
        out_size_np      = np.ctypeslib.as_ctypes(np.array((1, out_size[1], out_size[2], out_size[3])).astype(np.int32))
        strides_np       = np.ctypeslib.as_ctypes(np.array((strides[0], strides[1])).astype(np.int32))
        padding_np       = np.ctypeslib.as_ctypes(np.array((padding[0], padding[1])).astype(np.int32))
        kernel_size_np   = np.ctypeslib.as_ctypes(np.array((kernel_size[0], kernel_size[1], in_size[3], out_size[3])).astype(np.int32))
        kernel_weight_np = np.ctypeslib.as_ctypes(kernel_weight.reshape(-1))
        
        is_first_c = 0

        if is_first:
            is_first_c = 1
        if bias_weight is None:
            bias_weight = np.zeros(out_size[3]).astype(np.float32)
        bias_weight_np = np.ctypeslib.as_ctypes(bias_weight.reshape(-1))

        setup_conv(self.eid[eid_idx], in_size_np, out_size_np, kernel_size_np, strides_np,
                   padding_np, kernel_weight_np, bias_weight_np, is_first_c)
    def setup_depth_conv(self, in_size, out_size, kernel_size, 
                   strides, padding, kernel_weight, bias_weight, is_first, eid_idx=0):
        setup_conv = self.lib.sgx_depth_conv_create
        
        print(in_size, out_size, kernel_size, strides, padding, kernel_weight.shape, bias_weight, is_first)
        # set up input type
        setup_conv.argtypes = [c_ulong, POINTER(c_int), POINTER(c_int),
                                        POINTER(c_int), POINTER(c_int),
                                        POINTER(c_int), POINTER(c_float),
                                        POINTER(c_float), c_int]
        in_size_np       = np.ctypeslib.as_ctypes(np.array((1, in_size[1], in_size[2], in_size[3])).astype(np.int32))
        out_size_np      = np.ctypeslib.as_ctypes(np.array((1, out_size[1], out_size[2], out_size[3])).astype(np.int32))
        strides_np       = np.ctypeslib.as_ctypes(np.array((strides[0], strides[1])).astype(np.int32))
        padding_np       = np.ctypeslib.as_ctypes(np.array((padding[0], padding[1])).astype(np.int32))
        kernel_size_np   = np.ctypeslib.as_ctypes(np.array((kernel_size[0], kernel_size[1], in_size[3], out_size[3])).astype(np.int32))
        kernel_weight_np = np.ctypeslib.as_ctypes(kernel_weight.reshape(-1))
        
        is_first_c = 0

        if is_first:
            is_first_c = 1
        if bias_weight is None:
            bias_weight = np.zeros(out_size[3]).astype(np.float32)
        bias_weight_np = np.ctypeslib.as_ctypes(bias_weight.reshape(-1))

        setup_conv(self.eid[eid_idx], in_size_np, out_size_np, kernel_size_np, strides_np,
                   padding_np, kernel_weight_np, bias_weight_np, is_first_c)
    def setup_bn(self, in_size, mode, eps, momentum, eid_idx=0):
        in_size_np = np.ctypeslib.as_ctypes(np.array((1, in_size[1], in_size[2], in_size[3])).astype(np.int32))

        setup_sgx_bn = self.lib.sgx_bn_create

        print(in_size, mode)
        setup_sgx_bn.argtypes = [c_ulong, POINTER(c_int), c_int, c_float, c_float]

        setup_sgx_bn(self.eid[eid_idx], in_size_np, mode, eps, momentum)

    
    def setup_linear(self, in_size, out_size, kernel_data, bias_data, eid_idx=0):
        in_size_np = np.ctypeslib.as_ctypes(np.array((1, in_size[1], in_size[2], in_size[3])).astype(np.int32))
        out_size_np = np.ctypeslib.as_ctypes(np.array((1, out_size[1])).astype(np.int32))
        kernel_weight_np = np.ctypeslib.as_ctypes(kernel_data.reshape(-1))
        

        bias_weight_np = np.ctypeslib.as_ctypes(bias_data.reshape(-1))

        setup_sgx_linear = self.lib.sgx_linear_create
        setup_sgx_linear.argtypes = [c_ulong, POINTER(c_int), POINTER(c_int), POINTER(c_float), POINTER(c_float)]
        setup_sgx_linear(self.eid[eid_idx], in_size_np, out_size_np, kernel_weight_np, bias_weight_np)

    def setup_pool(self, in_size, out_size, kernel_size, stride, padding, type1, eid_idx=0):
        in_size_np      = np.ctypeslib.as_ctypes(np.array((1, in_size[1], in_size[2], in_size[3])).astype(np.int32))
        out_size_np     = np.ctypeslib.as_ctypes(np.array((1, out_size[1], out_size[2], out_size[3])).astype(np.int32))
        ker_size_np     = np.ctypeslib.as_ctypes(np.array((kernel_size[0], kernel_size[1])).astype(np.int32))
        stride_size_np  = np.ctypeslib.as_ctypes(np.array((stride[0],  stride[1])).astype(np.int32))
        padding_size_np = np.ctypeslib.as_ctypes(np.array((padding[0], padding[1])).astype(np.int32))

        setup_sgx_pool  = self.lib.sgx_pool_create
        setup_sgx_pool.argtypes = [c_ulong, POINTER(c_int), POINTER(c_int), POINTER(c_int), POINTER(c_int), POINTER(c_int), c_int]
        setup_sgx_pool(self.eid[eid_idx], in_size_np, out_size_np, ker_size_np, stride_size_np, padding_size_np, type1)
    
    def setup_sgx_relu(self, in_size, eid_idx=0):
        in_size_np      = np.ctypeslib.as_ctypes(np.array((1, in_size[1], in_size[2], in_size[3])).astype(np.int32))
        setup_sgx_relu  = self.lib.sgx_relu_create

        setup_sgx_relu.argtypes = [c_ulong, POINTER(c_int)]

        setup_sgx_relu(self.eid[eid_idx], in_size_np)


    def setup_resblock_init(self, in_size, out_size, stride, identity, eid_idx=0):
        in_size_np      = np.ctypeslib.as_ctypes(np.array((1, in_size[1], in_size[2], in_size[3])).astype(np.int32))
        out_size_np     = np.ctypeslib.as_ctypes(np.array((1, out_size[1], out_size[2], out_size[3])).astype(np.int32))
        stride_size_np  = np.ctypeslib.as_ctypes(np.array((stride[0],  stride[1])).astype(np.int32))

        identity_int    = 0

        if identity:
            identity_int = 1
        setup_sgx_resblock = self.lib.resblock_init

        setup_sgx_resblock.argtypes = [c_ulong, POINTER(c_int), POINTER(c_int), POINTER(c_int), c_int]
        setup_sgx_resblock(self.eid[eid_idx], in_size_np, out_size_np, stride_size_np, identity_int)

    def setup_resblock_compl(self, eid_idx=0):
        setup_sgx_resblock = self.lib.resblock_compl
        setup_sgx_resblock.argtypes = [c_ulong]
        setup_sgx_resblock(self.eid[eid_idx])



    def sgx_print_timing(self, eid_idx=0):
        func = self.lib.print_time_report
        func.argtypes = [c_ulong]
        func(self.eid[eid_idx])
    def sgx_reset_timing(self, eid_idx=0):
        func = self.lib.reset_timing
        func.argtypes = [c_ulong]
        func(self.eid[eid_idx])
    def setup_layer(self, arg, eid_idx=0):
        for c, line  in enumerate(arg):
            name = line[0]
            print("==================================")
            print(name, c)
            if (name == "conv2d"):
                in_size = line[1] 
                out_size = line[2] 
                kernel = line[3] 
                strides = line[4] 
                padding = line[5] 
                kernel_weight = line[6] 
                bias_weight = line[7]
                self.setup_conv(in_size=in_size, 
                                out_size=out_size, 
                                kernel_size=kernel, 
                                strides=strides, 
                                padding=padding, 
                                kernel_weight=kernel_weight,
                                bias_weight=bias_weight,
                                is_first=self.is_first,
                                eid_idx=eid_idx)
                self.is_first = False
            elif (name == "bn"):
                in_size = line[1]
                mode    = line[2]
                eps     = line[3]
                momentum= line[4]
                self.setup_bn(in_size, mode, eps, momentum)
            elif (name == "depthwise"):
                in_size = line[1] 
                out_size = line[2] 
                kernel = line[3] 
                strides = line[4] 
                padding = line[5] 
                kernel_weight = line[6] 
                bias_weight = line[7]
                self.setup_depth_conv(in_size=in_size, 
                                out_size=out_size, 
                                kernel_size=kernel, 
                                strides=strides, 
                                padding=padding, 
                                kernel_weight=kernel_weight,
                                bias_weight=bias_weight,
                                is_first=self.is_first,
                                eid_idx=eid_idx)

                self.is_first = False
            elif (name == "inv_start"):
                self.inverted_init(eid_idx=eid_idx)
            elif (name == "inv_end"):
                self.inverted_compl(eid_idx=eid_idx)
            elif (name == "linear"):
                in_size     = line[1]
                out_size    = line[2]
                kernel_data = line[3]
                bias_data   = line[4]

                self.setup_linear(in_size, out_size, kernel_data, bias_data, eid_idx=0)
            
            elif (name == "pool"):
                in_size     = line[1]
                out_size    = line[2]
                kernel_size = line[3]
                stride      = line[4]
                padding     = line[5]
                type1       = line[6]
                print(in_size, out_size, kernel_size, stride, padding)
                self.setup_pool(in_size, out_size, kernel_size, stride, padding, type1, eid_idx=0)

            elif (name == "relu"):
                in_size = line[1]
                self.setup_sgx_relu(in_size, eid_idx=0)

            elif (name == "resblock_start"):
                in_size     = line[1]
                out_size    = line[2]
                stride      = line[3]
                identity    = line[4]
                self.setup_resblock_init(in_size, out_size, stride, identity, eid_idx=0)

            elif (name == "resblock_compl"):
                self.setup_resblock_compl(eid_idx=0)

    def forward(self, x, out_size, eid_idx=0):
        dtype = np.float32
        x_typed = x.reshape(-1).astype(dtype)
        inp_ptr = np.ctypeslib.as_ctypes(x_typed)
        
        res = np.zeros(out_size, dtype=dtype)
        res_ptr = np.ctypeslib.as_ctypes(res.reshape(-1))

        ptr_type = c_float
        predict_method = self.lib.forward

        
        predict_method.argtypes = [c_ulong, POINTER(ptr_type), POINTER(ptr_type)]

        predict_method(self.eid[eid_idx], inp_ptr, res_ptr)


        return res.reshape(out_size)
    def backward(self, x, out_size, eid_idx=0):
        dtype = np.float32
        x_typed = x.reshape(-1).astype(dtype)
        inp_ptr = np.ctypeslib.as_ctypes(x_typed)
        
        res = np.zeros(out_size, dtype=dtype)
        res_ptr = np.ctypeslib.as_ctypes(res.reshape(-1))

        ptr_type = c_float
        predict_method = self.lib.enclave_backward

        
        predict_method.argtypes = [c_ulong, POINTER(ptr_type), POINTER(ptr_type)]

        predict_method(self.eid[eid_idx], inp_ptr, res_ptr)
        return res.reshape(out_size)
        
    def setup_final_reorder(self, eid_idx=0):
        cfunc = self.lib.setup_final_reorder
        cfunc.argtypes = [c_ulong]
        cfunc(self.eid[eid_idx])

    def inverted_compl(self, eid_idx=0):
        cfunc = self.lib.inverted_compl
        cfunc.argtypes = [c_ulong]
        cfunc(self.eid[eid_idx])

    def inverted_init(self, eid_idx=0):
        cfunc = self.lib.inverted_init
        cfunc.argtypes = [c_ulong]
        cfunc(self.eid[eid_idx])
    
    def enclave_update_backward(self, eid_idx=0):
        cfunc = self.lib.enclave_update_backward
        cfunc.argtypes = [c_ulong]
        cfunc(self.eid[eid_idx])
 
    def predict(self, x, num_classes=1000, is_intermediate=False, num_inter_feats=[0, 0, 0], eid_idx=0):
        dtype = np.float32
        x_typed = x.reshape(-1).astype(dtype)
        inp_ptr = np.ctypeslib.as_ctypes(x_typed)
        
        if is_intermediate:
            res = np.zeros((len(x), num_inter_feats[0], num_inter_feats[1], num_inter_feats[2]), dtype=dtype)
        else:
            res = np.zeros((len(x), num_classes), dtype=dtype)
        res_ptr = np.ctypeslib.as_ctypes(res.reshape(-1))

        ptr_type = c_float
        predict_method = self.lib.predict_float

        if self.use_sgx:
            predict_method.argtypes = [c_ulong, POINTER(ptr_type), POINTER(ptr_type), c_int]
            predict_method(self.eid[eid_idx], inp_ptr, res_ptr, x.shape[0])
        else:
            predict_method.argtypes = [POINTER(ptr_type), POINTER(ptr_type), c_int]
            predict_method(inp_ptr, res_ptr, x.shape[0])

        return res
        
    def print_time_report(self, eid_idx=0):
        print_time = self.lib.print_time_report
        print_time.argtypes = [c_ulong]
        print_time(self.eid[eid_idx])


    def predict_and_verify(self, x, aux_data, num_classes=1000, dtype=np.float64, eid_idx=0):
        assert dtype == np.float32
        ptr_type = c_float
        predict_method = self.lib.predict_verify_float
        ptr_type_aux = c_float

        x_typed = x.reshape(-1).astype(dtype)
        inp_ptr = np.ctypeslib.as_ctypes(x_typed)

        aux_ptrs = (POINTER(ptr_type_aux) * len(aux_data))()
        for i in range(len(aux_data)):
            aux_ptrs[i] = np.ctypeslib.as_ctypes(aux_data[i])

        res = np.zeros((len(x), num_classes), dtype=dtype)
        res_ptr = np.ctypeslib.as_ctypes(res.reshape(-1))

        if self.use_sgx:
            predict_method.argtypes = [c_ulong, POINTER(ptr_type), POINTER(ptr_type),
                                       POINTER(POINTER(ptr_type_aux)), c_int]
            predict_method(self.eid[eid_idx], inp_ptr, res_ptr, aux_ptrs, x.shape[0])
        else:
            predict_method.argtypes = [POINTER(ptr_type), POINTER(ptr_type),
                                       POINTER(POINTER(ptr_type_aux)), c_int]
            predict_method(inp_ptr, res_ptr, aux_ptrs, x.shape[0])
        return res.astype(np.float32)

    def relu_slalom(self, inputs, blind, activation, eid_idx=0):
        if self.use_sgx:
            eid = self.eid[eid_idx]
            eid_low = eid % 2**32
            eid_high = eid // 2**32
            return self.slalom_lib.relu_slalom(inputs, blind, activation=activation, eid_low=eid_low, eid_high=eid_high)
        else:
            return self.slalom_lib.relu_slalom(inputs, blind, activation=activation)
        
    def maxpoolrelu_slalom(self, inputs, bias, params, eid_idx=0):
        ksize = (1, params['pool_size'][0], params['pool_size'][1], 1)
        strides = (1, params['strides'][0], params['strides'][1], 1)
        padding = params['padding'].upper()

        if self.use_sgx:
            eid = self.eid[eid_idx]
            eid_low = eid % 2**32
            eid_high = eid // 2**32
            return self.slalom_lib.relu_max_pool_slalom(inputs, bias, ksize, strides, padding, eid_low=eid_low, eid_high=eid_high)
        else:
            return self.slalom_lib.relu_max_pool_slalom(inputs, blind, ksize, strides, padding)

    def slalom_init(self, slalom_integrity, slalom_privacy, batch_size, eid_idx=0):
        if self.use_sgx:
            self.lib.slalom_init.argtypes = [c_ulong, c_bool, c_bool, c_int]
            self.lib.slalom_init(self.eid[eid_idx], slalom_integrity, slalom_privacy, batch_size)
        else:
            self.lib.slalom_init(slalom_integrity, slalom_privacy, batch_size)

    def slalom_get_r(self, r, eid_idx=0):
        r_flat = r.reshape(-1)
        inp_ptr = np.ctypeslib.as_ctypes(r_flat)

        if self.use_sgx:
            self.lib.slalom_get_r.argtypes = [c_ulong, POINTER(c_float), c_int]
            self.lib.slalom_get_r(self.eid[eid_idx], inp_ptr, r.size)
        else:
            self.lib.slalom_get_r.argtypes = [POINTER(c_float), c_int]
            self.lib.slalom_get_r(inp_ptr, r.size)

    def slalom_set_z(self, z, z_enc, eid_idx=0):
        z_flat = z.reshape(-1)
        inp_ptr = np.ctypeslib.as_ctypes(z_flat)

        z_enc_flat = z_enc.reshape(-1)
        out_ptr = np.ctypeslib.as_ctypes(z_enc_flat)

        if self.use_sgx:
            self.lib.slalom_set_z.argtypes = [c_ulong, POINTER(c_float), POINTER(c_float), c_int]
            self.lib.slalom_set_z(self.eid[eid_idx], inp_ptr, out_ptr, z.size)
        else:
            self.lib.slalom_set_z.argtypes = [POINTER(c_float), POINTER(c_float), c_int]
            self.lib.slalom_set_z(inp_ptr, out_ptr, z.size)
   
    def align_numpy(self, unaligned):
        sess = tf.get_default_session()
        aligned = sess.run(tf.ones(unaligned.shape, dtype=unaligned.dtype))
        np.copyto(aligned, unaligned)
        return aligned
 
    def slalom_blind_input(self, x, eid_idx=0):
        res = self.align_numpy(x)
        res_flat = res.reshape(-1)
        inp_ptr = np.ctypeslib.as_ctypes(res_flat)
        out_ptr = np.ctypeslib.as_ctypes(res_flat)
        s = time.time()
        
        if self.use_sgx:
            self.lib.slalom_blind_input.argtypes = [c_ulong, POINTER(c_float), POINTER(c_float), c_int]
            self.lib.slalom_blind_input(self.eid[eid_idx], inp_ptr, out_ptr, x.size)
        else:
            self.lib.slalom_blind_input.argtypes = [POINTER(c_float), POINTER(c_float), c_int]
            self.lib.slalom_blind_input(inp_ptr, out_ptr, x.size)
        print(time.time() - s)
        return res

    def fill_parameter(self, eid_idx=0, internal_batch_size=3):
        bm, um, gm, igm = identity()
        print(bm)
        print(um)
        print(gm)
        print(igm)
        bm_ptr = np.ctypeslib.as_ctypes(bm.reshape(-1))
        um_ptr = np.ctypeslib.as_ctypes(um.reshape(-1))
        gm_ptr = np.ctypeslib.as_ctypes(gm.reshape(-1))
        igm_ptr = np.ctypeslib.as_ctypes(igm.reshape(-1))

        self.lib.fill_parameter_matrix.argtypes = [c_ulong, POINTER(c_float), POINTER(c_float),
                                                   POINTER(c_float), POINTER(c_float), c_int]
        self.lib.fill_parameter_matrix(self.eid[eid_idx], bm_ptr, um_ptr, gm_ptr, igm_ptr, internal_batch_size)

        return bm, um, gm, igm
    def dnnl_init(self, eid_idx=0, train_inside=0, internal_batch=2):
        self.lib.dnnl_init.argtypes = [c_ulong, c_int, c_int]
        self.lib.dnnl_init(self.eid[eid_idx], train_inside, internal_batch)
        
    def setup_relu(self, in_shape, eid_idx=0):
        in_shape = in_shape.reshape(-1)
        in_shape_ptr = np.ctypeslib.as_ctypes(in_shape)
        self.lib.setup_relu.argtypes = [c_ulong, POINTER(c_int)]
        self.lib.setup_relu(self.eid[eid_idx], in_shape_ptr)

    def slalom_relu_back(self, grad, relu_src, eid_idx=0):
        eid = self.eid[eid_idx]
        eid_low = eid % 2**32
        eid_high = eid // 2**32
        print(grad)
        print(relu_src)
        
        return self.slalom_lib.relu_back(grad, relu_src, eid_low=eid_low, eid_high=eid_high)
        
    def setup_batchnormsp_dark(self, in_shape, privacy, eps, momentum, eid_idx=0):
        in_shape = list(in_shape)
        in_shape[0] = 1
        in_shape_np = np.array(in_shape).astype(np.int32)
        in_shape_ptr = np.ctypeslib.as_ctypes(in_shape_np)

        self.lib.setup_batchnormsp.argtypes = [c_ulong, POINTER(c_int),c_int, c_float, c_float]
        privacy_int = 0
        if privacy is True:
            privacy_int = 1
        
        self.lib.setup_batchnormsp(self.eid[eid_idx], in_shape_ptr, privacy_int, eps, momentum)

    def setup_maxpoolrelu(self, in_size, out_size, ker_size, strides, padding, eid_idx=0):
        in_shape = in_size.reshape(-1)
        in_shape_ptr = np.ctypeslib.as_ctypes(in_shape)
        out_shape = out_size.reshape(-1)
        out_shape_ptr = np.ctypeslib.as_ctypes(out_shape)
        ker_ptr = np.ctypeslib.as_ctypes(ker_size)
        stride_ptr = np.ctypeslib.as_ctypes(strides)
        padding_ptr = np.ctypeslib.as_ctypes(padding)
        work_size = np.array([0], dtype=np.int32)
        work_size_ptr = np.ctypeslib.as_ctypes(work_size)

        self.lib.setup_maxpoolrelu.argtypes = [c_ulong, POINTER(c_int), POINTER(c_int),
                                               POINTER(c_int), POINTER(c_int), POINTER(c_int),
                                               POINTER(c_int)]


        self.lib.setup_maxpoolrelu(self.eid[eid_idx], work_size_ptr,
                                   in_shape_ptr, out_shape_ptr, ker_ptr, stride_ptr,
                                   padding_ptr)
        return work_size

    def maxpool(self, in_shape, out_shape, work_size, eid_idx=0):
        # in_aligned = self.align_numpy(in_shape)
        # res_aligned = self.align_numpy(out_shape)
        # work_aligned = self.align_numpy(work_size)
        in_ptr = np.ctypeslib.as_ctypes(in_shape.reshape(-1))
        out_shape = out_shape.reshape(-1)
        res_ptr = np.ctypeslib.as_ctypes(out_shape)
        work_ptr = np.ctypeslib.as_ctypes(work_size.reshape(-1))
        
        self.lib.maxpool.argtypes = [c_ulong, POINTER(c_float), POINTER(c_float), POINTER(c_float)]
        self.lib.maxpool(self.eid[eid_idx], in_ptr, res_ptr, work_ptr)

        return work_size, out_shape

    def maxpoolrelu_back(self, grad, relu_src, workspace, eid_idx=0):
        eid = self.eid[eid_idx]
        eid_low = eid % 2**32
        eid_high = eid // 2**32

        return self.slalom_lib.maxpool_relu_back(grad, relu_src, workspace, eid_low=eid_low, eid_high=eid_high)

    def batchnorm_dark(self, input, means, skip_input, training, act_mode='bn', eid_idx=0):
        eid = self.eid[eid_idx]
        eid_low = eid % 2**32
        eid_high = eid // 2**32

        return self.slalom_lib.batch_norm_dark(inp=input,
                                               means=means, 
                                               skip_input=skip_input, 
                                               eid_low=eid_low, 
                                               eid_high=eid_high,
                                               act_mode=act_mode)

    def batchnorm_dark_back(self, grad, input, skip_input, act_src, eid_idx=0):
        eid = self.eid[eid_idx]
        eid_low = eid % 2**32
        eid_high = eid // 2**32

        return self.slalom_lib.batch_norm_dark_back(grad=grad,
                                                    input=input,
                                                    skip_input=skip_input,
                                                    act_src=act_src,
                                                    eid_low=eid_low,
                                                    eid_high=eid_high)
    def resnet_activation_op(self, input, means, act_mode, eid_idx=0):
        eid = self.eid[eid_idx]
        eid_low = eid % 2**32
        eid_high = eid // 2**32

        return self.slalom_lib.resnet_activation(src=input,
                                                 means=means,
                                                 mode=act_mode,
                                                 eid_low=eid_low,
                                                 eid_high=eid_high
                                                 )

    def resnet_activation_back_op(self, grad_out, act_mode, eid_idx=0):
        eid = self.eid[eid_idx]
        eid_low = eid % 2**32
        eid_high = eid // 2**32

        return self.slalom_lib.resnet_activation_back(grad_out=grad_out,
                                                      act_mode=act_mode,
                                                      eid_low=eid_low,
                                                      eid_high=eid_high)

    def resnet_bottom_op(self, left_in, right_in, mean_left, mean_right,eid_idx=0):
        eid = self.eid[eid_idx]
        eid_low = eid % 2**32
        eid_high = eid // 2**32

        return self.slalom_lib.resnet_bottom(left_in=left_in,
                                             right_in=right_in,
                                             mean_left=mean_left,
                                             mean_right=mean_right,
                                             act_mode="apple",
                                             eid_low=eid_low,
                                             eid_high=eid_high)

    def resnet_bottom_back_op(self, grad_out, eid_idx=0):
        eid      = self.eid[eid_idx]
        eid_low  = eid % 2**32
        eid_high = eid // 2**32

        return self.slalom_lib.renset_bottom_back(grad_out=grad_out,
                                                  act_mode="apple",
                                                  eid_low=eid_low,
                                                  eid_high=eid_high)

    def resnet_setup_activation(self, mode, in_size, out_size, 
                                pool_window, pool_stride, 
                                eps, momentum, bias_data,
                                eid_idx=0):
        setup_act = self.lib.resnet_setup_activation
        batch = self.dark_batch_size
        setup_act.argtypes = [c_ulong, c_int, POINTER(c_int), POINTER(c_int), POINTER(c_int), POINTER(c_int), c_float, c_float, POINTER(c_float)]
        mode_int = 9
        if mode == "bias_add":
            mode_int = 0
        elif mode == "bnzerorelu":
            mode_int = 1
        elif mode == "bnrelupool":
            mode_int = 2
        print(in_size, out_size)
        in_size_np = np.ctypeslib.as_ctypes(np.array((batch, in_size[1], in_size[2], in_size[3])).astype(np.int32))
        out_size_np = np.ctypeslib.as_ctypes(np.array((batch, out_size[1], out_size[2], out_size[3])).astype(np.int32))
        pool_win_np = np.ctypeslib.as_ctypes(np.array((pool_window[0], pool_window[1])).astype(np.int32))
        pool_str_np = np.ctypeslib.as_ctypes(np.array((pool_stride[0], pool_stride[1])).astype(np.int32))
        bias_weight_np = np.ctypeslib.as_ctypes(bias_data.reshape(-1))

        setup_act(self.eid[eid_idx], mode_int, in_size_np, out_size_np, pool_win_np, pool_str_np,
                  eps, momentum, bias_weight_np)

    def resnet_setup_bottom(self, mode, in_size, out_size, 
                            eps, momentum,
                            bias_data_l, bias_data_r,
                            eid_idx=0):
        setup_bot = self.lib.resnet_setup_bottom
        batch = self.dark_batch_size
        setup_bot.argtypes = [c_ulong, c_int, POINTER(c_int), POINTER(c_int), c_float, c_float, POINTER(c_float), POINTER(c_float)]
        mode_int = 0
        if mode == "normal":
            mode_int = 0
        elif mode == "downsample":
            mode_int = 1
        
        print(in_size, out_size)

        in_size_np = np.ctypeslib.as_ctypes(np.array((batch, in_size[1], in_size[2], in_size[3])).astype(np.int32))
        out_size_np = np.ctypeslib.as_ctypes(np.array((batch, out_size[1], out_size[2], out_size[3])).astype(np.int32))
        bias_l_np = np.ctypeslib.as_ctypes(bias_data_l.reshape(-1))
        bias_r_np = np.ctypeslib.as_ctypes(bias_data_r.reshape(-1))

        setup_bot(self.eid[eid_idx], mode_int, in_size_np, out_size_np,
                  eps, momentum, bias_l_np, bias_r_np)


    def addresstest(self, x):
        sess = tf.get_default_session()
        aligned = sess.run(tf.layers.Flatten()(x))
        res_ptr = np.ctypeslib.as_ctypes(aligned)

        print(res_ptr)

# convert a model to JSON format and potentially precompute integrity check vectors
def model_to_json(sess, model, verif_preproc=False, slalom_privacy=False, dtype=np.float32,


                  bits_w=0, bits_x=0, partitionAtLayer = 1000):

    import keras
    from keras import activations
    from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, DepthwiseConv2D, GlobalAveragePooling2D, Dropout, \
        Reshape, ZeroPadding2D, AveragePooling2D, Lambda
    from python.slalom.quant_layers import Conv2DQ, DenseQ, DepthwiseConv2DQ, ActivationQ
    from python.slalom.resnet import ResNetBlock

    
    def get_activation_name(layer):
        if layer.activation is not None:
            return activations.serialize(layer.activation)
        else:
            return ''

    p = ((1 << 24) - 3) 
    r_max = (1 << 19)
    reps = 2

    def layer_to_json(layer):
        json = {}
        layer_weights = []

        if isinstance(layer, keras.layers.InputLayer):
            json = {'name': 'Input', 'shape': layer.batch_input_shape}

        elif isinstance(layer, Conv2D) and not isinstance(layer, DepthwiseConv2D):
            json = {'name': 'Conv2D',
                    'kernel_size': layer.kernel.get_shape().as_list(),
                    'strides': layer.strides,
                    'padding': layer.padding,
                    'activation': get_activation_name(layer),
                    'bits_w': layer.bits_w, 'bits_x': layer.bits_x}

            if isinstance(layer, Conv2DQ):
                kernel = sess.run(layer.kernel_q)
                bias = sess.run(layer.bias_q)

                print(layer, layer.input_shape, layer.output_shape, kernel.shape)

                if verif_preproc:
                    # precompute W*r or r_left * W * r_right
                    k_w, k_h, ch_in, ch_out = kernel.shape
                    h, w = layer.input_shape[1], layer.input_shape[2]
                    h_out, w_out = layer.output_shape[1], layer.output_shape[2]

                    np.random.seed(0)
                    if k_w == 1 and k_h == 1:
                        # pointwise conv
                        r_left = np.array([]).astype(np.float32)
                        r_right = np.ones(shape=(reps, ch_out)).astype(np.float32)
                        w_r = kernel.astype(np.float64).reshape((-1, ch_out)).dot(r_right.T.astype(np.float64))
                        w_r = w_r.T
                        w_r = np.fmod(w_r, p).astype(np.float32)
                        assert np.max(np.abs(w_r)) < 2 ** 52
                        b_r = np.fmod(np.dot(bias.astype(np.float64), r_right.T.astype(np.float64)), p).astype(
                            np.float32)
                    else:
                        r_left = np.random.randint(low=-r_max, high=r_max + 1, size=(reps, h_out * w_out)).astype(
                            np.float32)
                        r_right = np.random.randint(low=-r_max, high=r_max + 1, size=(reps, ch_out)).astype(np.float32)
                        w_r = np.zeros((reps, h * w, ch_in)).astype(np.float32)
                        b_r = np.zeros(reps).astype(np.float32)

                        X = np.zeros((1, h, w, ch_in)).astype(np.float64)
                        x_ph = tf.placeholder(tf.float64, shape=(None, h, w, ch_in))
                        w_ph = tf.placeholder(tf.float64, shape=(k_h, k_w, ch_in, 1))
                        y_ph = tf.placeholder(tf.float64, shape=(1, h_out, w_out, 1))
                        z = tf.nn.conv2d(x_ph, w_ph, (1,) + layer.strides + (1,), layer.padding.upper())
                        dz = tf.gradients(z, x_ph, grad_ys=y_ph)[0]

                        for i in range(reps):
                            curr_r_left = r_left[i].astype(np.float64)
                            curr_r_right = r_right[i].astype(np.float64)
                            #print("sum(curr_r_left) = {}".format(np.sum(curr_r_left)))
                            #print("sum(curr_r_right) = {}".format(np.sum(curr_r_right)))

                            w_right = kernel.astype(np.float64).reshape((-1, ch_out)).dot(curr_r_right)
                            #print("sum(w_right) = {}".format(np.sum(w_right)))
                            assert np.max(np.abs(w_right)) < 2 ** 52

                            w_r_i = sess.run(dz, feed_dict={x_ph: X, w_ph: w_right.reshape(k_w, k_h, ch_in, 1),
                                                            y_ph: curr_r_left.reshape((1, h_out, w_out, 1))})
                            #print("sum(w_r) = {}".format(np.sum(w_r_i.astype(np.float64))))
                            w_r[i] = np.fmod(w_r_i, p).astype(np.float32).reshape((h * w, ch_in))
                            assert np.max(np.abs(w_r[i])) < 2 ** 52
                            #print("sum(w_r) = {}".format(np.sum(w_r[i].astype(np.float64))))
                            b_r[i] = np.fmod(
                                np.sum(curr_r_left) * np.fmod(np.dot(bias.astype(np.float64), curr_r_right), p),
                                p).astype(np.float32)
                            #print("sum(b_r) = {}".format(np.sum(b_r[i].astype(np.float64))))

                    print("r_left: {}".format(r_left.astype(np.float64).sum()))
                    print("r_right: {}".format(r_right.astype(np.float64).sum()))
                    print("w_r: {}".format(w_r.astype(np.float64).sum()))
                    print("b_r: {}".format(b_r.astype(np.float64).sum()))
                    layer_weights.append(r_left.reshape(-1))
                    layer_weights.append(r_right.reshape(-1))
                    layer_weights.append(w_r.reshape(-1))
                    layer_weights.append(b_r.reshape(-1))
            else:
                kernel = layer.kernel.eval(sess)
                bias = layer.bias.eval(sess)
                print("sum(abs(conv_w)): {}".format(np.abs(kernel).sum()))

            if not verif_preproc:
                layer_weights.append(kernel.reshape(-1).astype(dtype))
                layer_weights.append(bias.reshape(-1).astype(dtype))

        elif isinstance(layer, MaxPooling2D):
            json = {'name': 'MaxPooling2D', 'pool_size': layer.pool_size,
                    'strides': layer.strides, 'padding': layer.padding}

        elif isinstance(layer, AveragePooling2D):
            json = {'name': 'AveragePooling2D', 'pool_size': layer.pool_size,
                    'strides': layer.strides, 'padding': layer.padding}

        elif isinstance(layer, Flatten):
            json = {'name': 'Flatten'}

        elif isinstance(layer, Dense):
            assert not (slalom_privacy and verif_preproc)
            json = {'name': 'Dense', 'kernel_size': layer.kernel.get_shape().as_list(),
                    'pointwise_conv': False, 'activation': get_activation_name(layer),
                    'bits_w': layer.bits_w, 'bits_x': layer.bits_x}

            if isinstance(layer, DenseQ):
                kernel = sess.run(layer.kernel_q).reshape(-1).astype(dtype)
                bias = sess.run(layer.bias_q).reshape(-1).astype(dtype)

            else:
                kernel = layer.kernel.eval(sess).reshape(-1).astype(dtype)
                bias = layer.bias.eval(sess).reshape(-1).astype(dtype)
            print("sum(abs(dense_w)): {}".format(np.abs(kernel).sum()))
            layer_weights.append(kernel)
            layer_weights.append(bias)

        elif isinstance(layer, DepthwiseConv2D):
            json = {'name': 'DepthwiseConv2D', 'kernel_size': layer.depthwise_kernel.get_shape().as_list(),
                    'strides': layer.strides, 'padding': layer.padding, 'activation': get_activation_name(layer)}

            if isinstance(layer, DepthwiseConv2DQ):
                kernel = sess.run(layer.kernel_q)
                bias = sess.run(layer.bias_q)

                if verif_preproc:
                    # precompute W*r
                    k_w, k_h, ch_in, _ = kernel.shape
                    h, w = layer.input_shape[1], layer.input_shape[2]
                    h_out, w_out = layer.output_shape[1], layer.output_shape[2]

                    np.random.seed(0)
                    r_left = np.random.randint(low=-r_max, high=r_max + 1, size=(reps, h_out * w_out)).astype(
                        np.float32)
                    w_r = np.zeros((reps, h * w, ch_in)).astype(np.float32)
                    b_r = np.zeros((reps, ch_in)).astype(np.float32)

                    X = np.zeros((1, h, w, ch_in)).astype(np.float64)
                    x_ph = tf.placeholder(tf.float64, shape=(None, h, w, ch_in))
                    w_ph = tf.placeholder(tf.float64, shape=(k_h, k_w, ch_in, 1))
                    y_ph = tf.placeholder(tf.float64, shape=(1, h_out, w_out, ch_in))
                    z = tf.nn.depthwise_conv2d_native(x_ph, w_ph, (1,) + layer.strides + (1,), layer.padding.upper())
                    dz = tf.gradients(z, x_ph, grad_ys=y_ph)[0]

                    for i in range(reps):
                        curr_r_left = r_left[i].astype(np.float64)
                        #print("r_left: {}".format(curr_r_left.astype(np.float64).sum()))
                        w_r_i = sess.run(dz, feed_dict={x_ph: X, w_ph: kernel.astype(np.float64),
                                                        y_ph: curr_r_left.reshape((1, h_out, w_out, 1)).repeat(ch_in,
                                                                                                               axis=-1)})
                        w_r[i] = np.fmod(w_r_i, p).astype(np.float32).reshape((h * w, ch_in))
                        assert np.max(np.abs(w_r[i])) < 2 ** 52
                        #print("sum(w_r) = {}".format(np.sum(w_r[i].astype(np.float64))))

                        b_r[i] = np.fmod(np.sum(curr_r_left) * bias.astype(np.float64), p)
                        #print("sum(b_r) = {}".format(np.sum(b_r[i].astype(np.float64))))

                    print("r_left: {}".format(r_left.astype(np.float64).sum()))
                    print("w_r: {}".format(w_r.astype(np.float64).sum()))
                    print("b_r: {}".format(b_r.astype(np.float64).sum()))
                    layer_weights.append(r_left.reshape(-1))
                    layer_weights.append(w_r.reshape(-1))
                    layer_weights.append(b_r.reshape(-1))

            else:
                kernel = layer.depthwise_kernel.eval(sess)
                bias = layer.bias.eval(sess)
            print("sum(abs(depthwise_w)): {}".format(np.abs(kernel).sum()))

            if not verif_preproc:
                layer_weights.append(kernel.reshape(-1).astype(dtype))
                layer_weights.append(bias.reshape(-1).astype(dtype))

        elif isinstance(layer, GlobalAveragePooling2D):
           json = {'name': 'GlobalAveragePooling2D'}

        elif isinstance(layer, Dropout):
            pass

        elif isinstance(layer, Lambda):
            pass

        elif isinstance(layer, Reshape):
            json = {'name': 'Reshape', 'shape': layer.target_shape}

        elif isinstance(layer, ZeroPadding2D):
            json = {'name': 'ZeroPadding2D',
                    'padding': layer.padding if not hasattr(layer.padding, '__len__') else layer.padding[0]}

        elif isinstance(layer, ActivationQ):
            json = {'name': 'Activation', 'type': layer.activation_name(), 'bits_w': layer.bits_w}

            if hasattr(layer, 'maxpool_params') and layer.maxpool_params is not None:
                json2 = {'name': 'MaxPooling2D', 'pool_size': layer.maxpool_params['pool_size'],
                        'strides': layer.maxpool_params['strides'], 'padding': layer.maxpool_params['padding']}
                
                json = [json, json2]

        elif isinstance(layer, ResNetBlock):
            path1 = []
            path2 = []
            for l in layer.path1:
                if isinstance(l, Conv2D) or isinstance(l, ActivationQ):
                    js, w = layer_to_json(l)
                    path1.append(js)
                    layer_weights.extend(w)

            for l in layer.path2:
                if isinstance(l, Conv2D) or isinstance(l, ActivationQ):
                    js, w = layer_to_json(l)
                    path2.append(js)
                    layer_weights.extend(w)
            
            json = {'name': 'ResNetBlock', 'identity': layer.identity, 'bits_w': layer.bits_w, 'bits_x': layer.bits_x,
                    'path1': path1, 'path2': path2}

            if slalom_privacy:
                json = [json]
                js2, _ = layer_to_json(layer.merge_act)
                if isinstance(js2, dict):
                    json.append(js2)
                else:
                    json.extend(js2)

        else:
            raise NameError("Unknown layer {}".format(layer))

        return json, layer_weights

    model_json_part1 = {'layers': [], 'shift_w': 2**bits_w, 'shift_x': 2**bits_x, 'max_tensor_size': 224*224*64}
    weights_part1 = []
    model_json_part2 = {'layers': [], 'shift_w': 2**bits_w, 'shift_x': 2**bits_x, 'max_tensor_size': 224*224*64}
    weights_part2 = []
    for idx, layer in enumerate(model.layers):
        if idx < partitionAtLayer:
            json, layer_weights = layer_to_json(layer)

            if json:
                if isinstance(json, dict):
                    model_json_part1['layers'].append(json)
                else:
                    model_json_part1['layers'].extend(json)
            weights_part1.extend(layer_weights)
        else:
            break #Only return model partition 1
            json, layer_weights = layer_to_json(layer)

            if json:
                if isinstance(json, dict):
                    model_json_part2['layers'].append(json)
                else:
                    model_json_part2['layers'].extend(json)
            weights_part2.extend(layer_weights)

    return model_json_part1, weights_part1, model_json_part2, weights_part2


# for debugging integrity checks
def mod_test(sess, model, images, linear_ops_in, linear_ops_out, verif_preproc=False):
    linear_inputs, linear_outputs = sess.run([linear_ops_in, linear_ops_out],
                                             feed_dict={model.inputs[0]: images, keras.backend.learning_phase(): 0})

    kernels = [(layer, sess.run([layer.kernel_q, layer.bias_q])) for layer in model.layers if
               hasattr(layer, 'kernel_q')]

    p = ((1 << 24) - 3)
    r_max = (1 << 20)
    batch = images.shape[0]

    def fmod(x, p):
        return np.fmod(x, p)

    def fmod_pos(x, p):
        return np.fmod(np.fmod(x, p) + p, p)

    np.random.seed(0)
    for (layer, (kernel, bias)), inp, out in zip(kernels, linear_inputs, linear_outputs):
        assert (np.max(np.abs(inp)) < p / 2)
        assert (np.max(np.abs(out)) < p / 2)

        print("input = {} {}".format(inp.reshape(-1)[:3], inp.reshape(-1)[-3:]))
        assert np.max(np.abs(out)) < 2 ** 23

        if isinstance(layer, Conv2DQ):
            h, w = layer.input_shape[1], layer.input_shape[2]
            h_out, w_out = layer.output_shape[1], layer.output_shape[2]
            k_w, k_h, ch_in, _ = kernel.shape
            inp = np.reshape(inp, (batch, h, w, ch_in))

            pointwise = k_h == 1 and k_w == 1

            np.random.seed(0)
            if verif_preproc and pointwise:
                r_left = []
                r_right = np.random.randint(low=-r_max, high=r_max + 1, size=(2, kernel.shape[3])).astype(np.float32)
                r_right = r_right[0, :].astype(np.float64)
            else:
                r_left = np.random.randint(low=-r_max, high=r_max + 1, size=(2, h_out * w_out)).astype(np.float32)
                r_right = np.random.randint(low=-r_max, high=r_max + 1, size=(2, kernel.shape[3])).astype(np.float32)
                r_left = r_left[0, :].astype(np.float64)
                r_right = r_right[0, :].astype(np.float64)

            if batch > 1:
                r_left = np.ones(shape=(1, batch)).astype(np.float64)
                r_right = np.ones(shape=(kernel.shape[3], 1)).astype(np.float64)

            if batch > 1:
                Z = out.reshape((batch, -1)).astype(np.float64)
                r_Z = fmod(r_left.dot(Z), p).reshape(1, h_out, h_out, kernel.shape[3])
                assert np.max(np.abs(r_Z)) < 2 ** 52
                r_Z_r = r_Z.dot(r_right)
                assert np.max(np.abs(r_Z_r)) < 2 ** 52
                r_Z_r = fmod(r_Z_r, p)
                b_r = fmod(np.sum(r_left) * fmod(np.dot(bias, r_right), p), p)
                print("b_r: {}".format(b_r))
                r_Z_r = fmod_pos(r_Z_r - b_r, p)
                print(r_Z_r.shape)
            else:
                Z = out.reshape((h_out * w_out, -1)).astype(np.float64)
                Z_r = fmod(Z.dot(r_right.astype(np.float64)), p)
                assert np.max(np.abs(Z_r)) < 2 ** 52

                if not pointwise:
                    r_Z_r = r_left.dot(Z_r)
                    assert np.max(np.abs(r_Z_r)) < 2 ** 52
                    r_Z_r = fmod_pos(r_Z_r, p)
                else:
                    r_Z_r = fmod_pos(Z_r, p)
            print("r_Z_r = {} {}".format(r_Z_r.reshape(-1)[:3], r_Z_r.reshape(-1)[-3:]))

            if batch > 1:
                r_X = r_left.dot(inp.reshape(batch, -1)).reshape(1, h, h, kernel.shape[2])
                assert np.max(np.abs(r_X)) < 2 ** 52
                r_X = fmod(r_X, p)
                print("r_X = {} {}".format(r_X.reshape(-1)[:3], r_X.reshape(-1)[-3:]))
                w_right = kernel.astype(np.float64).reshape((-1, kernel.shape[-1])).dot(r_right)
                w_right = w_right.reshape(kernel.shape[0], kernel.shape[1], kernel.shape[2], 1)
                assert np.max(np.abs(w_right)) < 2 ** 52
                X_W_r = sess.run(tf.nn.conv2d(r_X, w_right, (1,) + layer.strides + (1,), layer.padding.upper()))
                assert np.max(np.abs(X_W_r)) < 2 ** 52
                X_W_r = fmod_pos(X_W_r, p)
                print(X_W_r.shape)
            else:
                if pointwise:
                    test = inp.reshape((h * w, -1)).astype(np.float64).dot(kernel.reshape(ch_in, -1).astype(np.float64))
                    test += bias.astype(np.float64)
                    print("test = {} {}".format(test.reshape(-1)[:3], test.reshape(-1)[-3:]))
                    print("out = {} {}".format(out.reshape(-1)[:3], out.reshape(-1)[-3:]))
                    assert((test.reshape(-1) == out.reshape(-1)).all())

                    w_right = kernel.astype(np.float64).reshape((-1, kernel.shape[-1])).dot(r_right.astype(np.float64))
                    assert np.max(np.abs(w_right)) < 2 ** 52
                    print("W_r = {} {}".format(w_right.reshape(-1)[:3], w_right.reshape(-1)[-3:]))

                    X_r = inp.reshape((h * w, -1)).astype(np.float64).dot(np.fmod(w_right, p))
                    assert np.max(np.abs(X_r)) < 2 ** 52

                    X_r = np.fmod(X_r, p)
                    b_r = np.fmod(np.dot(bias, r_right), p)
                    X_W_r = fmod_pos(X_r + b_r, p)

                else:
                    w_right = kernel.astype(np.float64).reshape((-1, kernel.shape[-1])).dot(r_right).reshape(-1, kernel.shape[2])
                    assert np.max(np.abs(w_right)) < 2 ** 52
                    print("sum(w_right) = {}".format(np.sum(w_right.astype(np.float64))))

                    X = np.zeros((1, h, w, ch_in)).astype(np.float64)

                    x_ph = tf.placeholder(tf.float64, shape=(None, h, w, ch_in))
                    w_ph = tf.placeholder(tf.float64, shape=(k_h, k_w, ch_in, 1))
                    y_ph = tf.placeholder(tf.float64, shape=(1, h_out, w_out, 1))
                    z = tf.nn.conv2d(x_ph, w_ph, (1,) + layer.strides + (1,), layer.padding.upper())
                    dz = tf.gradients(z, x_ph, grad_ys=y_ph)[0]

                    w_r = sess.run(dz, feed_dict={x_ph: X, w_ph: w_right.reshape(k_w, k_h, ch_in, 1),
                                                  y_ph: r_left.reshape((1, h_out, w_out, 1))})

                    assert np.max(np.abs(w_r)) < 2 ** 52

                    w_r = np.fmod(w_r, p).astype(np.float32)
                    print("sum(r_left) = {}".format(np.sum(r_left.astype(np.float64))))
                    print("sum(r_right) = {}".format(np.sum(r_right.astype(np.float64))))
                    print("sum(w_r) = {}".format(np.sum(w_r.astype(np.float64))))

                    X_W_r = inp.astype(np.float64).reshape(-1).dot(w_r.astype(np.float64).reshape(-1))
                    assert np.max(np.abs(X_W_r)) < 2 ** 52
                    X_W_r = fmod(X_W_r, p)
                    b_r = np.fmod(np.sum(r_left) * np.fmod(np.dot(bias, r_right), p), p)
                    print("b_r: {}".format(b_r))
                    X_W_r = fmod_pos(X_W_r + b_r, p)
            print("X*W_r = {} {}".format(X_W_r.reshape(-1)[:3], X_W_r.reshape(-1)[-3:]))
            if not (X_W_r == r_Z_r).all():
                np.set_printoptions(threshold=np.nan)
                print(X_W_r - r_Z_r)
                assert(0)
            print()

        elif isinstance(layer, DepthwiseConv2DQ):
            h, w = layer.input_shape[1], layer.input_shape[2]
            h_out, w_out = layer.output_shape[1], layer.output_shape[2]
            k_w, k_h, ch_in, _ = kernel.shape
            inp = np.reshape(inp, (batch, h, w, ch_in))

            reps = 1

            if batch > 1:
                r_left = np.ones(shape=(reps, batch)).astype(np.float64)

                Z = out.reshape((batch, -1)).astype(np.float64)
                r_Z = fmod(r_left.dot(Z), p).reshape(reps, h_out, h_out, ch_in)
                assert np.max(np.abs(r_Z)) < 2 ** 52
                b_r = np.sum(r_left, axis=1).reshape(reps, 1) * bias.reshape(1, ch_in)
                r_Z = fmod_pos(r_Z - b_r.reshape(reps, 1, 1, ch_in), p)
                print("r_Z = {} {}".format(r_Z.reshape(-1)[:3], r_Z.reshape(-1)[-3:]))

                r_X = r_left.dot(inp.reshape(batch, -1)).reshape(reps, h, w, ch_in)
                assert np.max(np.abs(r_X)) < 2 ** 52
                r_X = fmod(r_X, p)
                print("r_X = {} {}".format(r_X.reshape(-1)[:3], r_X.reshape(-1)[-3:]))

                r_X_W = sess.run(tf.nn.depthwise_conv2d_native(r_X, kernel, (1,) + layer.strides + (1,), layer.padding.upper()))
                assert np.max(np.abs(r_X_W)) < 2 ** 52
                r_X_W = fmod_pos(r_X_W, p)
                print("r_X_W = {} {}".format(r_X_W.reshape(-1)[:3], r_X_W.reshape(-1)[-3:]))
                assert (r_X_W == r_Z).all()
                print()

            else:
                np.random.seed(0)
                r_left = np.random.randint(low=-r_max, high=r_max + 1, size=(2, h_out * w_out)).astype(np.float32)
                r_left = r_left[0, :].astype(np.float64)

                X = np.zeros((1, h, w, ch_in)).astype(np.float64)
                x_ph = tf.placeholder(tf.float64, shape=(None, h, w, ch_in))
                w_ph = tf.placeholder(tf.float64, shape=(k_h, k_w, ch_in, 1))
                y_ph = tf.placeholder(tf.float64, shape=(1, h_out, w_out, ch_in))
                z = tf.nn.depthwise_conv2d_native(x_ph, w_ph, (1,) + layer.strides + (1,), layer.padding.upper())
                dz = tf.gradients(z, x_ph, grad_ys=y_ph)[0]

                w_r = sess.run(dz, feed_dict={x_ph: X, w_ph: kernel.astype(np.float64),
                                              y_ph: r_left.reshape((1, h_out, w_out, 1)).repeat(ch_in, axis=-1)})
                w_r = np.fmod(w_r, p).astype(np.float32).reshape((h * w, ch_in))
                assert np.max(np.abs(w_r)) < 2 ** 52
                print("sum(w_r) = {}".format(np.sum(w_r.astype(np.float64))))
                print("r_left: {}".format(r_left.astype(np.float64).sum()))

                Z_r = r_left.dot(out.reshape(h_out * w_out, -1).astype(np.float64))
                assert np.max(np.abs(Z_r)) < 2 ** 52
                Z_r = fmod_pos(Z_r, p)
                print("Z_r = {} {}".format(Z_r.reshape(-1)[:3], Z_r.reshape(-1)[-3:]))

                r_X_W = (inp.astype(np.float64).reshape(h*w, ch_in) * w_r.astype(np.float64)).sum(axis=0)
                assert np.max(np.abs(r_X_W)) < 2 ** 52
                r_X_W = fmod(r_X_W, p)

                b_r = fmod(np.sum(r_left) * bias.astype(np.float64), p)
                print("sum(b_r): {}".format(np.sum(b_r)))
                r_X_W = fmod_pos(r_X_W + b_r, p)
                print("r_X_W = {} {}".format(r_X_W.reshape(-1)[:3], r_X_W.reshape(-1)[-3:]))
                assert (r_X_W == Z_r).all()
                print()


def model_to_json_full(sess, model,dtype=np.float32):
    def get_activation_name(layer):
        if layer.activation is not None:
            return activations.serialize(layer.activation)
        else:
            return ''


    import keras
    from keras import activations
    from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, DepthwiseConv2D, GlobalAveragePooling2D, Dropout, \
        Reshape, ZeroPadding2D, AveragePooling2D, Lambda
    from python.slalom.quant_layers import Conv2DF, DenseF, ActivationF
    from python.slalom.resnet import ResNetBlock
        
    def layer_to_json(layer):
        json = {}
        layer_weights = []

        print(type(layer))
        if isinstance(layer, keras.layers.InputLayer):
            json = {'name': 'Input', 'shape': layer.batch_input_shape}

        elif isinstance(layer, Conv2DF) and not isinstance(layer, DepthwiseConv2D):
            json = {'name': 'Conv2D',
                    'kernel_size': layer.kernel.get_shape().as_list(),
                    'strides': layer.strides,
                    'padding': layer.padding,
                    'activation': get_activation_name(layer),
                    'bits_w': 8, 'bits_x': 8}

            if isinstance(layer, Conv2DF):
                kernel = layer.kernel.eval(sess)
                if layer.bias is not None:
                    bias = layer.bias.eval(sess)
                else:
                    bias = np.zeros(kernel.shape[3])
                print("sum(abs(conv_w)): {}".format(np.abs(kernel).sum()))
                layer_weights.append(kernel.reshape(-1).astype(dtype))
                layer_weights.append(bias.reshape(-1).astype(dtype))
          
        elif isinstance(layer, DenseF):
            json = {'name': 'Dense', 'kernel_size': layer.kernel.get_shape().as_list(),
                    'pointwise_conv': False, 'activation': get_activation_name(layer),
                    'bits_w': 8, 'bits_x': 8}
            kernel = layer.kernel.eval(sess)
            bias_size = kernel.shape[1]
            kernel = kernel.reshape(-1).astype(dtype)
            if layer.bias is not None:
                bias = layer.bias.eval(sess).reshape(-1).astype(dtype)
            else:
                    bias = np.zeros(bias_size).astype(np.float32)
            print("sum(abs(dense_w)): {}".format(np.abs(kernel).sum()))
            layer_weights.append(kernel)
            layer_weights.append(bias)

        elif isinstance(layer, ActivationF):
            json = {'name': 'Activation', 'type': layer.activation_name(), 'bits_w': 8}
            
            if hasattr(layer, 'maxpool_params') and layer.maxpool_params is not None:
                json2 = {'name': 'MaxPooling2D', 'pool_size': layer.maxpool_params['pool_size'],
                         'strides': layer.maxpool_params['strides'], 'padding': layer.maxpool_params['padding']}
                
                json = [json, json2]
        elif isinstance(layer, Flatten):
            json = {'name': 'Flatten'}

        else:
            raise NameError("Unknown layer {}".format(layer))

        return json, layer_weights
    
    model_json = {'layers': [], 'shift_w': 2**8, 'shift_x': 2**8, 'max_tensor_size': 224*224*64}
    weights = []
    for idx, layer in enumerate(model.layers):
        json, layer_weights=layer_to_json(layer)
        print(json)

        #json, layer_weights = layer_to_json(layer)
        
        if json:
            if isinstance(json, dict):
                
                model_json['layers'].append(json)
            else:
                model_json['layers'].extend(json)
            weights.extend(layer_weights)

    return model_json, weights
