#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import python.slalom.keras_fix

import sys
import os
import copy 

import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
from keras import backend
from keras.layers import MaxPooling2D
from python import imagenet
from python.slalom.models import get_model
from python.slalom.quant_layers import transform_full, transform, ActivationF, build_blinding_ops, prepare_blinding_factors, get_all_linear_layers
from python.slalom.utils import Results
from python.slalom.sgxdnn import model_to_json, model_to_json_full
from python.slalom.global_sgx import sgxutils
from python.slalom.global_bias import bias_list
from keras.models import Sequential
import torch
from keras.losses import mean_squared_error as MSE
import time
import pickle

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["TF_USE_DEEP_CONV2D"] = '0'

def get_activation_input(model):
    acts = []
    print("getting activation inputs")
    for layer in model.layers:
        #if isinstance(layer, MaxPooling2D):
        #  acts.pop()
        #  acts.append(layer.output)
        
        if isinstance(layer, ActivationF):
            acts.append(layer.input)
            print(layer.input)
    return acts

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.set_random_seed(0)

    with tf.Graph().as_default():
        num_batches = args.max_num_batches
        batch_size = args.batch_size

        device = '/gpu:0'
        config = tf.ConfigProto(log_device_placement=False)
        config.allow_soft_placement = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.90
        config.gpu_options.allow_growth = True

        quantize = True
        slalom = not args.no_slalom
        blinded = args.blinding 
        integrity = args.integrity
        simulate = args.simulate

        sgx_batch_size = int(batch_size / 2 * 3)
        
        with tf.Session(config=config) as sess:
            with tf.device(device):
                
                model_sgx, model_info_sgx = get_model(args.model_name, sgx_batch_size, include_top=True, double_prec=False)
                modelq, model_info_q = get_model(args.model_name, batch_size, include_top=True, double_prec=False)

            dataset_images, labels = imagenet.load_validation(
                args.input_dir, batch_size, preprocess=model_info_sgx['preprocess'])

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            num_linear_layers = len(get_all_linear_layers(model_sgx))
            if blinded and not simulate:
                queues = [tf.FIFOQueue(capacity=num_batches + 1, dtypes=[tf.float32]) for _ in range(num_linear_layers)]
            else:
                queues = None


            
            with tf.device(device):
                model_sgx, l_sgx_ops_in, loPout_sgx = transform_full(model_sgx, sgxutils=sgxutils, quantization=False)
                print(model_sgx.summary())
                modelq, linq, oq = transform_full(modelq, sgxutils=None, quantization=False)
            

            # remove the loss function
            modelq.pop()
            #model_sgx.pop()
            #print(modelq.summary())
            print(model_sgx.summary())
            print(modelq.summary())

            dtype = np.float32
            model_json, weights = model_to_json_full(sess, model_sgx)
            
            print(model_json)

            def unblind(mixture, um):
                blind_size = int(mixture.shape[0] / 3)
                
                return (um @ mixture.reshape(blind_size, 3, -1)).reshape(mixture.shape)

            def pad_zero(image):
                merge_size = list(image.shape)
                split_size = list(image.shape)
                merge_size[0] = int(merge_size[0] / 2 * 3)
                split_batch_size = int(split_size[0] / 2)
                splited = np.split(image, split_batch_size, axis=0)
                new_sli = []
                for sli in splited:
                    sli_shape = list(sli.shape)
                    sli_shape[0] = int(sli_shape[0] / 2)
                    sli = np.concatenate([sli, np.zeros(sli_shape).astype(np.float32)])
                    new_sli.append(sli)

                retval = np.concatenate(new_sli)
                return retval

            
                
            bm, um, gm, igm = sgxutils.fill_parameter(internal_batch_size=3)
            sgxutils.dnnl_init(train_inside=0, internal_batch=3)
            sgxutils.load_model(model_json, weights, dtype=dtype, verify=False, verify_preproc=False)

            num_classes = np.prod(model_sgx.output.get_shape().as_list()[1:])    
            print(num_classes)
            print_acc = (num_classes == 1000)
            res = Results(acc=print_acc)
            res_sgx = Results(acc=print_acc)
            

            sgxutils.slalom_init(integrity, (blinded and not simulate), sgx_batch_size)

            x_shape = (batch_size, 224, 224, 3)
            image = (torch.randn(x_shape)).numpy()

            sgx_layer = 2
            tar_layer = 2

            
            image_padded  = pad_zero(image)


            
            blinded_image = sgxutils.slalom_blind_input(image_padded)
            #print((blinded_image[0] - image[0]).sum())
            #print((blinded_image[1] - image[1]).sum())
            #print((blinded_image[3] - image[2]).sum())
            #print((blinded_image[4] - image[3]).sum())
 
            #sys.exit(0)
            diff_array = []
            index       = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18,19,20,21,22]
            check_layer = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,14,16,18,20,22,24,26,29,31,33]
            target_layer= [0, 1, 2, 3, 5, 6, 7, 8,10,11,12,13,14,17,19,21,24,26,28,31,34,36,37]
            layer_idx = 22

            # loss functions                                                                                       
            y_tar = tf.placeholder(shape=(batch_size, 1000), dtype='float32')
            y_sgx = tf.placeholder(shape=(int(batch_size / 2 * 3), 1000), dtype='float32')
            loss_tar_fn = MSE(y_tar, modelq.output)
            loss_sgx_fn = MSE(y_sgx, model_sgx.output)

            trainable_weights_tar = modelq.trainable_weights
            trainable_weights_sgx = model_sgx.trainable_weights
            gradients_tar = backend.gradients(loss_tar_fn, trainable_weights_tar) 
            gradients_sgx = backend.gradients(loss_sgx_fn, trainable_weights_sgx)



            out_grad_tar = backend.gradients(loss_tar_fn, modelq.layers[36].output)
            out_grad_sgx = backend.gradients(loss_sgx_fn, model_sgx.layers[31].output)


            gradout_tar = (10000. * torch.randn((batch_size, 1000))).numpy()
            gradout_sgx = pad_zero(gradout_tar)


            act_input_tar = get_activation_input(modelq)
            act_input_sgx = get_activation_input(model_sgx)


            gradients_last_tar  = backend.gradients(loss_tar_fn, modelq.layers[37].output)
            gradients_last_sgx  = backend.gradients(loss_sgx_fn, model_sgx.layers[32].output)
            gradients_input_tar = backend.gradients(loss_tar_fn, act_input_tar)
            gradients_input_sgx = backend.gradients(loss_sgx_fn, act_input_sgx)

            
            
            diff_array = []
            with tf.device(device):
                run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

                #sgx_grad = sess.run(gradients_sgx,
                #                    feed_dict={model_sgx.input:blinded_image, y_sgx:gradout_sgx})
                #tar_grad = sess.run(gradients_tar,
                #                    feed_d#ict={modelq.input:image, y_tar:gradout_tar})
                for _ in range(3):
                    s = time.time()
                    #sgx_res = sess.run(model_sgx.layers[check_layer[layer_idx]].output,
                    #                   feed_dict={model_sgx.input:blinded_image})
                    sgx_grad = sess.run(gradients_sgx,
                                       feed_dict={model_sgx.input:blinded_image, y_sgx:gradout_sgx},
                                       options=run_options, 
                                       run_metadata=run_metadata)
                    end = time.time()
                    total_time = end - s
                    batch_time = total_time / batch_size
                    print("================================")
                    print("Total compute time ", total_time)
                    print("Time per batch ", batch_time)
                tar_grad = sess.run(gradients_tar,
                                    feed_dict={modelq.input:image, y_tar:gradout_tar})
                 
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open('timeline.json', 'w') as f:
                    f.write(ctf)

                #print(len(sgx_grad), len(tar_grad))
                #for i in range(len(tar_grad)):
                    #diff0 = abs(sgx_grad[i] - tar_grad[i]).mean()
                    #diff1 = abs(sgx_grad[i][1] - tar_grad[i][1]).mean()
                    
                    #diff2 = abs(sgx_grad[i][3] - tar_grad[i][2]).mean()
                    #diff3 = abs(sgx_grad[i][4] - tar_grad[i][3]).mean()
                
                    #diff_array.append((diff0, diff1, diff2, diff3))
                    #diff_array.append(diff0)
                
                sgxutils.print_time_report()
                #print(diff_array)
                #print(sgx_grad[11].shape)
                #print(act_input_sgx[11])
                #sgxc= sgx_grad[11].flatten(-1)
                #tarc= tar_grad[11].flatten(-1)
                #print(len(sgxc))
                #for idx in range(len(sgxc)):
                #    diff = sgxc[idx] - tarc[idx]
                #    if (diff != 0.0):
                #        print(idx, sgxc[idx], tarc[idx])
                #   if idx > 100:
                #        break
                #        
                #for i in range(length):
                #    if sgx_b[i] - tar_b[i] != 0.0:
                #        print(i, sgx_b[i], tar_b[i])
                sys.exit(0)
                
            '''
            print("here")
            with tf.device(device):
                sgx_res = sess.run(model_sgx.layers[check_layer[layer_idx]].output,
                                   feed_dict={model_sgx.input:blinded_image})
                tar_res = sess.run(modelq.layers[target_layer[layer_idx]].output,
                                   feed_dict={modelq.input:image})

                if layer_idx == len(index) - 1:
                    cleared_sgx_res = sgx_res
                else:
                    cleared_sgx_res = unblind(sgx_res, um)
                ite = int(batch_size / 2)
                diff = None
                for idx in range(ite):
                    diffl = abs(tar_res[idx*3] - cleared_sgx_res[idx*3])
                    diffl += abs(tar_res[idx*3+1] - cleared_sgx_res[idx*3+1])
                    diffl /= 2.
                    basel = tar_res[idx*3] + tar_res[idx*3+1]
                    
                    if diff is None:
                        diff = diffl
                    else:
                        diff += diffl

                diff_array.append((diff.mean(), diff.std()))
                print(model_sgx.layers[check_layer[layer_idx]])
                print(modelq.layers[target_layer[layer_idx]])
                print(diff_array)

                
            
            
            sys.exit(0)

            '''


            
            def unblind(mixture, a1, a2, b1, b2):
                size = mixture.shape[0] >> 1
                res = None
                for i in range(size):
                    x = mixture[2*i]
                    y = mixture[2*i+1]
                    
                    xc = (b2*x-a2*y) / (a1*b2-b1*a2)
                    yc = (b1*x-a1*y) / (a2*b1-a1*b2)
                    if res is None:
                        res = np.array([xc, yc])
                    else:
                        res = np.concatenate((res, np.array([xc, yc])), axis=0)

                return res

            def blind(mixture, a1, a2, b1, b2):
                size=mixture.shape[0] >> 1
                res = None

                for i in range(size):
                    x = mixture[2*i]
                    y = mixture[2*i+1]
                    
                    x_b = x*a1 + y*a2
                    y_b = x*b1 + y*b2
                    if res is None:
                        res = np.array([x_b, y_b])
                    else:
                        res=np.concatenate((res, np.array([x_b, y_b])), axis=0)
                return res

            
            check_layer = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,14,16,18,20,22,24,26,29,31,32]
            target_layer= [0, 1, 2, 3, 5, 6, 7, 8,10,11,12,13,14,17,19,21,24,26,28,31,34,36,37]

            #check_layer =[32]
            #target_layer=[38]
            #test_layer = [(True, None)]
            max_idx     = [0, 0, 0, 3, 0, 6, 0, 0,10, 0, 0,14, 0, 0,18, 0, 0, 0]
            test_layer  = [(False, None),  (False, None),
                           (True, "relu"), (False, None),
                           (True, "max"),  (False, None),
                           (True, "relu"), (False, None),
                           (True, "max"),  (False, None),
                           (True, "relu"), (False, None),
                           (True, "relu"),
                           (True, "max"),  (True, "relu"),
                           (True, "relu"), (True, "max"),
                           (True, "relu"), (True, "relu"),
                           (True, "max"),  (True, "relu"),
                           (True, "relu"), (True, "None")]
            assert(len(check_layer) == len(test_layer))

            # loss functions                                                                                              
            y_tar = tf.placeholder(shape=(batch_size, 1000), dtype='float32')
            y_sgx = tf.placeholder(shape=(batch_size, 1000), dtype='float32')
            loss_tar_fn = MSE(y_tar, modelq.output)
            loss_sgx_fn = MSE(y_sgx, model_sgx.output)
            
            trainable_weights_tar = modelq.trainable_weights
            trainable_weights_sgx = model_sgx.trainable_weights
            #gradients_tar = backend.gradients(loss_tar_fn, trainable_weights_tar)
            #gradients_sgx = backend.gradients(loss_sgx_fn, trainable_weights_sgx)

            act_input_tar = get_activation_input(modelq)
            act_input_sgx = get_activation_input(model_sgx)

            
            gradients_last_tar  = backend.gradients(loss_tar_fn, modelq.layers[37].output)
            gradients_last_sgx  = backend.gradients(loss_sgx_fn, model_sgx.layers[32].output)
            gradients_input_tar = backend.gradients(loss_tar_fn, act_input_tar)
            gradients_input_sgx = backend.gradients(loss_sgx_fn, act_input_sgx)
            print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
            for i in range(len(gradients_input_tar)):
                print(gradients_input_tar[i])
                print(gradients_input_sgx[i])
            print(len(gradients_input_tar), len(gradients_input_sgx))
            
            
            for i in range(num_batches):
                images, true_labels = sess.run([dataset_images, labels])
                image_original = copy.deepcopy(images[0])
                original_image0 = copy.deepcopy(images[0])
                original_image1 = copy.deepcopy(images[1])
                if quantize:
                    images = np.round(2 ** model_info['bits_x'] * images).astype(np.float32)
                    print("bits_x: {}, input images: {}".format(model_info['bits_x'], np.sum(np.abs(images))))

                print("before blinded", images[0][0][0])

                images_blinded = sgxutils.slalom_blind_input(images)
                #res.start_timer()
                res_sgx.start_timer()
                                
                diff_array = []
                with tf.device(device):
                    prev_layer = 0
                    prev_tar_layer=0
                    prev_value     = images_blinded
                    prev_value_sgx = images_blinded
                    bp = 1
                    i = 0

                    true_res = np.zeros((batch_size, 1000), dtype=np.float32)
                    for b in range(batch_size):
                        true_res[b][true_labels[b]] = 1.
                    print("=========================================================================")
                    #grad_sgx = sess.run(gradients_sgx, feed_dict={model_sgx.input:images, y_sgx:true_res})
                    
                    #for c in range(len(gradients_input_tar)):
                    print(type(gradients_input_tar))

                    #for bias_tensor in bias_list:
                    #    gradients_input_sgx.append(bias_tensor)
                    
                    #print(len(gradients_input_sgx), len(gradients_input_tar), len(bias_list))
                    #sys.exit(0)
                    #y_tar = sess.run(modelq.output, feed_dict={modelq.input:images})
                    #y_sgx = sess.run(model_sgx.output, feed_dict={model_sgx.input:images})
                    
                    #print((y_tar - y_sgx).sum())
                    #sys.exit(0)
                    print(len(gradients_last_tar))
                    print("gergesdsf")
                    gradients_input_tar.append(gradients_last_tar[0])
                    gradients_input_sgx.append(gradients_last_sgx[0])
                    gradients_input_tar.append(modelq.layers[37].input)
                    gradients_input_sgx.append(model_sgx.layers[32].input)
                    gradients_input_tar.append(modelq.output)
                    gradients_input_sgx.append(model_sgx.output)
                    grad_tar = sess.run(gradients_input_tar, feed_dict={modelq.input:images,    y_tar:true_res})
                    grad_sgx = sess.run(gradients_input_sgx, feed_dict={model_sgx.input:images, y_sgx:true_res})
                    
                    for idx in range(len(grad_tar)):
                        diff = np.abs((grad_sgx[idx] - grad_tar[idx]) != 0.0).sum()
                        diff_array.append(diff)
                    
                print("diff array content")
                print(diff_array)
                sys.exit(0)
                for i in range(15):
                    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
                    print(grad_sgx[i].shape, grad_sgx[29-i].shape)
                    print("means", grad_sgx[i].mean())
                    print(grad_sgx[i].sum(), grad_sgx[29-i].sum())
                    print(grad_sgx[i].sum() - grad_sgx[29-i].sum())
                sys.exit(0)
                preds_sgx = np.reshape(cleared_sgx, (batch_size, -1))
                preds = np.reshape(prev_value, (batch_size, -1))
                #res.end_timer()
                res_sgx.end_timer()
                res_sgx.record_acc(preds_sgx, true_labels)
                #res.record_acc(preds, true_labels)
            #res.print_results()
            res_sgx.print_results()
            sys.stdout.flush()
            coord.request_stop()
            coord.join(threads)

        if sgxutils is not None:
            sgxutils.destroy()

           
            
if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument('model_name', type=str,
                        choices=['vgg_16', 'vgg_19', 'inception_v3', 'mobilenet', 'mobilenet_sep', 
                                 'resnet_18', 'resnet_34', 'resnet_50', 'resnet_101', 'resnet_152'])
    parser.add_argument('--input_dir', type=str,
                        default='../imagenet/',
                        help='Input directory with images.')
    parser.add_argument('--batch_size', type=int, default=8,
                        help='How many images process at one time.')
    parser.add_argument('--max_num_batches', type=int, default=2,
                        help='Max number of batches to evaluate.')
    parser.add_argument('--use_sgx', action='store_true')
    parser.add_argument('--no_slalom', action='store_true', help='only test GPU quantization')
    parser.add_argument('--blinding', action='store_true', help='add random blinding for privacy')
    parser.add_argument('--integrity', action='store_true', help='add integrity checks')
    parser.add_argument('--simulate', action='store_true')
    args = parser.parse_args()

    tf.app.run()
