#define USE_EIGEN_TENSOR

#ifndef USE_SGX
#define EIGEN_USE_THREADS
#else
#include "Enclave.h"
#include "sgx_trts.h"
#endif

#include "sgxdnn_main.hpp"
#include "layers/eigen_maxpool.h"
#include "layers/batchnorm.hpp"
#include "layers/conv2d_pure.hpp"
#include "layers/dense_inner.hpp"
#include "randpool.hpp"
#include "utils.hpp"
#include "benchmark.hpp"

#include <unsupported/Eigen/CXX11/Tensor>
#include "model.hpp"
#include <iostream>
#include <memory>
#include <chrono>
#include <string>
#include <cstring>
#include <deque>
#include <math.h>

#include "Crypto.h"

using namespace SGXDNN;

// prime P chosen for data blinding. Chosen such that P + P/2 < 2^24
int p_int = (1 << 23) + (1 << 21) + 7;
float p = (float) p_int;
float mid = (float) (p_int / 2);
float* bias_grads;
float* temp_buffer;
float* temp_buffer2;
float* temp_buffer3;
float* temp_buffer4;
double* time_report;
float* bm_buf;
float* um_buf;
float* gm_buf;
float* igm_buf;
// prime used for Freivalds checks. Largest prime smaller than 2^24
int p_verif = ((1 << 24) - 3);
double inv_p_verif = 1.0 / p_verif;

// some vectorized constants
__m256 p8f = _mm256_set1_ps(p);
__m256 mid8f = _mm256_set1_ps(mid);
__m256 negmid8f = _mm256_set1_ps(-mid);
__m256 zero8f = _mm256_set1_ps((float)(0));
__m256 inv_shift8f = _mm256_set1_ps((float)(1.0/256));
__m256 six8f = _mm256_set1_ps((float) 6 * 256 * 256);

double elapsed = 0;
double elapsed2 = 0;
double elapsed3 = 0;
double elapsed4 = 0;
double elapsed5 = 0;
double elapsed6 = 0;
double elapsed7 = 0;

double elapsed100 = 0;
double elapsed101 = 0;
double elapsed102 = 0;
double elapsed103 = 0;
double elapsed104 = 0;
double elapsed105 = 0;


float alfa1_raw = 1;
float alfa2_raw = 0;
float beta1_raw = 0;
float beta2_raw = 1;
float alfa1p_raw = 1;
float alfa2p_raw = 0;
float beta1p_raw = 0;
float beta2p_raw = 1;
bool time;
__m256 alfa1 = _mm256_set1_ps((float)(alfa1_raw));
__m256 alfa2 = _mm256_set1_ps((float)(alfa2_raw));
__m256 beta1 = _mm256_set1_ps((float)(beta1_raw));
__m256 beta2 = _mm256_set1_ps((float)(beta2_raw));

__m256 alfa1p = _mm256_set1_ps((float)(alfa1p_raw));
__m256 alfa2p = _mm256_set1_ps((float)(alfa2p_raw));
__m256 beta1p = _mm256_set1_ps((float)(beta1p_raw));
__m256 beta2p = _mm256_set1_ps((float)(beta2p_raw));


#include "../Enclave/batchnorm_sp.h"
#include "../Enclave/depthwise_sp.h"
#include "../Enclave/inverted_sp.h"
#include "../Enclave/linear_sp.h"
#include "layers/batchnorm.hpp"
#include <string>
#include "immintrin.h"

float mean(float* inp, int size) {
  float means[16];
  float* mean_aligned = ALIGN32(means);
  __m256 mean_vec = _mm256_set1_ps((float)(0));

  for (int j = 0; j < size; j+=8) {
    __m256 num = _mm256_load_ps(&inp[j]);
    mean_vec = _mm256_add_ps(num, mean_vec);
  }
  _mm256_stream_ps(mean_aligned, mean_vec);
  float res = 0.0;
  for (int i = 0; i < 8; i++)
    res += mean_aligned[i];
  res /= ((float) size);
  return res;
} 

BatchNormSP::BatchNormSP(int size[4], int mode, float eps, float momentum, memory::desc src_tag, engine eng, stream s) {
    const char* mode_char = "";
    mode_num_ = mode;
    if (mode == 0)
        mode_char = "bn";
    else if (mode == 1)
        mode_char = "bnrelu";
    else if (mode == 2)
        mode_char = "bnadd";
    printf("mode %d, %f, %f", mode, eps, momentum);
    mode_ = std::string(mode_char);
    std::string name  = std::string("name");
    array4d dim       = {size[0], size[1], size[2], size[3]};
    batchnorm_        = (void*)new SGXDNN::BatchNormSp<float>(name, dim);
    ((SGXDNN::BatchNormSp<float>*) batchnorm_)->update_params(false, eps, momentum);
    output_size_      = size[0] * size[1] * size[2] * size[3];
    input_size_       = output_size_;
    src_tag_          = src_tag;
    eng_              = eng;
    s_                = s;
    nhwc_ = memory::desc({{size[0], size[3], size[1], size[2]}}, memory::data_type::f32, memory::format_tag::nhwc);
    if (nhwc_ != src_tag_) {
        input_reorder_ = true;
        primitive_attr dummy_attr;
        primitive reorder_layer = reorder(reorder::primitive_desc(eng_, src_tag_, eng_, nhwc_, dummy_attr));
        net_fwd.push_back(reorder_layer);
    }
    ocall_extern_alloc((void**) &src_dump_, sizeof(float)*src_tag_.get_size());
    if (mode == 2)
        ocall_extern_alloc((void**) &act_src_dump_, sizeof(float)*src_tag_.get_size());
}

double batchnorm_fwd_time = 0.0;
double batchnorm_bwd_time = 0.0;

void BatchNormSP::forward_special (float* src, float* dst, float* skip, bool is_train) {
  float* batchnorm_src = src;
  float* batchnorm_src_ptr = nullptr;
  if (input_reorder_) {
    batchnorm_src_ptr = new float[this->output_size_+8];
    batchnorm_src     = ALIGN32(batchnorm_src_ptr);
    memory src_1 = memory(src_tag_, eng_, src);
    memory dst_1 = memory(nhwc_, eng_, batchnorm_src);

    net_fwd.at(0).execute(s_, {{DNNL_ARG_FROM, src_1}, 
                      {DNNL_ARG_TO, dst_1}});
    s_.wait();

  }

  if (mode_num_ == 2) {
    act_src_ptr_ = new float[this->output_size_+8];
    act_src_     = ALIGN32(act_src_ptr_);
    }
  // dump src to dump_src_
  if (sharding) {
      this->dump_src_input((void*) src_dump_, 
                        batchnorm_src,
                        sizeof(float) * this->output_size_);
    saved_src_ = nullptr;
  } else {
        saved_src_ = batchnorm_src;
  }
  // mean
  float means = mean(batchnorm_src, this->output_size_);

  ((SGXDNN::BatchNormSp<float>*) batchnorm_)->fwd(dst, batchnorm_src, &means, skip, act_src_, nullptr, 1, 1, mode_.c_str());

  if (sharding && batchnorm_src_ptr != nullptr)
    delete [] batchnorm_src_ptr;
    
  if (mode_num_ == 2 && sharding && act_src_ptr_ != nullptr) {
        this->dump_src_input((void*) act_src_dump_, 
                        act_src_,
                        sizeof(float) * this->output_size_);
        delete [] act_src_ptr_;
    }
}


void BatchNormSP::forward(float* src, float* dst, bool is_train) {
    forward_special(src, dst, src, is_train);
}

void BatchNormSP::update_backward(memory::desc diff_src_tag) {
    diff_src_tag_ = diff_src_tag;
    if (nhwc_ != diff_src_tag) {
        primitive_attr dummy_attr;
        grad_reorder_ = true;
        primitive reorder_layer = reorder(reorder::primitive_desc(eng_, diff_src_tag, eng_, nhwc_, dummy_attr));
        net_bwd.push_back(reorder_layer);
    }

}

void BatchNormSP::backward_special(float* diff_src_ptr, float* diff_dst_ptr, float* diff_skip_ptr) {
    float* batchnorm_diff_dst_ptr = nullptr;
    float* batchnorm_diff_dst     = diff_dst_ptr;
    if (grad_reorder_) {
        batchnorm_diff_dst_ptr = new float[this->output_size_+8];
        batchnorm_diff_dst     = ALIGN32(batchnorm_diff_dst_ptr);
        net_bwd.at(0).execute(s_, {{DNNL_ARG_FROM, memory(diff_src_tag_, eng_, diff_dst_ptr)}, 
                                   {DNNL_ARG_TO,   memory(nhwc_,     eng_,     batchnorm_diff_dst)}});

    }

    // load dumpped srcs
    float* saved_src_ptr = nullptr;
    float* saved_act_ptr = nullptr;
    if (sharding) {
        saved_src_ptr    = new float[sizeof(float) * this->input_size_ + 8];
        this->saved_src_ = ALIGN32(saved_src_ptr);
        this->load_src_input((void*) this->saved_src_, (void*) this->src_dump_, 
                             sizeof(float) * this->input_size_);

        if (mode_num_ == 2) {
            saved_act_ptr  = new float[sizeof(float) * this->input_size_ + 8];
            this->act_src_ = ALIGN32(saved_act_ptr);
            this->load_src_input((void*) this->saved_src_, (void*) this->act_src_dump_, 
                             sizeof(float) * this->input_size_);

        }
    }
    if (mode_num_ != 2) {
        act_src_ = batchnorm_diff_dst;
    }

    ((SGXDNN::BatchNormSp<float>*) batchnorm_)->bwd(batchnorm_diff_dst, 
                                                    diff_src_ptr,
                                                    saved_src_,
                                                    nullptr,
                                                    act_src_
                                                    );

    if (batchnorm_diff_dst_ptr != nullptr)
        delete [] batchnorm_diff_dst_ptr;
    if (sharding && saved_src_ptr != nullptr)
        delete [] saved_src_ptr;
    if (sharding && saved_act_ptr != nullptr)
        delete [] saved_act_ptr;

}

void BatchNormSP::backward(float* diff_src_ptr, float* diff_dst_ptr) {
    backward_special(diff_src_ptr, diff_dst_ptr, nullptr);
}


int BatchNormSP::out_size() {
  return output_size_;
}

int BatchNormSP::input_size() {
  return input_size_;
}

memory::desc BatchNormSP::dst_desc() {

    return nhwc_;
}

memory::desc BatchNormSP::diff_src_desc() {
    return nhwc_;
}

double dump_time = 0.0;
double load_time = 0.0;
int    dump_num=0;
int    load_num=0;
void LayerSp::dump_src_input(void* dst, void* src, int num_bytes) {
    double start = get_time_force();
    sgx_aes_gcm_128bit_iv_t  *iv  = (sgx_aes_gcm_128bit_iv_t*)new sgx_aes_gcm_128bit_iv_t;
    sgx_aes_gcm_128bit_tag_t *mac = (sgx_aes_gcm_128bit_tag_t*)new sgx_aes_gcm_128bit_tag_t;
    iv_ptr.push_back((void*) iv);
    mac_ptr.push_back((void*) mac);

    encrypt((uint8_t*) src, num_bytes, (uint8_t*) dst, iv, mac);
    double end = get_time_force();
    dump_time += get_elapsed_time(start, end);
    dump_num++;
}

