#include "../Enclave/batchnorm_sp.h"
#include "layers/batchnorm.hpp"
#include <string>
#include "immintrin.h"

float mean(float* inp, int size) {
  float means[8];
  __m256 mean_vec = _mm256_set1_ps((float)(0));

  for (int j = 0; j < size; j+=8) {
    __m256 num = _mm256_load_ps(&inp[j]);
    mean_vec = _mm256_add_ps(num, mean_vec);
  }
  _mm256_stream_ps(means, mean_vec);
  float res = 0.0;
  for (int i = 0; i < 8; i++)
    res += means[i];

  return res;
} 


BatchNormSP::BatchNormSP(int size[4], const char* mode, memory::desc src_tag, engine eng, stream s) {
	std::string mode_ = std::string(mode);
	std::string name  = std::string("name");
	array4d dim       = {size[0], size[1], size[2], size[3]};
	batchnorm_        = (void*)new SGXDNN::BatchNormSp<float>(name, dim);
  output_size_      = size[0] * size[1] * size[2] * size[3];
  input_size_       = output_size_;
  src_tag_          = src_tag;
  eng_              = eng_;
  s_                = s;
  nhwc_ = memory::desc({{size[0], size[3], size[1], size[2]}}, memory::data_type::f32, memory::format_tag::nhwc);
  if (nhwc_ != src_tag_) {
     input_reorder_ = true;
     primitive_attr dummy_attr;
     primitive reorder_layer = reorder(reorder::primitive_desc(eng, src_tag_, eng, nhwc_, dummy_attr));
     net_fwd.push_back(reorder_layer);
  }
  ocall_extern_alloc((void**) &src_dump_, sizeof(float)*src_tag_.get_size());
}

 
void BatchNormSP::forward_special (float* src, float* dst, float* skip, bool is_train) {
  float* batchnorm_src = src;
  float* batchnorm_src_ptr = nullptr;
  // dump src to dump_src_
  if (input_reorder_) {
    batchnorm_src_ptr = new float[this->output_size_+8];
    batchnorm_src     = ALIGN32(batchnorm_src_ptr);
    net_fwd.at(0).execute(s_, {{DNNL_ARG_FROM, memory(src_tag_, eng_, batchnorm_src)}, 
                      {DNNL_ARG_TO,   memory(nhwc_, eng_, batchnorm_src)}});
    s_.wait();
  }

  // mean
  float means = mean(batchnorm_src, this->output_size_);

  ((SGXDNN::BatchNormSp<float>*) batchnorm_)->fwd(dst, batchnorm_src, &means, skip, nullptr, nullptr, 1, 1, mode_.c_str());

  if (batchnorm_src_ptr != nullptr)
    delete [] batchnorm_src_ptr;
}
void BatchNormSP::forward(float* src, float* dst, bool is_train) {
}

void BatchNormSP::update_backward(memory::desc diff_src_tag) {

}

float* BatchNormSP::backward(float* diff_srt_ptr) {

}

int BatchNormSP::out_size() {
  return output_size_;
}

int BatchNormSP::input_size() {
  return input_size_;
}

memory::desc BatchNormSP::dst_desc() {

	return nhwc_;
}


