#include "resnet_activation.h"
#include "resnet_bottom.h"
#include "dnnl.hpp"
#include "pool_sp.h"
#include "example_utils.hpp"
#include "Enclave.h"
#include "immintrin.h"
#include <string>
#include "Enclave_t.h"

using namespace dnnl;

extern float* bm_buf;
extern float* gm_buf;
extern float* igm_buf;
extern float* um_buf;

ResnetBottom::ResnetBottom(string act_mode, 
	                       int in_size[4], 
	                       int out_size[4], 
	                       float eps,
	                       float momentum,
	                       float* bias_data_l,
	                       float* bias_data_r,
	                       memory::desc src_tags, 
	                       engine eng, 
	                       stream s) {

	this->act_mode_ = act_mode;
	this->eng_       = eng;
	this->s_         = s;
	this->src_tag_   = src_tags;
	this->in_size_   = in_size[1]*in_size[2]*in_size[3];
	this->out_size_  = out_size[1]*out_size[2]*out_size[3];
	this->batch_size_= out_size[0];
	this->internal_batch_size_ = 3;
	int pool_window[2] = {0, 0};
	int pool_stride[2] = {0, 0};
	left_norm_ = new ResnetActivation(std::string("bottom"), 
                                      in_size, 
                                      out_size, 
                                      pool_window, 
                                      pool_stride, 
                                      eps,
                                      momentum,
                                      bias_data_l,
                                      src_tags, 
                                      eng, 
                                      s);
	
	if (act_mode == "downsample") {
		right_norm_ = new ResnetActivation(std::string("bottom"), 
                                           in_size, 
                                           out_size, 
                                           pool_window, 
                                           pool_stride, 
                                           eps,
                                           momentum,
                                           bias_data_r,
                                           src_tags, 
                                           eng, 
                                           s);
	}

	ocall_extern_alloc((void**) &left_res_ptr_, 
		                           (batch_size_) * sizeof(float) * this->out_size_ + 32);
	ocall_extern_alloc((void**) &right_res_ptr_, 
		                           (batch_size_) * sizeof(float) * this->out_size_ + 32);


	ocall_extern_alloc((void**) &relu_temp_ptr_, 
		                           (batch_size_) * sizeof(float) * this->out_size_ + 32);
	left_res_align_  = ALIGN32(left_res_ptr_);
	right_res_align_ = ALIGN32(right_res_ptr_);

	relu_temp_align_ = ALIGN32(relu_temp_ptr_);
}

void ResnetBottom::forward_sp(float* left_in, float* right_in, float* mean_left, float* mean_right, 
	                       float* dst, bool training=true) {
	
	int image_skip_size = internal_batch_size_ * this->in_size_;
	int iter = batch_size_ / internal_batch_size_;
	this->left_norm_->forward_sp(left_in, left_res_align_, mean_left, false, true);
	if (this->act_mode_ == "downsample")
		this->right_norm_->forward_sp(right_in, right_res_align_, mean_right, false, true);

	else
		right_res_align_ = right_in;

	__m256 zero_v = _mm256_set1_ps((float)0);
	for (int i = 0; i < iter; i++) {
		float* left_ptr  = left_res_align_  + i *image_skip_size; 
		float* right_ptr = right_res_align_ + i *image_skip_size;
		float* out_ptr   = dst + i * image_skip_size;
		for (int j = 0; j < this->out_size_; j+=8) {
			Triple_t image_l = load_triple(left_ptr, j, this->out_size_);
			Triple_t image_r = load_triple(right_ptr, j, this->out_size_);

			image_l = decrypt(this->left_norm_->um, image_l);
			image_r = decrypt(this->left_norm_->um, image_r);

			Triple_t sum;
			sum.first  = _mm256_add_ps(image_l.first,  image_r.first);
			sum.second = _mm256_add_ps(image_l.second, image_r.second);

			Triple_t sum_enc = encrypt(this->left_norm_->bm, sum);

			dump_triple(sum_enc, relu_temp_align_, j, this->in_size_);
			image_l.first  = _mm256_max_ps(zero_v, sum.first);
			image_l.second = _mm256_max_ps(zero_v, sum.second);
			
			image_l = encrypt(this->left_norm_->bm, image_l);
			dump_triple(image_l, out_ptr, j, this->in_size_);
		}
	}
}

void ResnetBottom::backward_sp(float* left_grad, float* right_grad, float* diff_dst_ptr) {
	
	int image_skip_size = internal_batch_size_ * this->in_size_;
	int iter = batch_size_ / internal_batch_size_;
	
	for (int i = 0; i < iter; i++) {
		float* grad_out_in = diff_dst_ptr + i * image_skip_size;
		float* act_arc_ptr = this->relu_temp_align_ + i * image_skip_size;

		for (int j = 0; j < this->in_size_; j+=8) {
			Triple_t image = load_triple(act_arc_ptr, j, this->in_size_);
			Triple_t grad  = load_triple(grad_out_in, j, this->in_size_);
			image = decrypt(this->left_norm_->um, image);

			Triple_t grad_next = relu_back(grad, image);

			grad_next = encrypt(this->left_norm_->gm, grad_next);
			dump_triple(grad_next, this->relu_temp_align_, j, this->in_size_);
		}
	}

	

	this->left_norm_->backward(left_grad, relu_temp_align_);
	if (this->act_mode_ == "downsample")
		this->right_norm_->backward(right_grad, relu_temp_align_);

}