
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import python.slalom.keras_fix

import sys
import os
import copy 

import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
from keras import backend
from keras.layers import MaxPooling2D
from python import imagenet
from python.slalom.models import get_model
from python.slalom.quant_layers import transform, build_blinding_ops, prepare_blinding_factors, get_all_linear_layers
from python.slalom.utils import Results
from python.slalom.sgxdnn import model_to_json
from python.slalom.global_sgx import sgxutils
from keras.models import Sequential


import pickle

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["TF_USE_DEEP_CONV2D"] = '0'


def blind(x, y, a1, a2, b1, b2):
    return x*a1 + y*a2, x*b1 + y*b2

def unblind(x, y, a1, a2, b1, b2):
    return (b2*x-a2*y) / (a1*b2-b1*a2), (b1*x-a1*y) / (a2*b1 - a1*b2) 

    
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches = args.max_num_batches
        batch_size = args.batch_size

        device = '/cpu:0'
        config = tf.ConfigProto(log_device_placement=False)
        config.allow_soft_placement = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.90
        config.gpu_options.allow_growth = True

        quantize = True
        slalom = not args.no_slalom
        blinded = args.blinding 
        integrity = args.integrity
        simulate = args.simulate

        with tf.Session(config=config) as sess:
            with tf.device(device):
                model, model_info = get_model(args.model_name, batch_size, include_top=True, double_prec=False)
                pool = backend.function([model.layers[3].input], [model.layers[3].output])
                
            dataset_images, labels = imagenet.load_validation(
                args.input_dir, batch_size, preprocess=model_info['preprocess'])

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            num_linear_layers = len(get_all_linear_layers(model))
            if blinded and not simulate:
                queues = [tf.FIFOQueue(capacity=num_batches + 1, dtypes=[tf.float32]) for _ in range(num_linear_layers)]
            else:
                queues = None

            with tf.device(device):
                model, linear_ops_in, linear_ops_out = transform(model, log=False, quantize=quantize, verif_preproc=True,
                                                             slalom=slalom, slalom_integrity=integrity, slalom_privacy=blinded,
                                                             bits_w=model_info['bits_w'],
                                                             bits_x=model_info['bits_x'],
                                                                 sgxutils=sgxutils, queues=queues)

                
                layer01 = backend.function([model.layers[0].input], [model.layers[1].output])
                layer2  = backend.function([model.layers[2].input], [model.layers[2].output])
                layer3  = backend.function([model.layers[3].input], [model.layers[3].output])
                layer4  = backend.function([model.layers[4].input], [model.layers[4].output])
                layer_res = backend.function([model.layers[5].input], [model.layers[-1].output])
                
            dtype = np.float32
            model_json, weights, model_json2, weigh2 = model_to_json(sess=sess,
                                                                     model=model,
                                                                     dtype=dtype,
                                                                     verif_preproc=True,
                                                                     slalom_privacy=blinded,
                                                                     bits_w=model_info['bits_w'],
                                                                     bits_x=model_info['bits_x'])
            sgxutils.dnnl_init()
            sgxutils.load_model(model_json, weights, dtype=dtype, verify=True, verify_preproc=True)

            num_classes = np.prod(model.output.get_shape().as_list()[1:])    
            print_acc = (num_classes == 1000)
            res = Results(acc=print_acc)
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

            sgxutils.slalom_init(integrity, (blinded and not simulate), batch_size)
            
            if blinded and not simulate:
                in_ph, zs, out_ph, queue_ops, temps, out_funcs = build_blinding_ops(model, queues, batch_size)

            for i in range(num_batches):
                images, true_labels = sess.run([dataset_images, labels])
                image_original = copy.deepcopy(images[0])
                original_image0 = copy.deepcopy(images[0])
                original_image1 = copy.deepcopy(images[1])
                if quantize:
                    images = np.round(2 ** model_info['bits_x'] * images).astype(np.float32)
                    print("bits_x: {}, input images: {}".format(model_info['bits_x'], np.sum(np.abs(images))))

                if blinded and not simulate:
                    prepare_blinding_factors(sess, model, sgxutils, in_ph, zs, out_ph, queue_ops, batch_size, num_batches=1,
                                             #inputs=images, temps=temps, out_funcs=out_funcs
                    )
                    
                images = sgxutils.slalom_blind_input(images)


                encrypted_image0 = copy.deepcopy(images[0])
                encrypted_image1 = copy.deepcopy(images[1])

                expected_image0 = 5 * original_image0 + 6 * original_image1
                expected_image1 = original_image0 + original_image1

                res1 = np.sum((encrypted_image0 - expected_image0) > 5.5)
                res2 = np.sum((encrypted_image1 - expected_image1) > 1.0)
                        
                if (res1  != 0.0 or res2 != 0.0):
                    print("python print")
                    print(original_image0[0][0][0], original_image1[0][0][0])
                    print(encrypted_image0[0][0][0], encrypted_image1[0][0][0])
                    print(expected_image0[0][0][0], expected_image1[0][0][0])
                    print(res1, res2)
                    print("mismatch")
                
                res.start_timer()
                
                layer2_input = layer01([images])

                conv0 = layer2_input[0][0]
                conv1 = layer2_input[0][1]
                unblinded0 = 6*conv1-conv0
                unblinded1 = conv0-5*conv1
                
                layer2_output = layer2(layer2_input)

                layer3_output = layer3(layer2_output)
                layer4_output = layer4(layer3_output)
                
                unblinded0 = np.maximum(0, unblinded0)
                unblinded1 = np.maximum(0, unblinded1)
                
                ep0 = unblinded0 * 5 + unblinded1 * 6
                ep1 = unblinded0 + unblinded1

                go0 = layer2_output[0][0]
                go1 = layer2_output[0][1]

                res0 = np.sum(np.absolute(ep0-go0) > 100.0)
                res1 = np.sum(np.absolute(ep1-go1) > 5.0)

                
                if (res0 > 0.0 or res1 > 0.0):
                    print("input ", conv0[0][0][0], conv1[0][0][0])
                    print("got ", go0[0][0][0], go1[0][0][0])
                    print("exp ", ep0[0][0][0], ep1[0][0][0])

                    print(res0, res1)
                    print("mismatch")
                

                conv2 = layer3_output[0][0]
                conv3 = layer3_output[0][1]
                unblinded0 = 6*conv3-conv2
                unblinded1 = conv2-5*conv3
                
                print("pythohn input 1", unblinded0[0][0][0], unblinded0[0][1][0], unblinded0[1][1][0], unblinded0[1][0][0])
                print("pythonn input 2", unblinded1[0][0][0], unblinded1[0][1][0], unblinded1[1][1][0], unblinded1[1][0][0])
                
                r = np.maximum(0, pool([np.array([unblinded0, unblinded1])]))
                

                unblinded0 = r[0][0]
                
                unblinded1 = r[0][1]
               

                ep0 = unblinded0 * 5 + unblinded1 * 6
                ep1 = unblinded0 + unblinded1
                go0 = layer4_output[0][0]
                go1 = layer4_output[0][1]

                res0 = np.sum(np.absolute(ep0-go0) > 150.0)
                res1 = np.sum(np.absolute(ep1-go1) > 40.0)

                
                
                if (res0 > 0.0 or res1 > 0.0):
                    print("input 1", conv2[0][0][0], conv2[0][1][0], conv2[1][1][0], conv2[1][0][0])
                    print("input 2", conv3[0][0][0], conv3[0][1][0], conv3[1][1][0], conv3[1][0][0])
                    print("got ", go0[0][0][0], go1[0][0][0])
                    print("exp ", ep0[0][0][0], ep1[0][0][0])
                    print(res0, res1)
                    print("mismatch")

                res.start_timer()
                res.end_timer()
                total_res = layer_res(layer4_output)
                preds = np.reshape(total_res, (batch_size, -1))
                print(preds, true_labels)            
                res.record_acc(preds, true_labels)
                res.print_results()
                sys.stdout.flush()
            coord.request_stop()
            coord.join(threads)

        if sgxutils is not None:
            sgxutils.destroy()


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument('model_name', type=str,
                        choices=['vgg_16', 'vgg_19', 'inception_v3', 'mobilenet', 'mobilenet_sep', 
                                 'resnet_18', 'resnet_34', 'resnet_50', 'resnet_101', 'resnet_152'])
    parser.add_argument('--input_dir', type=str,
                        default='../imagenet/',
                        help='Input directory with images.')
    parser.add_argument('--batch_size', type=int, default=8,
                        help='How many images process at one time.')
    parser.add_argument('--max_num_batches', type=int, default=2,
                        help='Max number of batches to evaluate.')
    parser.add_argument('--use_sgx', action='store_true')
    parser.add_argument('--no_slalom', action='store_true', help='only test GPU quantization')
    parser.add_argument('--blinding', action='store_true', help='add random blinding for privacy')
    parser.add_argument('--integrity', action='store_true', help='add integrity checks')
    parser.add_argument('--simulate', action='store_true')
    args = parser.parse_args()

    tf.app.run()