void LayerSp::load_src_input(void* dst, void* src, int num_bytes) {

    double start = get_time_force();

    sgx_aes_gcm_128bit_iv_t  *iv   = (sgx_aes_gcm_128bit_iv_t*)  iv_ptr.at(0);
    sgx_aes_gcm_128bit_tag_t *mac  = (sgx_aes_gcm_128bit_tag_t*) mac_ptr.at(0);
    
    decrypt((uint8_t*) src, num_bytes, (uint8_t*) dst, iv, mac, (uint8_t*) src);
    iv_ptr.erase(iv_ptr.begin(), iv_ptr.begin()+1);
    mac_ptr.erase(mac_ptr.begin(), mac_ptr.begin()+1);
    double end = get_time_force();

    load_time += get_elapsed_time(start, end);
    load_num++;
}   

MemPool* mem_pool = nullptr;


LinearSp::LinearSp(int size[4], int out_size[2], 
                   float* kernel_data, 
                   float* bias_data, 
                   memory::desc src_tags, 
                   engine eng, stream s) {
    if (mem_pool == nullptr) {
        mem_pool = new MemPool(1, 1);
    }

    int h_in = size[1] * size[2] * size[3];
    int h_out= out_size[1];
    const array4d input_shape = {size[0], size[1], size[2], size[3]};
    linear_ = (void*) new DenseInner<float>("DenseInner",
                                        input_shape,
                                        h_in,
                                        h_out,
                                        kernel_data, 
                                        bias_data,
                                        mem_pool,
                                        false,
                                        false
    );
    this->output_size_ = h_out;
    this->input_size_  = h_in;
    eng_ = eng;
    s_   = s;
    nhwc_ = memory::desc({{size[0], size[3], size[1], size[2]}}, memory::data_type::f32, memory::format_tag::nhwc);
    printf("%d %d %d %d", size[0], size[1], size[2], size[3]);
    nhwc_out_ = memory::desc({{1, h_out, 1, 1}}, memory::data_type::f32, memory::format_tag::nhwc);
    src_tag_ = src_tags;
    ocall_extern_alloc((void**) &src_dump_, sizeof(float)*src_tag_.get_size());
    printf("before reorder");
    if (nhwc_ != src_tag_) {
        input_reorder_ = true;
        primitive_attr dummy_attr;
        printf("before primitive");
        printf("%d %d", src_tag_.get_size(), nhwc_.get_size());
        primitive reorder_layer = reorder(reorder::primitive_desc(eng_, src_tag_, eng_, nhwc_, dummy_attr));
        printf("after primitive");

        net_fwd.push_back(reorder_layer);
    }
    printf("after reorder");

}

void LinearSp::forward(float* src, float* dst, bool is_train) {
    double s = get_time_force();
    float* linear_src = src;
    float* linear_src_ptr = nullptr;
    if (input_reorder_) {
        linear_src_ptr = new float[this->input_size_+8];
        linear_src     = ALIGN32(linear_src_ptr);
        net_fwd.at(0).execute(s_, {{DNNL_ARG_FROM, memory(src_tag_, eng_, src)}, 
                          {DNNL_ARG_TO,   memory(nhwc_, eng_, linear_src)}});
        s_.wait();
    }

    // dump src to dump_src
    if (sharding) {
        this->dump_src_input((void*) src_dump_, 
                        linear_src,
                        sizeof(float) * this->input_size_);
        this->saved_src_ = nullptr;
    } else {
        if (input_reorder_)
            this->saved_src_ = linear_src;
        else
            this->saved_src_ = src;
    }
    ((SGXDNN::DenseInner<float>*) linear_)->fwd(linear_src, dst);
    if (sharding && linear_src_ptr != nullptr)
        delete [] linear_src_ptr;
    double e = get_time_force();
    batchnorm_fwd_time += get_elapsed_time(s, e);
}

void LinearSp::update_backward(memory::desc diff_src_tag) {
    diff_src_tag_ = diff_src_tag;
    if (nhwc_out_ != diff_src_tag) {
        primitive_attr dummy_attr;
        grad_reorder_ = true;
        primitive reorder_layer = reorder(reorder::primitive_desc(eng_, diff_src_tag, eng_, nhwc_out_, dummy_attr));
        net_bwd.push_back(reorder_layer);
    }
}

void LinearSp::backward(float* diff_src_ptr, float* diff_dst_ptr) {
    double s = get_time_force();

    float* saved_src_ptr = nullptr;
    if (sharding) {
        saved_src_ptr    = new float[sizeof(float) * this->input_size_ + 8];
        this->saved_src_ = ALIGN32(saved_src_ptr);
        this->load_src_input((void*) this->saved_src_, (void*) this->src_dump_, 
                             sizeof(float) * this->input_size_);
    }

    // reorder diff_dst_ptr
    float* depthwise_diff_dst     =  diff_dst_ptr;
    float* depthwise_diff_dst_ptr = nullptr;
    if (grad_reorder_) {
        depthwise_diff_dst_ptr = new float[this->output_size_+8];
        depthwise_diff_dst     = ALIGN32(depthwise_diff_dst_ptr);
        net_bwd.at(0).execute(s_, {{DNNL_ARG_FROM, memory(diff_src_tag_, eng_, diff_dst_ptr)}, 
                                   {DNNL_ARG_TO,   memory(nhwc_out_,     eng_, depthwise_diff_dst)}});
    }
    ((SGXDNN::DenseInner<float>*) linear_)->bwd(depthwise_diff_dst, diff_src_ptr, this->saved_src_);

    if (depthwise_diff_dst_ptr != nullptr)
        delete[] depthwise_diff_dst_ptr;
    if (sharding && saved_src_ptr != nullptr)
        delete [] saved_src_ptr;

    double e = get_time_force();
    batchnorm_bwd_time += get_elapsed_time(s, e);


}

DepthwiseConv2dSp::DepthwiseConv2dSp(
           int conv_src_sz[4], int conv_dst_sz[4], int conv_weight_sz[4],
           int conv_strides_sz[2], int conv_padding_sz[2], memory::desc src_tags,
           engine eng, stream s, float* weight_data, float* bias_data) {
    const array4d input_shape = { conv_src_sz[0], conv_src_sz[1], conv_src_sz[2], conv_src_sz[3]};
    const array4d kernel_shape = {conv_weight_sz[0], conv_weight_sz[1], conv_weight_sz[2], conv_weight_sz[3]};
    if (mem_pool == nullptr) {
        mem_pool = new MemPool(1, 1);
    }
    depthwise_ = (void*) new SGXDNN::DepthwiseConv2D<float>("depthwise_conv",
                                                        input_shape,
                                                        kernel_shape,
                                                        conv_strides_sz[0],
                                                        conv_strides_sz[1],
                                                        Eigen::PaddingType::PADDING_SAME,
                                                        nullptr,
                                                        weight_data,
                                                        bias_data,
                                                        mem_pool,
                                                        false,
                                                        false);

    output_size_      = conv_dst_sz[0] * conv_dst_sz[1] * conv_dst_sz[2] * conv_dst_sz[3];
    input_size_       = conv_src_sz[0] * conv_src_sz[1] * conv_src_sz[2] * conv_src_sz[3];
    src_tag_          = src_tags;
    eng_              = eng;
    s_                = s;

    nhwc_             = memory::desc({{conv_src_sz[0], conv_src_sz[3], conv_src_sz[1], conv_src_sz[2]}}, memory::data_type::f32, memory::format_tag::nhwc);
    nhwc_out_         = memory::desc({{conv_dst_sz[0], conv_dst_sz[3], conv_dst_sz[1], conv_dst_sz[2]}}, memory::data_type::f32, memory::format_tag::nhwc);
    if (nhwc_ != src_tag_) {
        input_reorder_ = true;
        primitive_attr dummy_attr;
        primitive reorder_layer = reorder(reorder::primitive_desc(eng_, src_tag_, eng_, nhwc_, dummy_attr));
        net_fwd.push_back(reorder_layer);
    }
    ocall_extern_alloc((void**) &src_dump_, sizeof(float)*src_tag_.get_size());
}

void DepthwiseConv2dSp::forward(float* src, float* dst, bool is_train) {
    double s = get_time_force();

    float* depthwise_src = src;
    float* depthwise_src_ptr = nullptr;
    if (input_reorder_) {
        depthwise_src_ptr = new float[this->input_size_+8];
        depthwise_src     = ALIGN32(depthwise_src_ptr);
        net_fwd.at(0).execute(s_, {{DNNL_ARG_FROM, memory(src_tag_, eng_, src)}, 
                          {DNNL_ARG_TO,   memory(nhwc_, eng_, depthwise_src)}});
        s_.wait();
    }

    // dump src to dump_src
    if (sharding) {
        this->dump_src_input((void*) src_dump_, 
                        depthwise_src,
                        sizeof(float) * this->input_size_);
        this->saved_src_ = nullptr;
    } else {
        if (input_reorder_)
            this->saved_src_ = depthwise_src;
        else
            this->saved_src_ = src;
    }
    ((SGXDNN::DepthwiseConv2D<float>*) depthwise_)->fwd(depthwise_src, dst);

    if  (sharding && depthwise_src_ptr != nullptr)
        delete [] depthwise_src_ptr;
    double e = get_time_force();
    batchnorm_fwd_time += get_elapsed_time(s, e);
}
    
void DepthwiseConv2dSp::update_backward(memory::desc diff_src_tag) {
    diff_src_tag_ = diff_src_tag;
    if (nhwc_out_ != diff_src_tag) {
        primitive_attr dummy_attr;
        grad_reorder_ = true;
        primitive reorder_layer = reorder(reorder::primitive_desc(eng_, diff_src_tag, eng_, nhwc_out_, dummy_attr));
        net_bwd.push_back(reorder_layer);
    }
}
void DepthwiseConv2dSp::backward(float* diff_src_ptr, float* diff_dst_ptr) {
    double s = get_time_force();

    float* saved_src_ptr = nullptr;
    if (sharding) {
        saved_src_ptr    = new float[sizeof(float) * this->input_size_ + 8];
        this->saved_src_ = ALIGN32(saved_src_ptr);
        this->load_src_input((void*) this->saved_src_, (void*) this->src_dump_, 
                             sizeof(float) * this->input_size_);
    }

    // reorder diff_dst_ptr
    float* depthwise_diff_dst     =  diff_dst_ptr;
    float* depthwise_diff_dst_ptr = nullptr;
    if (grad_reorder_) {
        depthwise_diff_dst_ptr = new float[this->output_size_+8];
        depthwise_diff_dst     = ALIGN32(depthwise_diff_dst_ptr);
        net_bwd.at(0).execute(s_, {{DNNL_ARG_FROM, memory(diff_src_tag_, eng_, diff_dst_ptr)}, 
                                   {DNNL_ARG_TO,   memory(nhwc_out_,     eng_, depthwise_diff_dst)}});
    }
    ((SGXDNN::DepthwiseConv2D<float>*) depthwise_)->bwd(depthwise_diff_dst, diff_src_ptr, this->saved_src_);

    if (depthwise_diff_dst_ptr != nullptr)
        delete[] depthwise_diff_dst_ptr;
    if (sharding && saved_src_ptr != nullptr)
        delete [] saved_src_ptr;

    double e = get_time_force();
    batchnorm_bwd_time += get_elapsed_time(s, e);
}   

int DepthwiseConv2dSp::out_size() {
    return output_size_;
}

int DepthwiseConv2dSp::input_size() {
    return input_size_;
}

memory::desc DepthwiseConv2dSp::dst_desc() {
    return nhwc_out_;
}

memory::desc DepthwiseConv2dSp::diff_src_desc() {
    return nhwc_;
}

void InvertedCellSp::forward(float* src, float* dst, bool is_train) {
    int size = this->network_.size();
    float* res_ptr;
    float* res_ptr_aligned;
    float* in_ptr = src;
    float* in_ptr_aling = ALIGN32(in_ptr);

    for (int i = 0; i < size - 1; i++) {
        LayerSp* l = this->network_.at(i);
        res_ptr = new float[l->dst_desc().get_size() / sizeof(float) + 8];
        res_ptr_aligned = ALIGN32(res_ptr);
        l->forward(in_ptr_aling, res_ptr_aligned, true);
        if (i != 0)
            delete [] in_ptr;
        in_ptr = res_ptr;
        in_ptr_aling = res_ptr_aligned;
    }
    BatchNormSP* l = ((BatchNormSP*)this->network_.at(size - 1));
    
    l->forward_special(in_ptr_aling, dst, src, is_train);
    delete [] in_ptr;
}

memory::desc InvertedCellSp::diff_src_desc() {
    BatchNormSP* l = ((BatchNormSP*)this->network_.at(this->network_.size() - 1));

    if (l->mode_ == std::string("bnadd"))
        return l->nhwc_;
    else
        return ((BatchNormSP*)this->network_.at(0))->diff_src_desc();
}

void InvertedCellSp::update_backward(memory::desc diff_src_tag) {
    memory::desc back_tag = diff_src_tag;
    int size = network_.size();
    for (int i = size - 1; i >=0; i--) {
        printf("update inv %d", i);
        LayerSp* l = network_.at(i);
        l->update_backward(back_tag);
        back_tag = l->diff_src_desc();
    }

    BatchNormSP* l = ((BatchNormSP*)this->network_.at(this->network_.size() - 1));
    if (l->mode_ == std::string("bnadd")) {
        LayerSp* first_layer = this->network_.at(0);
        if (l->nhwc_ != first_layer->diff_src_desc()) {
            grad_out_reorder_ = true;
            primitive_attr dummy_attr;
            reorder_grad_ = reorder(reorder::primitive_desc(eng_, first_layer->diff_src_desc(), eng_, l->nhwc_, dummy_attr));
        }
    }

}

void InvertedCellSp::backward(float* diff_src_ptr, float* diff_dst_ptr) {
    int size = network_.size();
    float* grad_out_ptr;
    float* grad_out_ptr_aligned;
    float* in_ptr       = diff_dst_ptr; 
    float* in_ptr_aling = diff_dst_ptr;
    for (int i = size - 1; i >= 0; i--) {
        LayerSp* l           = network_.at(i);
        grad_out_ptr         = new float[l->diff_src_desc().get_size() / sizeof(float) + 8];
        grad_out_ptr_aligned = ALIGN32(grad_out_ptr);
        l->backward(grad_out_ptr_aligned, in_ptr_aling);
        if (i != size - 1)
            delete [] in_ptr;
        in_ptr       = grad_out_ptr;
        in_ptr_aling = grad_out_ptr_aligned;
    }

    BatchNormSP* l_bn = ((BatchNormSP*)this->network_.at(this->network_.size() - 1));
    LayerSp* l_first  = this->network_.at(0);
    float* grad1_ptr = in_ptr;
    float* grad1_alig= in_ptr_aling;
    if (l_bn->mode_ == std::string("bnadd")) {
        if (grad_out_reorder_) {
            grad1_ptr = new float[l_bn->nhwc_.get_size() / sizeof(float) + 8];
            grad1_alig= ALIGN32(grad1_ptr);
            reorder_grad_.execute(s_, {{DNNL_ARG_FROM, memory(l_first->diff_src_desc(), eng_, in_ptr_aling)}, 
                                       {DNNL_ARG_TO,   memory(l_bn->diff_src_desc(),    eng_, grad1_alig)}});
        }
    }
    for (int i = 0; i < this->input_size_; i+=8) {
        __m256 grad1 = _mm256_load_ps(&grad1_alig[i]);
        __m256 grad2 = _mm256_load_ps(&diff_dst_ptr[i]);

        __m256 sum   = _mm256_add_ps(grad1, grad2);
        _mm256_stream_ps(&diff_src_ptr[i], sum);
    }

    delete in_ptr;
}


// unblind data mod p, compute activation and write to output buffer
void unblind_internal(float* inp, float* blind, float* out, int num_elements) {
    //printf("unblinddddddddddddddd:%d\n",num_elements);               

    if ((num_elements / 2) % 8 != 0) {
        //printf("slow unblind\n");
        for(size_t i = 0; i < num_elements/2; i += 1) {
            float inputf1 = inp[i];
            float inputf2 = inp[i+num_elements/2];
            float t13 = inputf1 * alfa1p_raw + inputf2 * alfa2p_raw;
            float t23 = inputf1 * beta1p_raw + inputf2 * beta2p_raw;
            
            out[i] = t13;
            out[i+num_elements/2] = t23;
        }
        
        return;
    }
    
    for(size_t i = 0; i < num_elements/2; i += 8) {
        //printf("%d\n", i);
        const __m256 inp8f1 = _mm256_load_ps( &inp[i] );
        const __m256 inp8f2 = _mm256_load_ps( &inp[(num_elements/2)+i] );
        
        
        const __m256 t11 = _mm256_mul_ps(inp8f1, alfa1p);
        const __m256 t12 = _mm256_mul_ps(inp8f2, alfa2p);
        const __m256 t13 = _mm256_add_ps(t12, t11);
        _mm256_stream_ps(&out[i], t13);
        
        const __m256 t21 = _mm256_mul_ps(inp8f1, beta1p);
        const __m256 t22 = _mm256_mul_ps(inp8f2, beta2p);
        const __m256 t23 = _mm256_add_ps(t21, t22);
        
        _mm256_stream_ps(&out[(num_elements/2)+i], t23);
    }
}


template <typename F>
inline void unblind(F func, float* inp, float* blind, float* out, int num_elements) {
    //printf("unblinddddddddddddddd:%d\n", num_elements);
        sgx_time_t start_time;
                sgx_time_t end_time;
                start_time = get_time_force();
                for(size_t i = 0; i < num_elements/2; i += 8) {
                                const __m256 inp8f1 = _mm256_load_ps( &inp[i] );
                                const __m256 inp8f2 = _mm256_load_ps( &inp[(num_elements/2)+i] );

                                
                                const __m256 t11 = _mm256_mul_ps(inp8f1, alfa1p);
                                const __m256 t12 = _mm256_mul_ps(inp8f2, alfa2p);
                                const __m256 t13 = _mm256_add_ps(t12, t11);
                                _mm256_stream_ps(&out[i], func(t13));

                                const __m256 t21 = _mm256_mul_ps(inp8f1, beta1p);
                                const __m256 t22 = _mm256_mul_ps(inp8f2, beta2p);
                                const __m256 t23 = _mm256_add_ps(t21, t22);

                                 _mm256_stream_ps(&out[(num_elements/2)+i], func(t23));
                } 
                end_time = get_time_force();
                elapsed5 = get_elapsed_time(start_time, end_time) + elapsed5;
}

void bias_grad(float* clean_gradient, float* bias_res, bool overwrite, int num_channel, int num_elements, int batch_size) {
    int sum_len = num_elements / num_channel;
    //printf("%d, %d\n", num_channel, num_elements);
    sgx_time_t bias_s = get_time_force();
    if (num_channel % 8 != 0) {
        //printf("slow bias accumulation\n");
        for (int batch_count = 0; batch_count < 2; batch_count++) {
            for (int i = 0; i < num_channel; i++) {
                float prev_res = bias_res[i];
                float eight_res = 0.0;
                for (int j = 0; j <sum_len; j++) {
                    eight_res += clean_gradient[num_channel*j + i];
                }
                if (!overwrite || batch_count != 0) {
                    eight_res +=  prev_res;
                }
                bias_res[i] = eight_res;
            }
            clean_gradient += num_elements;
        }
        return;
    }

    
    for (int batch_count = 0; batch_count < 2; batch_count++) {
        for (int i = 0; i < num_channel; i+=8) {
            __m256 prev_res  = _mm256_load_ps(&bias_res[i]);
            __m256 eight_res = _mm256_set1_ps((float)(0));
            for(int j = 0; j < sum_len; j++) {
                const __m256 input = _mm256_load_ps(&clean_gradient[num_channel*j + i]);
                eight_res = _mm256_add_ps(eight_res, input);
            }
            if (!overwrite || batch_count != 0) {
            eight_res = _mm256_add_ps(eight_res, prev_res);
            }
            _mm256_stream_ps(&bias_res[i], eight_res);
            
        }
        clean_gradient += num_elements;
    } 
    sgx_time_t bias_e = get_time_force();

    time_report[8] += get_elapsed_time(bias_s, bias_e);
}


extern "C" {

        Model<float> model_float;
        std::vector<std::shared_ptr<BatchNormSp<float>>> batchnorms;

        bool slalom_privacy;
        bool slalom_integrity;
        int batch_size;
        aes_stream_state producer_PRG;
        aes_stream_state consumer_PRG;
        std::deque<sgx_aes_gcm_128bit_iv_t*> aes_gcm_ivs;
        std::deque<sgx_aes_gcm_128bit_tag_t*> aes_gcm_macs;
        
        Tensor<float, 1> buffer_t;
        Tensor<float, 1> buffer2_t;
        Tensor<float, 1> buffer3_t;
        Tensor<float, 1> buffer4_t;
        int act_idx;
        bool verbose;
        std::vector<int> activation_idxs;

        
        // matrixs
        Tensor<float, 1> bm, um, igm, gm;

        // matrix buffers
        int mini_batch_size = 3;

        Tensor<float, 1> bias_t;
        #ifdef EIGEN_USE_THREADS
        int n_threads = 1;
        Eigen::ThreadPool pool(n_threads);
        Eigen::ThreadPoolDevice device(&pool, n_threads);
        #endif

        void test_bias_grad() {
            /*
            printf("test encryptions");
            memory::desc desc;
            engine eng;
            stream s;
            int size [4]= {1, 224, 224, 8};
            LayerSp* l =  new BatchNormSP(size, 1, 0.0, 0.0, desc, eng, s);
            void* plain = (void*) new float[128];
            void* cihpe = (void*) new float[256];
            void* dec   = (void*) new float[128];

            sgx_read_rand((uint8_t*) plain, 128 * 4);
            printf("here0");

            l->dump_src_input(cihpe, plain, 128*4);
            printf("here1");
            l->load_src_input(dec,   cihpe, 128*4);
            printf("here2");

            int* orignal = (int*) plain;
            int* decrpted = (int*) dec;

            for (int i = 0; i < 128; i++)
                if (orignal[i] != decrpted[i])
                    printf("shit");
            */
            int conv_src_sz[4]     = {1, 224, 224, 3};
            int conv_dst_sz[4]     = {1, 112, 112, 32};
            int conv_weight_sz[4]  = {3, 3, 3, 32};
            int conv_strides_sz[2] = {2, 2};
            int conv_padding_sz[2] = {1, 1};
            float* weight_data     = new float[864];
            ecall_sgx_conv_create(conv_src_sz, 
                                  conv_dst_sz, 
                                  conv_weight_sz,
                                  conv_strides_sz, 
                                  conv_padding_sz, 
                                  weight_data,
                                  nullptr,  
                                  1);
            printf("here");
        }


        // load a model into the enclave
        void load_model_float(char* model_json, float** filters) {
                model_float.load_model(model_json, filters, false, false);
                #ifdef EIGEN_USE_THREADS
                Eigen::setNbThreads(n_threads);
                #endif
        }

        // load a model in verify-mode
        void load_model_float_verify(char* model_json, float** filters, bool verify_preproc) {
                model_float.load_model(model_json, filters, true, verify_preproc);
                #ifdef EIGEN_USE_THREADS
                Eigen::setNbThreads(n_threads);
                #endif
        }

        void setup_batchnormsp(int* input_shape, bool privacy, float eps, float momentum) {
            array4d shape = {2, input_shape[1], input_shape[2], input_shape[3]};
            BatchNormSp<float>* ptr = new BatchNormSp<float>(std::string("batchnormsp"), shape);
            ptr->update_params(privacy, eps, momentum);
            std::shared_ptr<BatchNormSp<float>> new_layer(ptr);
            batchnorms.push_back(new_layer);
            printf("input shape [%d, %d, %d, %d]\n", 
                    input_shape[0], 
                    input_shape[1], 
                    input_shape[2], 
                    input_shape[3]);
            printf("eps %f, momentum %f privacy %d\n", eps, momentum, privacy); 
        }

        


        void test_handler() {
            array4d input_shape =  {0, 224, 224, 3};
            array4d kernel_shape = {3, 3, 3, 16,};
            MemPool* mem_pool = new MemPool(1, 1);
            Conv2DPure<float> conv = Conv2DPure<float>("name",
                                                       input_shape, 
                                                       kernel_shape, 
                                                       1, 1, 
                                                       Eigen::PaddingType::PADDING_SAME,

                                                       mem_pool, false, false, "shit");

        }



                        

        
        double enclave_fwd_time = 0.0;
        double enclave_bwd_time = 0.0;
        void batchnormSp_dark(float* output, float* inp, float* means, float* skip_input, float* act_src, 
                              int batch_size, const char* act_mode) {
            static int batchnormsp_ptr = 0;
            //printf("------------------------------------------------------\n");
            //printf("batch norm ptr %d / %d\n", batchnormsp_ptr, batchnorms.size());
            double s = get_time_force();
            if (batchnormsp_ptr == 0) {
                enclave_fwd_time = 0.0;
                enclave_bwd_time = 0.0;
            }
            batchnorms.at(batchnormsp_ptr)->fwd(output, inp, means, skip_input, act_src,
                                                um_buf, batch_size, mini_batch_size, act_mode);
            
            // reset pointer when possible
            batchnormsp_ptr++;
            if (batchnormsp_ptr >= batchnorms.size()) {
                batchnormsp_ptr = 0;
            }
            double e = get_time_force();  
            enclave_fwd_time += get_elapsed_time(s, e);
        }

        void batchnormSp_dark_back(float* grad_out, float* grad, float* inp, float* skip_src, float* act_src) {
            double s = get_time_force();


            static int batchnormsp_ptr = batchnorms.size() - 1;
            //printf("------------------------------------------------------\n");
            //printf("batch norm ptr %d / %d\n", batchnormsp_ptr, batchnorms.size());
            batchnorms.at(batchnormsp_ptr)->bwd(grad_out, grad, inp, skip_src, act_src);
            batchnormsp_ptr--;
            if (batchnormsp_ptr < 0) {
                batchnormsp_ptr = batchnorms.size() - 1;
                ecall_print_timing();
            }

            double e = get_time_force();  
            enclave_bwd_time += get_elapsed_time(s, e);

        }

        // forward pass
        void predict_float(float* input, float* output, int batch_size) {

                array4d input_dims = {batch_size,
                                    model_float.input_shape[0],
                                    model_float.input_shape[1],
                                    model_float.input_shape[2]};

                int input_size = batch_size * model_float.input_shape[0] * model_float.input_shape[1] * model_float.input_shape[2];
                assert(input_size != 0);

                // copy input into enclave
                float* inp_copy = model_float.mem_pool->alloc<float>(input_size);
                std::copy(input, input + input_size, inp_copy);

                auto map_in = TensorMap<float, 4>(inp_copy, input_dims);
                TensorMap<float, 4>* in_ptr = &map_in;

                sgx_time_t start_time;
                sgx_time_t end_time;
                double elapsed;

                start_time = get_time_force();

                // loop over all layers
                for (int i=0; i<model_float.layers.size(); i++) {
                        if (TIMING) {
                                printf("before layer %d (%s)\n", i, model_float.layers[i]->name_.c_str());
                        }
                        //printf("before %s\n", model_float.layers[i]->name_.c_str());
                        sgx_time_t layer_start = get_time();
                        #ifdef EIGEN_USE_THREADS
                        auto temp_output = model_float.layers[i]->apply(*in_ptr, (void*) &device);
                        #else
                        auto temp_output = model_float.layers[i]->apply(*in_ptr);
                        #endif

                        in_ptr = &temp_output;
                        //printf("after %d\n", i);
                        sgx_time_t layer_end = get_time();
                        if (TIMING) {
                                printf("layer %d required %4.4f sec\n", i, get_elapsed_time(layer_start, layer_end));
                        }
                }

                // copy output outside enclave
                std::copy(((float*)in_ptr->data()), ((float*)in_ptr->data()) + ((int)in_ptr->size()), output);
                model_float.mem_pool->release(in_ptr->data());

                end_time = get_time_force();
                printf("total time: %4.4f sec\n", get_elapsed_time(start_time, end_time));
        }

        // forward pass with verification
        void predict_verify_float(float* input, float* output, float** aux_data, int batch_size) {
                array4d input_dims = {batch_size,
                                model_float.input_shape[0],
                                model_float.input_shape[1],
                                model_float.input_shape[2]
                };

                int input_size = batch_size * model_float.input_shape[0] * model_float.input_shape[1] * model_float.input_shape[2];
                assert(input_size != 0);

                float* inp_copy = model_float.mem_pool->alloc<float>(input_size);
                std::copy(input, input + input_size, inp_copy);

                auto map_in = TensorMap<float, 4>(inp_copy, input_dims);

                TensorMap<float, 4>* in_ptr = &map_in;
                sgx_time_t start_time;
                sgx_time_t end_time;
                double elapsed;

                start_time = get_time_force();

                int linear_idx = 0;
                for (int i=0; i<model_float.layers.size(); i++) {
                        if (TIMING) {
                                printf("before layer %d (%s)\n", i, model_float.layers[i]->name_.c_str());
                        }

                        #ifdef USE_SGX
                        size_t aux_size = batch_size * model_float.layers[i]->output_size();
                        assert(sgx_is_outside_enclave(aux_data[linear_idx], aux_size * sizeof(float)));
                        #endif

                        sgx_time_t layer_start = get_time();
                        #ifdef EIGEN_USE_THREADS
                        auto temp_output = model_float.layers[i]->fwd_verify(*in_ptr, aux_data, linear_idx, (void*) &device);
                        #else
                        auto temp_output = model_float.layers[i]->fwd_verify(*in_ptr, aux_data, linear_idx);
                        #endif

                        in_ptr = &temp_output;

                        linear_idx += model_float.layers[i]->num_linear();

                        sgx_time_t layer_end = get_time();
                        if (TIMING) {
                                printf("layer %d required %4.4f sec\n", i, get_elapsed_time(layer_start, layer_end));
                        }
                }
                
                std::copy(((float*)in_ptr->data()), ((float*)in_ptr->data()) + ((int)in_ptr->size()), output);
                model_float.mem_pool->release(in_ptr->data());

                end_time = get_time_force();
                printf("total time: %4.4f sec\n", get_elapsed_time(start_time, end_time));
        }

        void print_buffer(float* buf, int size) {
            for (int i = 0; i < size; i++)
                printf("%f ", buf[i]);
            printf("\n");
        }
    
        void fill_parameter_matrix(float* bm_ptr, float* um_ptr, float* gm_ptr,
                                                             float* igm_ptr, int internal_batch_size) {
            bm = Tensor<float, 1>(internal_batch_size * internal_batch_size);
            um = Tensor<float, 1>(internal_batch_size * internal_batch_size);
            // gm only uses two gradients                                                                                                    
            gm = Tensor<float, 1>(internal_batch_size * internal_batch_size);
            igm = Tensor<float, 1>(internal_batch_size * internal_batch_size);
            int m_size =internal_batch_size * internal_batch_size;
            mini_batch_size = internal_batch_size;
            bm_buf = bm.data();
            um_buf = um.data();
            gm_buf = gm.data();
            igm_buf = igm.data();
            std::copy(bm_ptr, bm_ptr+m_size, bm_buf);
            std::copy(um_ptr, um_ptr+m_size, um_buf);
            std::copy(gm_ptr, gm_ptr+m_size, gm_buf);
            std::copy(igm_ptr, igm_ptr+m_size, igm_buf);
            
        }
        // initialize SLALOM protocol
        void slalom_init(bool integrity, bool privacy, int batch) {

                slalom_privacy = privacy;
                slalom_integrity = integrity;
                batch_size = batch;
                // batch slalom init stop here for now
                if (batch % mini_batch_size != 0) {
                    printf("batch size %d is not divisible by %d\n", batch, mini_batch_size);
                    assert(false);
                }
                time_report = new double[12];
                for (int i = 0; i < 12; i++)
                        time_report[i] = 0.0;
                ecall_dnnl_batch(batch / mini_batch_size);
                
                // TODO pass max size as a parameter
                buffer_t = Tensor<float, 1>(3*224*224*64);
                buffer2_t = Tensor<float, 1>(3*224*224*64);
                buffer3_t = Tensor<float, 1>(3*224*224*64);
                buffer4_t = Tensor<float, 1>(3*224*224*64);
                bias_t = Tensor<float, 1>(4096);
                bias_grads = bias_t.data();
                temp_buffer = buffer_t.data();
                temp_buffer2 = buffer2_t.data();
                temp_buffer3 = buffer3_t.data();
                temp_buffer4 = buffer4_t.data();

                printf("SLALOM INIT: PRIVACY = %d, INTEGRITY = %d, batch_size = %d\n", privacy, integrity, batch_size);


                if (model_float.layers.size() <= 0) {
                    printf("model float size is zero\n");
                    assert(model_float.layers.size() > 0);
                }
                std::vector<std::shared_ptr<Layer<float>>> new_layers;  
                for (int i=0; i<model_float.layers.size(); i++) {
                        if (dynamic_cast<ResNetBlock<float>*>(model_float.layers[i].get()) != nullptr) {
                                auto resblock = dynamic_cast<ResNetBlock<float>*>(model_float.layers[i].get());
                                for (int j=0; j<resblock->get_path1().size(); j++) {
                                        new_layers.push_back(resblock->get_path1()[j]);
                                }
                                for (int j=0; j<resblock->get_path2().size(); j++) {
                                        new_layers.push_back(resblock->get_path2()[j]);
                                }
                        } else {
                                new_layers.push_back(model_float.layers[i]);
                        }
                }
                model_float.layers.assign(new_layers.begin(), new_layers.end());

                printf("MODEL LAYERS (%lu layers):\n", model_float.layers.size());
                // get the indices of all the activation layers
                for (int i=0; i<model_float.layers.size(); i++) {
                        printf("%s\n", model_float.layers[i].get()->name_.c_str());
                        if (dynamic_cast<Activation<float>*>(model_float.layers[i].get()) != nullptr) {
                                activation_idxs.push_back(i);
                        }
                }

                printf("=========\n");
                printf("ACTIVATION IDXS:\n");
                for (int i=0; i<activation_idxs.size(); i++) {
                        printf("%d\n", activation_idxs[i]);
                }
                printf("=========\n");

                act_idx = 0;
                verbose = false;
                time = true;
        }

        // blind an input
        void slalom_blind_input(float* inp, float* out, int size) {
                act_idx = 0;
                
                int size_per_batch = size/batch_size;
                int mini_batch_len = size_per_batch * mini_batch_size;
                if (batch_size % mini_batch_size != 0) {
                    printf("batch size error got %d\n", batch_size);
                    assert(false);
                }

                int iter = batch_size / mini_batch_size;
                for (int i = 0; i < iter; i++) {
                    blind_matrix_func(inp, out, mini_batch_len, 1);
                    inp += mini_batch_len;
                    out += mini_batch_len;
                }
        }


        void unblind_bias(float* inp, float*bias, float* out, int batch_len, int bias_len) {
                const __m256 m00 = _mm256_set1_ps(um_buf[0]);
                const __m256 m01 = _mm256_set1_ps(um_buf[1]);
                const __m256 m02 = _mm256_set1_ps(um_buf[2]);
                const __m256 m10 = _mm256_set1_ps(um_buf[3]);
                const __m256 m11 = _mm256_set1_ps(um_buf[4]);
                const __m256 m12 = _mm256_set1_ps(um_buf[5]);
                if (batch_len % 8 != 0 || bias_len % 8 != 0) {
                        printf("bias addtion len error\n");
                        assert(false);
                }

                if (batch_len % bias_len != 0) {
                        printf("batch len and bias len not divisible\n");
                        assert(false);
                }
                float* first_ptr = inp;
                float* second_ptr = inp + batch_len;
                float* third_ptr = inp + batch_len * 2;
                float* bias_ptr = bias;
                float* res1_ptr = out;
                float* res2_ptr = out + batch_len;

                int iter = batch_len / 8;
                for (int i = 0; i < batch_len; i+=8) {
                        const __m256 f = _mm256_load_ps(&first_ptr[i]);
                        const __m256 s = _mm256_load_ps(&second_ptr[i]);
                        const __m256 t = _mm256_load_ps(&third_ptr[i]);
                        const __m256 biast = _mm256_load_ps(&bias[i%bias_len]);

                        __m256 fs = _mm256_mul_ps(m00, f);
                        __m256 ss = _mm256_mul_ps(m01, s);
                        __m256 ts = _mm256_mul_ps(m02, t);

                        __m256 res = _mm256_add_ps(fs, ss);
                        res = _mm256_add_ps(res, ts);
                        res = _mm256_add_ps(res, biast);
                        _mm256_stream_ps(&res1_ptr[i], res);

                        fs = _mm256_mul_ps(m10, f);
                        ss = _mm256_mul_ps(m11, s);
                        ts = _mm256_mul_ps(m12, t);

                        res = _mm256_add_ps(fs, ss);
                        res = _mm256_add_ps(res, ts);
                        res = _mm256_add_ps(res, biast);
                        _mm256_stream_ps(&res2_ptr[i], res);
                }
        }

        void unblind_bias_relu_wb(float* inp, float*bias, float* out, float* relu_src, int batch_len, int bias_len) {
            const __m256 m00 = _mm256_set1_ps(um_buf[0]);
            const __m256 m01 = _mm256_set1_ps(um_buf[1]);
            const __m256 m02 = _mm256_set1_ps(um_buf[2]);
            const __m256 m10 = _mm256_set1_ps(um_buf[3]);
            const __m256 m11 = _mm256_set1_ps(um_buf[4]);
            const __m256 m12 = _mm256_set1_ps(um_buf[5]);
            if (batch_len % 8 != 0 || bias_len % 8 != 0) {
                    printf("bias addtion len error\n");
                    assert(false);
            }

            if (batch_len % bias_len != 0) {
                    printf("batch len and bias len not divisible\n");
                    assert(false);
            }
            float* first_ptr = inp;
            float* second_ptr = inp + batch_len;
            float* third_ptr = inp + batch_len * 2;
            float* bias_ptr = bias;
            float* res1_ptr = out;
            float* res2_ptr = out + batch_len;


            float* first_relu = relu_src;
            float* second_relu = relu_src + batch_len;

            int iter = batch_len / 8;
            for (int i = 0; i < batch_len; i+=8) {
                const __m256 f = _mm256_load_ps(&first_ptr[i]);
                const __m256 s = _mm256_load_ps(&second_ptr[i]);
                const __m256 t = _mm256_load_ps(&third_ptr[i]);
                const __m256 biast = _mm256_load_ps(&bias[i%bias_len]);

                __m256 fs = _mm256_mul_ps(m00, f);
                __m256 ss = _mm256_mul_ps(m01, s);
                __m256 ts = _mm256_mul_ps(m02, t);

                __m256 res = _mm256_add_ps(fs, ss);
                res = _mm256_add_ps(res, ts);
                res = _mm256_add_ps(res, biast);
                _mm256_stream_ps(&first_relu[i], res);

                res = _mm256_max_ps(res, zero8f);

                _mm256_stream_ps(&res1_ptr[i], res);

                fs = _mm256_mul_ps(m10, f);
                ss = _mm256_mul_ps(m11, s);
                ts = _mm256_mul_ps(m12, t);

                res = _mm256_add_ps(fs, ss);
                res = _mm256_add_ps(res, ts);
                res = _mm256_add_ps(res, biast);
                _mm256_stream_ps(&second_relu[i], res);
                res = _mm256_max_ps(res, zero8f);
                _mm256_stream_ps(&res2_ptr[i], res);
            }
        }

        void unblind_bias_relu(float* inp, float*bias, float* out, float* relu_src, int batch_len, int bias_len) {
            const __m256 m00 = _mm256_set1_ps(um_buf[0]);
            const __m256 m01 = _mm256_set1_ps(um_buf[1]);
            const __m256 m02 = _mm256_set1_ps(um_buf[2]);
            const __m256 m10 = _mm256_set1_ps(um_buf[3]);
            const __m256 m11 = _mm256_set1_ps(um_buf[4]);
            const __m256 m12 = _mm256_set1_ps(um_buf[5]);
            if (batch_len % 8 != 0 || bias_len % 8 != 0) {
                    printf("bias addtion len error\n");
                    assert(false);
            }

            if (batch_len % bias_len != 0) {
                    printf("batch len and bias len not divisible\n");
                    assert(false);
            }
            float* first_ptr = inp;
            float* second_ptr = inp + batch_len;
            float* third_ptr = inp + batch_len * 2;
            float* bias_ptr = bias;
            float* res1_ptr = out;
            float* res2_ptr = out + batch_len;

            int iter = batch_len / 8;
            for (int i = 0; i < batch_len; i+=8) {
                const __m256 f = _mm256_load_ps(&first_ptr[i]);
                const __m256 s = _mm256_load_ps(&second_ptr[i]);
                const __m256 t = _mm256_load_ps(&third_ptr[i]);
                const __m256 biast = _mm256_load_ps(&bias[i%bias_len]);

                __m256 fs = _mm256_mul_ps(m00, f);
                __m256 ss = _mm256_mul_ps(m01, s);
                __m256 ts = _mm256_mul_ps(m02, t);

                __m256 res = _mm256_add_ps(fs, ss);
                res = _mm256_add_ps(res, ts);
                res = _mm256_add_ps(res, biast);
                res = _mm256_max_ps(res, zero8f);

                _mm256_stream_ps(&res1_ptr[i], res);

                fs = _mm256_mul_ps(m10, f);
                ss = _mm256_mul_ps(m11, s);
                ts = _mm256_mul_ps(m12, t);

                res = _mm256_add_ps(fs, ss);
                res = _mm256_add_ps(res, ts);
                res = _mm256_add_ps(res, biast);
                res = _mm256_max_ps(res, zero8f);
                _mm256_stream_ps(&res2_ptr[i], res);
            }
        }

        void bias_add(float* inp, float* bias, int batch_len, int bias_len) {
                if (batch_len % 8 != 0 || bias_len % 8 != 0) {
                        printf("bias addtion len error\n");
                        assert(false);
                }

                array1d bcast = {batch_len / bias_len};
                array1d one_d = {batch_len};
                auto biasmap = TensorMap<float, 1>(bias, bias_len);
                array4d input_size = {1, 1, batch_len / bias_len, bias_len};
                for (int i = 0; i < 2; i++) {
                        auto out_map = TensorMap<float, 4>(inp, input_size);
                        out_map.reshape(one_d) = out_map.reshape(one_d) + biasmap.broadcast(bcast).reshape(one_d);
                        inp += batch_len;
                }
        }


        __m256 inline relu_grad(__m256 grad, __m256 relu_src) {
            return _mm256_and_ps(_mm256_cmp_ps(zero8f, relu_src, 0x11), grad);
        }

        void merged_unblind_scramble_reluback(float* grad, float* relu_src, float* grad_out, 
                                     float* grad_in, int batch_len) {
            const __m256 gm00 = _mm256_set1_ps(gm_buf[0]);
            const __m256 gm11 = _mm256_set1_ps(gm_buf[4]);
            

            const __m256 um00 = _mm256_set1_ps(um_buf[0]);
            const __m256 um01 = _mm256_set1_ps(um_buf[1]);
            const __m256 um02 = _mm256_set1_ps(um_buf[2]);
            const __m256 um10 = _mm256_set1_ps(um_buf[3]);
            const __m256 um11 = _mm256_set1_ps(um_buf[4]);
            const __m256 um12 = _mm256_set1_ps(um_buf[5]);

            if (batch_len % 8 != 0) {
                printf("num of elements in array is off\n");
                assert(false);
            }

            float* grad_f = grad;
            float* grad_s = grad + batch_len;
            float* grad_out1 = grad_out;
            float* grad_out2 = grad_out + batch_len;

            float* relu_sf = relu_src;
            float* relu_ss = relu_src + batch_len;
            float* relu_st = relu_src + batch_len * 2;

            float* grad_inf = grad_in;
            float* grad_ins = grad_in + batch_len;
            
            for (int i = 0; i < batch_len; i+=8) {
                const __m256 grad_fm = _mm256_load_ps(&grad_f[i]);
                const __m256 grad_sm = _mm256_load_ps(&grad_s[i]); 
                
                const __m256 relu_fm = _mm256_load_ps(&relu_sf[i]);
                const __m256 relu_sm = _mm256_load_ps(&relu_ss[i]); 
                const __m256 relu_tm = _mm256_load_ps(&relu_st[i]); 

                // the first image
                __m256 grad_res = _mm256_mul_ps(gm00, grad_fm);
                _mm256_stream_ps(&grad_out1[i], grad_res);

                __m256 relu_fr = _mm256_mul_ps(um00, relu_fm);
                __m256 relu_sr = _mm256_mul_ps(um01, relu_sm);
                __m256 relu_tr = _mm256_mul_ps(um02, relu_tm);

                __m256 relu_src_res = _mm256_add_ps(relu_fr, relu_sr);
                relu_src_res = _mm256_add_ps(relu_src_res, relu_tr);

                
                _mm256_stream_ps(&grad_inf[i], relu_grad(grad_res, relu_src_res));
                // the second image
                grad_res = _mm256_mul_ps(gm11, grad_sm);
                _mm256_stream_ps(&grad_out2[i], grad_res);

                relu_fr = _mm256_mul_ps(um10, relu_fm);
                relu_sr = _mm256_mul_ps(um11, relu_sm);
                relu_tr = _mm256_mul_ps(um12, relu_tm);

                relu_src_res = _mm256_add_ps(relu_fr, relu_sr);
                relu_src_res = _mm256_add_ps(relu_src_res, relu_tr);
                _mm256_stream_ps(&grad_ins[i], relu_grad(grad_res, relu_src_res));
            }
        }


        void matrix_mul2x3(float* inp, float* out, float* matrx_buf, int batch_len) {
            const __m256 m00 = _mm256_set1_ps(matrx_buf[0]);
            //const __m256 m01 = _mm256_set1_ps(matrx_buf[1]);
            //const __m256 m10 = _mm256_set1_ps(matrx_buf[3]);
            const __m256 m11 = _mm256_set1_ps(matrx_buf[4]);
            const __m256 m20 = _mm256_set1_ps(matrx_buf[6]);
            const __m256 m21 = _mm256_set1_ps(matrx_buf[7]);
            
            if (batch_len % 8 != 0) {
                printf("num of elements in array is off\n");
                assert(false);
            }

            float* first_ptr = inp;
            float* second_ptr = inp + batch_len;

            float* res1_ptr = out;
            float* res2_ptr = out + batch_len;
            float* res3_ptr = out + 2*batch_len;

            for (int i = 0; i < batch_len; i+=8) {
                const __m256 f = _mm256_load_ps(&first_ptr[i]);
                const __m256 s = _mm256_load_ps(&second_ptr[i]);
                

                __m256 res = _mm256_mul_ps(m00, f);
                //__m256 ss = _mm256_mul_ps(m01, s);
                //__m256 res = _mm256_add_ps(fs, ss);

                _mm256_stream_ps(&res1_ptr[i], res);
                
                //res = _mm256_mul_ps(m10, f);
                res = _mm256_mul_ps(m11, s);
                //res = _mm256_add_ps(fs, ss);
                _mm256_stream_ps(&res2_ptr[i], res);

                __m256 fs = _mm256_mul_ps(m20, f);
                __m256 ss = _mm256_mul_ps(m21, s);
                res = _mm256_add_ps(fs, ss);
                _mm256_stream_ps(&res3_ptr[i], res);
            }
        }

        void matrix_mul3x3(float* inp, float* out, float* matrx_buf, float* noise, int batch_len, int mode) {
            const __m256 m00 = _mm256_set1_ps(matrx_buf[0]);
            const __m256 m01 = _mm256_set1_ps(matrx_buf[1]);
            const __m256 m02 = _mm256_set1_ps(matrx_buf[2]);
            const __m256 m10 = _mm256_set1_ps(matrx_buf[3]);
            const __m256 m11 = _mm256_set1_ps(matrx_buf[4]);
            const __m256 m12 = _mm256_set1_ps(matrx_buf[5]);
            const __m256 m20 = _mm256_set1_ps(matrx_buf[6]);
            const __m256 m21 = _mm256_set1_ps(matrx_buf[7]);
            const __m256 m22 = _mm256_set1_ps(matrx_buf[8]);

            if (batch_len % 8 != 0) {
                printf("num of elements in array is off\n");
                assert(false);
            }
        
            int iter_num = batch_len / 8;
            float* first_ptr = inp;
            float* second_ptr = inp + batch_len;
            float* third_ptr = inp + batch_len * 2;
            float* res1_ptr = out;
            float* res2_ptr = out + batch_len;
            float* res3_ptr = out + 2*batch_len;

            if (mode == 1)
                third_ptr = noise;

            for (int i = 0; i < batch_len; i+=8) {
                const __m256 f = _mm256_load_ps(&first_ptr[i]);
                const __m256 s = _mm256_load_ps(&second_ptr[i]);
                const __m256 t = _mm256_load_ps(&third_ptr[i]);

                __m256 fs = _mm256_mul_ps(m00, f);
                __m256 ss = _mm256_mul_ps(m01, s);
                __m256 ts = _mm256_mul_ps(m02, t);
            
                __m256 res = _mm256_add_ps(fs, ss);
                res = _mm256_add_ps(res, ts);
                _mm256_stream_ps(&res1_ptr[i], res);

                fs = _mm256_mul_ps(m10, f);
                ss = _mm256_mul_ps(m11, s);
                ts = _mm256_mul_ps(m12, t);
                
                res = _mm256_add_ps(fs, ss);
                res = _mm256_add_ps(res, ts);
                _mm256_stream_ps(&res2_ptr[i], res);

                fs = _mm256_mul_ps(m20, f);
                ss = _mm256_mul_ps(m21, s);
                ts = _mm256_mul_ps(m22, t);
            
                res = _mm256_add_ps(fs, ss);
                res = _mm256_add_ps(res, ts);
                _mm256_stream_ps(&res3_ptr[i], res);
            }
        }

        // mode : 1 to scramble & 0 to unscramble 
        void scramble_matrix_func(float* inp, float* out, int mini_batch_len, int mode) {
                sgx_time_t scr_s = get_time_force();
                int single_batch_len = mini_batch_len / mini_batch_size;
                float* mat = igm_buf;

                if (mode == 1)
                        mat = gm_buf;
                matrix_mul2x3(inp, out, mat, single_batch_len);
                sgx_time_t scr_e = get_time_force();
                time_report[7] += get_elapsed_time(scr_s, scr_e);
        }
    
        // mode : 1 to blind & 0 to unblind
        void blind_matrix_func(float* inp, float* out, int mini_batch_len, int mode) {
                sgx_time_t blind_start;
                if (time) {
                 blind_start = get_time_force();
                }
                int single_batch_len = mini_batch_len / mini_batch_size;
                float* mat = um_buf;
                // noise generation
                if (mode == 1) {
                        //sgx_time_t gs = get_time_force();
                        //get_r(&producer_PRG, (unsigned char*) temp_buffer, single_batch_len*sizeof(float), 2);
                        //sgx_time_t end = get_time_force();
                        //time_report[4] += get_elapsed_time(gs, end);        
                        //for (int i = 0; i < single_batch_len; i++)
                        //temp_buffer[i] = 10.0;
                        mat = bm_buf;
                }

                // matrix multiplication
                matrix_mul3x3(inp, out, mat, temp_buffer, single_batch_len, mode);
                

            sgx_time_t end;
            if (time) {
                end = get_time_force();
                if (mode == 1)
                    time_report[3] += get_elapsed_time(blind_start, end);  
                else
                    time_report[4] += get_elapsed_time(blind_start, end);  
            }
                        
        }

    
        // blind an input buffer and write to output
        void slalom_blind_input_internal(float* inp, float* out, int size, float* temp) {
            //printf("blindddddd_input_internal:::::::::::::::::::::::%d\n", size);
            sgx_time_t start_time;
            sgx_time_t end_time;
            double elapsed;
            
            start_time = get_time();
                
            int num_bytes = size * sizeof(float);
            
            end_time = get_time();
            elapsed = get_elapsed_time(start_time, end_time);
            if (TIMING) {
                printf("\trandgen of %d bytes required %f sec\n", num_bytes, elapsed);
            }

            TensorMap<float, 1> inp_map(inp, size);
            TensorMap<float, 1> out_map(out, size);
            TensorMap<float, 1> r_map((float*) temp, size);
            
            if (verbose) {
                Tensor<double, 0> res;
                res = r_map.template cast<double>().minimum();
                printf("min(r) = %f\n", res.data()[0]);
                res = r_map.template cast<double>().maximum();
                printf("max(r) = %f\n", res.data()[0]);
                res = r_map.template cast<double>().abs().sum();
                printf("sum(abs(r)) = %f\n", res.data()[0]);
            }

            start_time = get_time();  
            
            //assert(size % 8 == 0);
            //assert ((long int)inp % 32 == 0);
            //assert ((long int)temp % 32 == 0);
            //assert ((long int)out % 32 == 0);

            if ((size / 2) % 8 != 0) {
                printf("slow blind\n");
                for(size_t i = 0; i < size/2; i += 1) {
                    float inputf1 = inp[i];
                    float inputf2 = inp[i+size/2];
                    float t13 = inputf1 * alfa1_raw + inputf2 * alfa2_raw;
                    float t23 = inputf1 * beta1_raw + inputf2 * beta2_raw;
                        
                    out[i] = t13;
                    out[i+size/2] = t23;
                }
                
                return;
            }

                
            // loop over data and add blinding factors (mod p)
            size_t i = 0;
            
            for(; i < size/2; i += 8) {
                const __m256 inp8f1 = _mm256_load_ps( &inp[i] );                         // unblinded input
                const __m256 inp8f2 = _mm256_load_ps( &inp[size/2+i] );   
                
                const __m256 t11 = _mm256_mul_ps(alfa1, inp8f1);//t11 = alfa1*x1
                const __m256 t12 = _mm256_mul_ps(alfa2, inp8f2);//t12 = alfa2*x2
                const __m256 f1 = _mm256_add_ps(t11, t12);//f1 = t11 + t12
                
                const __m256 t21 = _mm256_mul_ps(beta1, inp8f1);//t21 = beta1*x1
                const __m256 t22 = _mm256_mul_ps(beta2, inp8f2);//t22 = beta2*x2
                const __m256 f2 = _mm256_add_ps(t21, t22);//f2 = t21 + t22      
                
                _mm256_stream_ps(&out[i], f1);
                _mm256_stream_ps(&out[(size/2)+i], f2);
            }

            if (verbose) {
                Tensor<double, 0> res;
                res = out_map.template cast<double>().minimum();
                printf("min(new blinded) = %f\n", res.data()[0]);
                res = out_map.template cast<double>().maximum();
                printf("max(new blinded) = %f\n", res.data()[0]);
                res = out_map.template cast<double>().abs().sum();
                printf("sum(abs(new blinded)) = %f\n", res.data()[0]);
            }
                
            end_time = get_time();
            elapsed = get_elapsed_time(start_time, end_time);
            if (TIMING) {
                printf("\tblinding of size %d required %f sec\n", num_bytes, elapsed);
            }
        }
    

        // get a blinding factor
        // TODO: this is obviously insecure but we currently compute the unblinding factors outside of the enclave for simplicity
        void slalom_get_r(float* out, int size) {
                if (slalom_privacy) {
                        int num_bytes = size * sizeof(float) / batch_size;

                        for(int i=0; i<batch_size; i++) {
                                //Hanieh Removed: get_r(&producer_PRG, ((unsigned char*) out) + i * num_bytes, num_bytes, 9);

                                if (verbose) {
                                        TensorMap<float, 1> r_map(out, size);
                                        Tensor<double, 0> res;
                                        res = r_map.template cast<double>().minimum();
                                        printf("min(r) = %f\n", res.data()[0]);
                                        res = r_map.template cast<double>().maximum();
                                        printf("max(r) = %f\n", res.data()[0]);
                                        res = r_map.template cast<double>().abs().sum();
                                        printf("sum(abs(r)) = %f\n", res.data()[0]);
                                }
                        }
                }
        }

        // get an unblinding factor, encrypt it, store its MAC and write the ciphertext outside the enclave
        // TODO: this is obviously insecure but we currently compute the unblinding factors outside of the enclave for simplicity
        void slalom_set_z(float* z, float* dest, int size) {
                int num_bytes = size * sizeof(float) / batch_size;

                sgx_time_t start_time;
                sgx_time_t end_time;
                double elapsed;

                for(int i=0; i<batch_size; i++) {
                        //TODO should be randomized
                        sgx_aes_gcm_128bit_iv_t *iv = (sgx_aes_gcm_128bit_iv_t*)new sgx_aes_gcm_128bit_iv_t;
                        sgx_aes_gcm_128bit_tag_t *mac = (sgx_aes_gcm_128bit_tag_t*)new sgx_aes_gcm_128bit_tag_t;

                        start_time = get_time();
                        encrypt((uint8_t *) z + i * num_bytes, num_bytes, ((uint8_t *) dest) + i * num_bytes, iv, mac);

                        end_time = get_time();
                        elapsed = get_elapsed_time(start_time, end_time);
                        if (TIMING) {
                                printf("encrypt of size %d required %f sec\n", num_bytes, elapsed);
                        }

                        aes_gcm_ivs.push_back(iv);
                        aes_gcm_macs.push_back(mac);
                }
        }

        // compute a ReLU, ReLU6 or fused AvgPool+ReLU on blinded data
        void slalom_relu(float* inp, float* out, float* blind, float* relu_src, int num_elements, char* activation) {
            //printf("slalom_relu %d / %d\n", act_idx, activation_idxs.size());
            sgx_time_t start0_time;
            sgx_time_t srelu_start;
            sgx_time_t end0_time;
            static bool first_layer = true;

            if (first_layer) {
                for (int i = 0; i < 12; i++)
                    time_report[i] = 0.0;
                first_layer = false;
            }
            if (time) {
                srelu_start = get_time_force();
            }

            int layer_idx = activation_idxs[act_idx];
            auto curr_layer = model_float.layers[layer_idx];
            auto prev_layer = model_float.layers[layer_idx - 1];
            std::shared_ptr<Layer<float>> next_layer = nullptr;
            act_idx += 1;
            if (act_idx == activation_idxs.size())
                act_idx = 0;
            if (dynamic_cast<Activation<float>*>(curr_layer.get()) == nullptr) {
                printf("curr layer is null pointer\n");
                assert(false);
            }
            
            
            if (slalom_integrity) {
                // we only handle convolution layers for integrity checks
                assert(dynamic_cast<Conv2D<float>*>(prev_layer.get()) != nullptr);
                
                if (next_layer != nullptr) {
                    // skip reshape layer for MobileNet
                    if (dynamic_cast<Conv2D<float>*>(next_layer.get()) == nullptr) {
                        assert(dynamic_cast<Reshape<float>*>(next_layer.get()) != nullptr);
                        next_layer = model_float.layers[layer_idx + 2];
                    }
                    assert(dynamic_cast<Conv2D<float>*>(next_layer.get()) != nullptr);
                }
            }
            
                
                
            std::string act(activation);
            
            //printf("slalom_relu\n");
            if (layer_idx < model_float.layers.size() -1) {
                next_layer = model_float.layers[layer_idx + 1];
                if (verbose) {
                    printf("\nin activation %s: prev layer: %s, curr layer: %s, next_layer: %s\n",
                                 activation, prev_layer->name_.c_str(), curr_layer->name_.c_str(), next_layer->name_.c_str());
                }
            } else {
                if (verbose) {
                    printf("\nin activation %s: prev layer: %s, curr layer: %s\n",
                                 activation, prev_layer->name_.c_str(), curr_layer->name_.c_str());
                }
            }
                
                array4d out_shape = prev_layer->output_shape();
                int h = std::max((int) out_shape[1], 1);
                int w = std::max((int) out_shape[2], 1);
                int ch = std::max((int) out_shape[3], 1);
                
                int batch = num_elements / (h*w*ch);
                int mini_batch_len = num_elements / batch_size * mini_batch_size;
                int num_out_elements = mini_batch_len;      

                int iter = batch / mini_batch_size;
                //printf("relu %d %d %d\n", batch, iter, mini_batch_len);
                for (int i = 0; i<iter; i++) {
                        TensorMap<float, 1> inp_map(inp, mini_batch_len);
                        TensorMap<float, 1> blind_map((float*) temp_buffer, mini_batch_len);

                        sgx_time_t start_time;
                        sgx_time_t end_time;
                        double elapsed;

                        if (slalom_privacy) {
                            
                                Tensor<double, 0> res;


                                start_time = get_time();

                                if (num_elements % 8 != 0 || (long int)inp % 32 != 0) {
                                    printf("num elements %d or inp %lu error\n", num_elements, (long int)inp);
                                    assert(false);
                                }

                                //("before activation check\n");
                                if (act == "relu" || act == "relu6") {          

                                        //printf("layer shape: %d, %d, %d, %d\n", batch, h, w, ch);

                                        // unblind intermediate features
                                        // mode 0 to unblind
                                        sgx_time_t b_start; 
                                        if (time) { 
                                            b_start = get_time_force();
                                        }
                                        unblind_bias_relu_wb(inp, blind, temp_buffer2, temp_buffer3, mini_batch_len / mini_batch_size, ch);
                                        

                                        sgx_time_t b_end; 
                                        if (time) {
                                            b_end = get_time_force();
                                            time_report[0] += get_elapsed_time(b_start, b_end);
                                        }

                                        array4d relu_inp_shape = {mini_batch_size - 1, h, w, ch};
                                        auto out_map = TensorMap<float, 4>(temp_buffer2, relu_inp_shape);
                                        
                                        blind_matrix_func(temp_buffer3, relu_src, mini_batch_len, 1);
                                        //printf("after cwise relu\n");
                                        blind_matrix_func(temp_buffer2, out, mini_batch_len, 1);
                                        //printf("got out\n");
                                } else if (act == "avgpoolrelu" || act == "avgpoolrelu6") {

                                    // special case for fused AvgPool+ReLU used in MobileNet
                                    if (act == "avgpoolrelu") {
                                        auto act_func = [] (__m256 res8f) {
                                            return _mm256_max_ps(res8f, zero8f);
                                        };
                                        printf("avgpoolrelu::::::::::::::::::::::%d\n", num_elements);
                                        start0_time = get_time_force();
                                        unblind(act_func, inp, temp_buffer, temp_buffer, num_elements);
                                        end0_time = get_time_force();
                                        elapsed100 = elapsed100 + get_elapsed_time(start0_time, end0_time);
                                        
                                    } else if (act == "avgpoolrelu6") {
                                        auto act_func = [] (__m256 res8f) {
                                            return _mm256_min_ps(_mm256_max_ps(res8f, zero8f), six8f);
                                        };
                                        printf("avgpoolrelu6::::::::::::::::::::::%d\n", num_elements);
                                        start0_time = get_time_force();             
                                        unblind(act_func, inp, temp_buffer, temp_buffer, num_elements);
                                        end0_time = get_time_force();
                                        elapsed100 = elapsed100 + get_elapsed_time(start0_time, end0_time);
                                        } else {
                                        assert(0);
                                    }

                                    
                                } else {
                                        // unblind and add bias
                                        first_layer = true;
                                        sgx_time_t b_start;
                                        if (time) {
                                            b_start = get_time_force();
                                        }
                                        unblind_bias(inp, blind, temp_buffer2,  mini_batch_len / mini_batch_size, ch);
                                        sgx_time_t b_end;
                                        if (time) {
                                            b_end = get_time_force();
                                            time_report[1] += get_elapsed_time(b_start, b_end);
                                        }
                                        
                                        //printf("after bias_add\n");
                                        std::copy(temp_buffer2, temp_buffer2+mini_batch_len, out);
                                }

                        } else {
                                // no privacy, just compute activation

                                TensorMap<float, 1> out_map((float*) out, num_elements);
                                float shift = 1.0/256;

                                if (act == "relu") {
                                        out_map = (inp_map * shift).round().cwiseMax(static_cast<float>(0));
                                } else if (act == "relu6") {
                                        out_map = (inp_map.cwiseMax(static_cast<float>(0)).cwiseMin(static_cast<float>(6 * 256 * 256)) * shift).round();
                                } else if (act == "avgpoolrelu" || act == "avgpoolrelu6") {
                                        array4d input_shape = curr_layer->input_shape_;
                                        input_shape[0] = 1;
                                        array4d output_shape = {{1, 1, 1, input_shape[3]}};
                                        assert(input_shape[1] * input_shape[2] * input_shape[3] == num_elements);

                                        TensorMap<float, 4> inp_map4d(inp, input_shape);
                                        TensorMap<float, 4> out_map4d((float*) out, output_shape);

                                        Eigen::array<int, 2> mean_dims({1, 2 /* dimensions to reduce */});

                                        if (act == "avgpoolrelu") {
                                                auto temp = inp_map4d.cwiseMax(static_cast<float>(0));
                                                out_map4d = ((temp.eval().mean(mean_dims).reshape(out_map4d.dimensions())) * shift).round();
                                        } else if (act == "avgpoolrelu6") {
                                                auto temp = inp_map4d.cwiseMax(static_cast<float>(0)).cwiseMin(static_cast<float>(6 * 256 * 256));
                                                out_map4d = ((temp.eval().mean(mean_dims).reshape(out_map4d.dimensions())) * shift).round();
                                        } else {
                                                assert(0);
                                        }
                                        num_out_elements = input_shape[3];
                                } else {
                                        out_map = inp_map;
                                }
                        }

                        inp += mini_batch_len;
                        out += num_out_elements;
                        relu_src += mini_batch_len;
                }
                if (time) {
                    sgx_time_t srelu_end = get_time_force();
                    time_report[6] += get_elapsed_time(srelu_start, srelu_end);
                }
        }

    // compute a MaxPool + ReLU on blinded data
    void slalom_maxpoolrelu(float* inp, float* out, float* workspace, float* relu_src, float* bias, long int dim_in[4], long int dim_out[4],
                                                    int window_rows_, int window_cols_, int row_stride_, int col_stride_, bool is_padding_same)
        {
                //printf("slalom_maxpoolrelu %d / %d\n", act_idx, activation_idxs.size());
                sgx_time_t maxpoolrelu_start;
                if (time) {
                 maxpoolrelu_start = get_time_force();
                }

                int num_elements = dim_in[1] * dim_in[2] * dim_in[3] * mini_batch_size;
                assert(dim_in[0] == batch_size);
                int num_bytes = num_elements * sizeof(float);

                int layer_idx = activation_idxs[act_idx];
                auto curr_layer = model_float.layers[layer_idx];
                auto prev_layer = model_float.layers[layer_idx - 1];
                auto next_layer = model_float.layers[layer_idx + 1];
                auto next2_layer = model_float.layers[layer_idx + 2];

                sgx_time_t start0_time;
                sgx_time_t end0_time;
                if (verbose) {
                        printf("\nin maxpoolrelu: prev layer: %s, curr layer: %s, next layer: %s\n",
                                     prev_layer->name_.c_str(), curr_layer->name_.c_str(), next_layer->name_.c_str());
                }
                act_idx += 1;
                if (act_idx == activation_idxs.size())
                    act_idx = 0;
                assert(dynamic_cast<Activation<float>*>(curr_layer.get()) != nullptr);
                assert(dynamic_cast<MaxPool2D<float>*>(next_layer.get()) != nullptr);
                
                if (slalom_integrity) {
                    // we only handle convolutional layers
                    assert(dynamic_cast<Conv2D<float>*>(prev_layer.get()) != nullptr);
                    if (dynamic_cast<Conv2D<float>*>(next2_layer.get()) == nullptr) {
                        assert(dynamic_cast<Reshape<float>*>(next2_layer.get()) != nullptr);
                        next2_layer = model_float.layers[layer_idx + 3];
                    }
                    assert(dynamic_cast<Conv2D<float>*>(next2_layer.get()) != nullptr);
                }
                
                int h = dim_in[1];
                int w = dim_in[2];
                int ch = dim_in[3];

                sgx_time_t start_time;
                sgx_time_t end_time;

                int iter = batch_size / mini_batch_size;
                for (int b = 0; b<iter; b++) {

                        TensorMap<float, 1> inp_map(inp, num_elements);
                        TensorMap<float, 1> blind_map((float*) temp_buffer, num_elements);
                        Tensor<double, 0> res;

                        if (slalom_privacy) {
                                assert(num_elements % 8 == 0);
                                assert ((long int)inp % 32 == 0);

                                // unblind the data
                                //printf("maxpoolrelu:::::::::::::::::::%d\n", num_elements);
                                start0_time = get_time_force();
                                
                                int n = h*w*ch;
                                //unblind(id_avx, inp, temp_buffer, temp_buffer, num_elements);
                                sgx_time_t b_start;
                                if (time) {
                                 b_start = get_time_force();
                                }
                                unblind_bias(inp, bias, temp_buffer,  num_elements / mini_batch_size, ch);
                                sgx_time_t b_end;
                                if (time) {
                                    b_end = get_time_force();
                                    time_report[1] += get_elapsed_time(b_start, b_end);
                                }

                                // add the bias
                                array4d max_inp_shape = {mini_batch_size - 1, h, w, ch};
                                auto biasadd = TensorMap<float, 4>(temp_buffer, max_inp_shape);

                                if (verbose) { 
                                    res = blind_map.template cast<double>().minimum();
                                    printf("min(unblinded) = %f\n", res.data()[0]);
                                    res = blind_map.template cast<double>().maximum();
                                    printf("max(unblinded) = %f\n", res.data()[0]);
                                    res = blind_map.template cast<double>().abs().sum();
                                    printf("sum(abs(unblinded)) = %f\n", res.data()[0]);
                                }


                                if (TIMING) {
                                        printf("\tintrinsic stuff required %f sec\n", elapsed);
                                }

                                if (slalom_integrity) {
                                        // do the integrity checks on the output of the last convolution
                                        Conv2D<float>* conv2d_prev = dynamic_cast<Conv2D<float>*>(prev_layer.get());
                                        if (conv2d_prev->r_left_data_ == NULL) {
                                                conv2d_prev->res_z = model_float.mem_pool->alloc<double>(h*w*REPS);
                                                TensorMap<double, 1> res_z_map(conv2d_prev->res_z, h*w*REPS);
                                                res_z_map.setZero();
                                                conv2d_prev->preproc_verif_pointwise_Z(relu_avx, blind_map.data());
                                                double sum = 0;
                                                for (int i=0; i<h*w*REPS; i++) {
                                                        sum += conv2d_prev->res_z[i];
                                                }
                                                if (TIMING) {
                                                        printf("r_out_r: %f\n", sum);
                                                }
                                                model_float.mem_pool->release(conv2d_prev->res_z);
                                        } else {
                                                conv2d_prev->res_z = model_float.mem_pool->alloc<double>(REPS);
                                                TensorMap<double, 1> res_z_map(conv2d_prev->res_z, REPS);
                                                res_z_map.setZero();
                                                conv2d_prev->res_z_temp = model_float.mem_pool->alloc<double>(REPS);
                                                TensorMap<double, 1> res_z_temp_map(conv2d_prev->res_z_temp, REPS);
                                                res_z_temp_map.setZero();
                                                conv2d_prev->preproc_verif_Z(relu_avx, blind_map.data());
                                                if (TIMING) {
                                                        printf("r_out_r: %f, %f\n", mod_pos(conv2d_prev->res_z[0], p_verif), mod_pos(conv2d_prev->res_z[1], p_verif));
                                                }
                                                model_float.mem_pool->release(conv2d_prev->res_z);
                                                model_float.mem_pool->release(conv2d_prev->res_z_temp);
                                        }
                                }
                        }
                        else
                        {
                                new (&blind_map) TensorMap<float, 1>(inp, num_elements);
                        }

                
                        int h_out;
                        int w_out;
                        int pad_rows_;
                        int pad_cols_;

                        Eigen::PaddingType padding_ = Eigen::PaddingType::PADDING_VALID;
                        if (is_padding_same) {
                                padding_ = Eigen::PaddingType::PADDING_SAME;
                        }

                        GetWindowedOutputSize(h, window_rows_, row_stride_,
                                                                    padding_, &h_out, &pad_rows_);
                        GetWindowedOutputSize(w, window_cols_, col_stride_,
                                                                    padding_, &w_out, &pad_cols_);
                
                        int num_out_elements = h_out * w_out * ch * mini_batch_size;

                        sgx_time_t start1_time;
                                sgx_time_t end1_time;
                        double elapsed11 = 0;
                        
                        //fast_maxpool(blind_map.data(), temp_buffer2, 2, h, w, ch, h_out, w_out,
                        //           window_rows_, window_cols_, pad_rows_, pad_cols_, row_stride_, col_stride_, false);
                        int out_offset = h_out * w_out *ch;
                        int in_offset = h * w * ch;
                        float* res11 = temp_buffer2;
                        sgx_time_t m_start;
                        if (time) {
                         m_start = get_time_force();
                        }
                        ecall_maxpool(temp_buffer, temp_buffer2, workspace);
                        sgx_time_t m_end;
                        if (time) {
                         m_end = get_time_force();
                         time_report[5] += get_elapsed_time(m_start, m_end);
                        }
                        // blind relu_src
                        blind_matrix_func(temp_buffer2, relu_src, num_out_elements, 1);

                        TensorMap<float, 1> out_map((float*) temp_buffer2, num_out_elements / mini_batch_size * (mini_batch_size - 1));


                        if (slalom_privacy) {
                                // compute ReLU
                                sgx_time_t start_time;
                                sgx_time_t end_time;
                                double elapsed10 = 0;
                                //out_map = (out_map * static_cast<float>(1.0/256)).round().cwiseMax(static_cast<float>(0));
                                
                                sgx_time_t re_start;
                                if (time) {
                                    re_start = get_time_force();
                                }
                                out_map = out_map.cwiseMax(static_cast<float>(0));
                                sgx_time_t re_end;
                                if (time) {
                                    re_end = get_time_force();
                                    time_report[2] += get_elapsed_time(re_start, re_end);
                                }

                                if (slalom_integrity) {
                                        // do the integrity checks on the input of the next convolution
                                        Conv2D<float>* conv2d_next = dynamic_cast<Conv2D<float>*>(next2_layer.get());
                                        if (conv2d_next->r_left_data_ == NULL) {
                                                conv2d_next->res_x = model_float.mem_pool->alloc<double>(conv2d_next->h*conv2d_next->w*REPS);

                                                conv2d_next->preproc_verif_pointwise_X(out_map.data());
                                                double sum = 0;
                                                for (int i=0; i<conv2d_next->h*conv2d_next->w*REPS; i++) {
                                                        sum += conv2d_next->res_x[i];
                                                }
                                                if (TIMING) {
                                                        printf("r_inp_wr: %f\n", sum);
                                                }
                                                model_float.mem_pool->release(conv2d_next->res_x);
                                        } else {
                                                conv2d_next->res_x = model_float.mem_pool->alloc<double>(REPS);
                                                conv2d_next->preproc_verif_X(out_map.data());
                                                if (TIMING) {
                                                        printf("r_inp_wr: %f, %f\n", mod_pos(conv2d_next->res_x[0], p_verif), mod_pos(conv2d_next->res_x[1], p_verif));
                                                }
                                                model_float.mem_pool->release(conv2d_next->res_x);
                                        }
                                }

                                // reblind output
                                //printf("num_out_elementsi_re:::::::::::::::::::::::%d\n" , num_out_elements);
                                //start0_time = get_time_force();
                                blind_matrix_func(temp_buffer2, out, num_out_elements, 1);
                                //end0_time = get_time_force();
                                //  elapsed102 = elapsed102 + get_elapsed_time(start0_time, end0_time);
                        } else {
                                TensorMap<float, 1> out_map2(out, num_out_elements);
                                out_map2 = (out_map * static_cast<float>(1.0/256)).round().cwiseMax(static_cast<float>(0));
                        }

                inp += num_elements;
                out += num_out_elements;
                workspace += num_out_elements;
                relu_src += num_out_elements;
                }
                if (time) {
                    sgx_time_t maxpoolrelu_end = get_time_force();
                    time_report[6] += get_elapsed_time(maxpoolrelu_start, maxpoolrelu_end);
                }
        }

        void sgxdnn_benchmarks(int num_threads) {
                benchmark(num_threads);
        }
}
