#include "resnet_activation.h"
#include "dnnl.hpp"
#include "pool_sp.h"
#include "example_utils.hpp"
#include "Enclave.h"
#include "immintrin.h"
#include "Enclave_t.h"

using namespace dnnl;

extern float* bm_buf;
extern float* gm_buf;
extern float* igm_buf;
extern float* um_buf;

ResnetActivation::ResnetActivation(string act_mode, 
                     			   int in_size[4], 
                     			   int out_size[4], 
                     			   int pool_window[2], 
                     			   int pool_stride[2], 
                     			   float eps,
                     			   float momentum,
                     			   float* bias_data,
                     			   memory::desc src_tags, 
                     			   engine eng, 
                     			   stream s) {

	//printf("resnet activation setup");
	//printf("in_size %d %d %d %d", in_size[0], in_size[1], in_size[2], in_size[3]);
	//printf("out_size %d %d %d %d", out_size[0], out_size[1], out_size[2], out_size[3]);
	//printf("win %d %d", pool_window[0], pool_window[1]);
	//printf("str %d %d", pool_stride[0], pool_stride[1]);
	//printf("eps %f mom%f", eps, momentum);

	this->eng_       = eng;
	this->s_         = s;
	this->src_tag_   = src_tags;
	this->eps_       = eps;
	this->momentum_  = momentum;
	this->in_size_   = in_size[1]*in_size[2]*in_size[3];
	this->out_size_  = out_size[1]*out_size[2]*out_size[3];
	this->act_mode_  = act_mode;
	// configurable
	this->internal_batch_size_ = 3;
	int i = 0;
	batch_size_          = in_size[0];
	channel_             = in_size[3];

	center_ptr_          = new float[in_size[3]+8];
	scale_ptr_           = new float[in_size[3]+8];
	center_grad_ptr_     = new float[in_size[3]+8];
	scale_grad_ptr_      = new float[in_size[3]+8];
	bias_ptr_            = new float[in_size[3]+8];
	mean_ptr_            = new float[channel_+8];
	std_ptr_             = new float[channel_+8];
	bias_grad_ptr_       = new float[channel_+8];
	bias_grad_aligned_   = ALIGN32(bias_grad_ptr_);
	std_aligned_         = ALIGN32(std_ptr_);
	mean_aligned_        = ALIGN32(mean_ptr_);

	center_aligned_      = ALIGN32(center_ptr_);
	scale_aligned_       = ALIGN32(scale_ptr_);
	center_grad_aligned_ = ALIGN32(center_grad_ptr_);
	scale_grad_aligned_  = ALIGN32(scale_grad_ptr_);
	bias_aligned_        = ALIGN32(bias_ptr_);
	get_bn_ = true;

	for (int i = 0; i < channel_; i++) {
		center_aligned_[i] = 0.0;
		scale_aligned_[i]  = 1.0;
	}

	if (bias_data != nullptr)
		std::copy(bias_data, bias_data+channel_, this->bias_aligned_);

	if (channel_ % 8 != 0) {
		printf("resnet activation channel_ is not divisible by 8");
		assert(false);
	}

	if (act_mode == "bias_add") {
		get_pool_ = false;
		get_act_  = false;
		get_bn_   = false;

	} else if (act_mode == "bnzerorelu") {
		get_pool_ = false;
		get_act_  = true;
	} else if (act_mode == "bnrelupool") {
		get_pool_ = true;
		get_act_  = true;

	} else if (act_mode == "bottom") {
		get_pool_ = false;
		get_act_  = false;
	}


	nhwc_ = memory::desc({{2, in_size[3], in_size[1], in_size[2]}}, memory::data_type::f32, memory::format_tag::nhwc);
    nhwc_out_ = memory::desc({{2, out_size[3], out_size[1], out_size[2]}}, memory::data_type::f32, memory::format_tag::nhwc);
	int padding[2] = {0,0};
	if (get_pool_) {
		int pool_insize [4] ={2, in_size[1], in_size[2], in_size[3]}; 
		int pool_outsize[4] ={2, out_size[1], out_size[2], out_size[3]};
		pool_ = new PoolSp(pool_insize, pool_outsize, pool_window, pool_stride, padding, 0, src_tags, eng, s);

		pool_->update_backward(nhwc_out_);
	}
    
	Triple_t a;
	Triple_t b;
	Triple_t c;

	uint64_t ar, br, cr;
	ar = ((uint64_t) &a) % 32;
	br = ((uint64_t) &b) % 32;
	cr = ((uint64_t) &c) % 32;
	printf("%d %d %d", ar, br, cr);
	i = 0;
    // loading matice to vector
    uint64_t sd;
    if (bm_buf != nullptr && um_buf != nullptr) {
		bm.first    =  ALIGN32T(__m256,new __m256[4]);
	    bm.first[0] = _mm256_set1_ps( bm_buf[0]);
    	bm.first[1] =_mm256_set1_ps( bm_buf[1]);
    	bm.first[2] =_mm256_set1_ps( bm_buf[2]);

    	bm.second    =  ALIGN32T(__m256,new __m256[4]);
    	bm.second[0] = _mm256_set1_ps(bm_buf[3]);
    	bm.second[1] = _mm256_set1_ps(bm_buf[4]);
    	bm.second[2] = _mm256_set1_ps(bm_buf[5]);
    	
    	bm.third    =  ALIGN32T(__m256,new __m256[4]);
    	bm.third[0] = _mm256_set1_ps(bm_buf[6]);
    	bm.third[1] = _mm256_set1_ps(bm_buf[7]);
    	bm.third[2] = _mm256_set1_ps(bm_buf[8]);

    	um.first    =  ALIGN32T(__m256,new __m256[4]);
    	um.first[0] = _mm256_set1_ps(um_buf[0]);
    	um.first[1] = _mm256_set1_ps(um_buf[1]);
    	um.first[2] = _mm256_set1_ps(um_buf[2]);

		um.second    =  ALIGN32T(__m256,new __m256[4]);
    	um.second[0] = _mm256_set1_ps(um_buf[3]);
    	um.second[1] = _mm256_set1_ps(um_buf[4]);
    	um.second[2] = _mm256_set1_ps(um_buf[5]);

    	um.third    =  ALIGN32T(__m256,new __m256[4]);
    	um.third[0] = _mm256_set1_ps(um_buf[6]);
    	um.third[1] = _mm256_set1_ps(um_buf[7]);
    	um.third[2] = _mm256_set1_ps(um_buf[8]);

    	gm.first    = ALIGN32T(__m256,new __m256[4]);
    	gm.first[0] = _mm256_set1_ps(gm_buf[0]);
    	gm.first[1] = _mm256_set1_ps(gm_buf[1]);
    	gm.first[2] = _mm256_set1_ps(gm_buf[2]);
		
		gm.second    =  ALIGN32T(__m256,new __m256[4]);
    	gm.second[0] = _mm256_set1_ps(gm_buf[3]);
    	gm.second[1] = _mm256_set1_ps(gm_buf[4]);
    	gm.second[2] = _mm256_set1_ps(gm_buf[5]);
    	
    	gm.third    =  ALIGN32T(__m256,new __m256[4]);
    	gm.third[0] = _mm256_set1_ps(gm_buf[6]);
    	gm.third[1] = _mm256_set1_ps(gm_buf[7]);
    	gm.third[2] = _mm256_set1_ps(gm_buf[8]);

		igm.first    = ALIGN32T(__m256,new __m256[4]);
    	igm.first[0] = _mm256_set1_ps(igm_buf[0]);
    	igm.first[1] = _mm256_set1_ps(igm_buf[1]);
    	igm.first[2] = _mm256_set1_ps(igm_buf[2]);
		
		igm.second    =  ALIGN32T(__m256,new __m256[4]);
    	igm.second[0] = _mm256_set1_ps(igm_buf[3]);
    	igm.second[1] = _mm256_set1_ps(igm_buf[4]);
    	igm.second[2] = _mm256_set1_ps(igm_buf[5]);

    	igm.third    =  ALIGN32T(__m256,new __m256[4]);
    	igm.third[0] = _mm256_set1_ps(igm_buf[6]);
    	igm.third[1] = _mm256_set1_ps(igm_buf[7]);
    	igm.third[2] = _mm256_set1_ps(igm_buf[8]);
    	
    }

	ocall_extern_alloc((void**) &act_src_ptr_, 
		                           (batch_size_) * sizeof(float) * this->in_size_ + 32);	
	act_src_aligned_ = ALIGN32(act_src_ptr_);
	ocall_extern_alloc((void**) &batch_src_ptr_, 
		                           (batch_size_) * sizeof(float) * this->in_size_ + 32);	
	batch_src_aligned_ = ALIGN32(batch_src_ptr_);

	if (get_pool_) {
		ocall_extern_alloc((void**) &work_ptr_, 
		                           sizeof(float) * pool_->pool_pd_.workspace_desc().get_size() + 32);
		work_align_ = ALIGN32(work_ptr_);
	}
}


void ResnetActivation::forward_sp(float* src, float* dst, float* mean_extern, bool is_first=false, bool final_enc=true) {
	int iter = batch_size_ / internal_batch_size_;
	// all operations with batchnorm
	
	if (get_bn_) {
		/***************** mean recover ***********************/
		int mean_skip = channel_ * internal_batch_size_;
		memset(mean_aligned_, 0, channel_*sizeof(float));
		for (int i = 0; i < iter; i++) {
			float* mean_extern_ptr = mean_extern + mean_skip;
			for (int j = 0; j < channel_; j+=8){
				Triple_t mean_vec = load_triple(mean_extern_ptr, j, channel_);
				__m256 mean_res = _mm256_load_ps(&mean_aligned_[j]);
				mean_vec = decrypt(um, mean_vec); 
				mean_res = _mm256_add_ps(mean_res, mean_vec.first);
				mean_res = _mm256_add_ps(mean_res, mean_vec.second);
				_mm256_stream_ps(&mean_aligned_[j], mean_res);
			}
		}

		// final division and add bias to the mean

		__m256 num = _mm256_set1_ps((float) (iter * (internal_batch_size_-1)));
		for (int j = 0; j < channel_; j+=8) {
			__m256 mean_res = _mm256_load_ps(&mean_aligned_[j]);
			__m256 bias     = _mm256_load_ps(&bias_aligned_[j]);

			mean_res        = _mm256_div_ps(mean_res, num);
			mean_res        = _mm256_add_ps(mean_res, bias);
			_mm256_stream_ps(&mean_aligned_[j], mean_res);
		}

		/****************** add bias & get std ***************************/
		memset(std_aligned_, 0, channel_*sizeof(float));
		int image_skip_size = this->in_size_ * internal_batch_size_;
		for (int i = 0; i < iter; i++) {
			float* input_ptr = src + i * image_skip_size;
			float* dump_ptr  = this->batch_src_aligned_ + i * image_skip_size;
			for (int j = 0; j < this->in_size_; j+=8) {
				Triple_t image;
				image = load_triple(input_ptr, j, this->in_size_);
				int cs = 0;

				image = decrypt(um, image);
				__m256 bias  = _mm256_load_ps(&this->bias_aligned_[j%channel_]);
				image.first  = _mm256_add_ps(image.first, bias);
				image.second = _mm256_add_ps(image.second, bias);

				// write encrypted post-bias result to batch_src_aligned_
				Triple_t dec_bias = encrypt(bm, image);
				dump_triple(dec_bias, dump_ptr, j, this->in_size_);
				__m256 std_res = _mm256_load_ps(&std_aligned_[j%channel_]);

				// 2xstd computation
				__m256 mean_vec = _mm256_load_ps(&mean_aligned_[j%channel_]);
				__m256 temp1 = _mm256_sub_ps(image.first, mean_vec);
				temp1 = _mm256_mul_ps(temp1, temp1);
				__m256 temp2 = _mm256_sub_ps(image.second, mean_vec);
				temp2 = _mm256_mul_ps(temp2, temp2);

				temp2 = _mm256_add_ps(temp1, temp2);
				std_res = _mm256_add_ps(std_res, temp2);

				// stream back the std_align
				_mm256_stream_ps(&std_aligned_[j%channel_], std_res);
			}
		}

		// trim final std
		int effective_batch = batch_size_ / internal_batch_size_ * (internal_batch_size_-1);
		int total_num = effective_batch * (this->in_size_ / channel_) - 1;
		num = _mm256_set1_ps((float)total_num);
		__m256 eps_vec = _mm256_set1_ps(this->eps_);
		for (int i = 0; i < channel_; i+=8) {
			__m256 std_res = _mm256_load_ps(&this->std_aligned_[i]);
			std_res = _mm256_div_ps(std_res, num);
			std_res = _mm256_add_ps(std_res, eps_vec);
			std_res = _mm256_sqrt_ps(std_res);
			_mm256_stream_ps(&std_aligned_[i], std_res);
		}

		// combined norm_func & relu or combined form_func & none activation

		void (*act_src_dump_func)(Triple_t& input, float* ptr, int pos, int image_size);
	    Triple_t (*act_func)(Triple_t& input);
	    Triple_t (*act_src_enc_func)(M_t&, Triple_t& input);
	    Triple_t (*pool_src_enc_func)(M_t&, Triple_t& input);
	    act_func = none;
	    act_src_dump_func = dump_triple_none;
	    act_src_enc_func  = encrypt_none;

	    if (act_mode_ == "bnzerorelu") {
		    act_func = relu;
		    act_src_dump_func = dump_triple;
		    act_src_enc_func  = encrypt;
			pool_src_enc_func = encrypt;

		} else if (act_mode_ == "bnrelupool") {
			act_func = relu;
		    act_src_dump_func = dump_triple;
		    act_src_enc_func  = encrypt;
			pool_src_enc_func = encrypt_none;

		} else if (act_mode_ == "bottom") {
			act_func = none;
		    act_src_dump_func = dump_triple;
		    act_src_enc_func  = encrypt;
			pool_src_enc_func = encrypt;
		}

	    float* norm_destination;
	    float* norm_destination_ptr = nullptr;
	    float* pool_dst;
	    float* pool_dst_ptr = nullptr;
	    if (!get_pool_) {
	    	norm_destination = dst;
	    } else {
	    	norm_destination_ptr = new float[this->in_size_*internal_batch_size_+8];
	    	norm_destination     = ALIGN32(norm_destination_ptr);
	    	pool_dst_ptr         = new float[this->out_size_*internal_batch_size_+8];
			pool_dst             = ALIGN32(pool_dst_ptr);        
	    }

	    
		// int image_skip_size = this->in_size_ * internal_batch_size_;
	    for (int i = 0; i < iter; i++) {
	    	float* input_ptr = this->batch_src_aligned_ + i * image_skip_size;
	    	float* out_ptr   = norm_destination + i * image_skip_size;
	    	if (get_pool_)
	    		out_ptr = norm_destination;
	    	float* act_src_ptr = this->act_src_aligned_ + i * image_skip_size;
	    	
	    	for (int j = 0; j < this->in_size_; j+=8) {
	    		Triple_t image;
	    		image = load_triple(input_ptr, j, this->in_size_);
	    		image = decrypt(um, image);
	    		// load mean_vec, std_vec gamma_vec and beta_vec
	    		int channel_ptr = j % this->channel_;
	    		__m256 mean_vec = _mm256_load_ps(&mean_aligned_[channel_ptr]);
	    		__m256 std_vec = _mm256_load_ps(&std_aligned_[channel_ptr]);
	    		__m256 gamma_vec = _mm256_load_ps(&scale_aligned_[channel_ptr]);
	    		__m256 beta_vec = _mm256_load_ps(&center_aligned_[channel_ptr]);
	    		// normalizing

	    		image = norm_triple(image, mean_vec, std_vec, gamma_vec, beta_vec);

	    		Triple_t act_src = act_src_enc_func(bm, image);
	    		act_src_dump_func(act_src, act_src_ptr, j, this->in_size_);

	    		image = act_func(image);
	    		image = pool_src_enc_func(bm, image);

	    		dump_triple(image, out_ptr, j, this->in_size_);
	    	}

	    	 if (!get_pool_)
	    		continue;
	    	
	    	this->pool_->forward_resnet(norm_destination, pool_dst, work_align_, true);
	    	float* dest_ptr = dst + i * this->out_size_ * internal_batch_size_;
	    	
	    	for (int j = 0; j < this->out_size_; j+=8) {
	    		Triple_t image;
	    		image = load_triple(pool_dst, j, this->out_size_);
	    		image = encrypt(bm, image);

	    		dump_triple(image, dest_ptr, j, this->out_size_);
	    	}

	    }

	    if (norm_destination_ptr != nullptr)
	    	delete [] norm_destination_ptr;
	    if (pool_dst_ptr != nullptr)
	    	delete [] pool_dst_ptr;
	} else {
		// bias_add mode
		int image_skip_size = this->in_size_ * internal_batch_size_;
	    for (int i = 0; i < iter; i++) {
	    	float* input_ptr = src + i * image_skip_size;
	    	float* out_ptr   = dst + i * image_skip_size;
	    	for (int j = 0; j < this->in_size_; j+=8) {
	    		Triple_t image = load_triple(input_ptr, j, this->in_size_);
	    		image = decrypt(um, image);
				__m256 bias = _mm256_load_ps(&this->bias_aligned_[j%channel_]);
				image.first = _mm256_add_ps(image.first, bias);
				image.second = _mm256_add_ps(image.second, bias);
				image = encrypt(bm, image);
				dump_triple(image, out_ptr, j, this->in_size_);
	    	}
	    }
	}
}
void ResnetActivation::update_backward(memory::desc) {}

void ResnetActivation::backward(float* diff_src_ptr, float* diff_dst_ptr) {
	int iter = batch_size_ / internal_batch_size_;
	int image_skip_size = this->in_size_ * internal_batch_size_;
	int out_skip_size = this->out_size_ * internal_batch_size_;
	memset(this->bias_grad_aligned_,   0 , this->channel_);
	memset(this->center_grad_aligned_, 0 , this->channel_);
	memset(this->scale_grad_aligned_,  0 , this->channel_);
	
	if (this->act_mode_ == "bias_add") {
		
		for (int i = 0; i < iter; i++) {
			float* grad_out_ptr = diff_dst_ptr + i * out_skip_size;
			float* grad_in_ptr =  diff_src_ptr + i * out_skip_size;

			for (int j = 0; j < this->out_size_; j+=8) {
	    		Triple_t image = load_triple(grad_out_ptr, j, this->out_size_);
	    		Triple_t image_clean = decrypt(igm, image);
	    		__m256 bias_grad = _mm256_load_ps(&bias_grad_aligned_[j%channel_]);
	    		bias_grad = _mm256_add_ps(bias_grad, image_clean.first);
	    		bias_grad = _mm256_add_ps(bias_grad, image_clean.second);
	    		_mm256_stream_ps(&bias_grad_aligned_[j%channel_], bias_grad);
	    		dump_triple(image, grad_in_ptr, j, this->out_size_);
    		}
		}

		return;
	}


	float* diff_act_dst_ptr = nullptr;
	float* diff_act_dst_aligned;

	float* pool_dst_ptr = new float[this->out_size_*sizeof(float)+8];
	float* pool_dst_align = ALIGN32(pool_dst_ptr);
	if (get_pool_) {
			diff_act_dst_ptr = new float[this->in_size_*sizeof(float)+8];
			diff_act_dst_aligned = ALIGN32(diff_act_dst_ptr);
	}

    Triple_t (*act_back_func)(Triple_t& grad, Triple_t& input) = grad_none;
    Triple_t (*act_src_dec_func)(M_t&, Triple_t& input) = encrypt_none;
    if (get_act_ ) {
    	act_back_func = relu_back;
    	act_src_dec_func = decrypt;
	}

	if (this->act_mode_ == "bottom") {
    	act_src_dec_func = decrypt;
    	act_back_func = relu_back;
	}

	
	for (int i = 0; i < iter; i++) {
		// pool back
		// decrypt first
		float* grad_out_ptr = diff_dst_ptr + i * out_skip_size;
    	for (int j = 0; j < this->out_size_; j+=8) {
    		Triple_t image = load_triple(grad_out_ptr, j, this->out_size_);
    		image = decrypt(igm, image);
			dump_triple(image, pool_dst_align, j, this->in_size_);
    	}

    	if (get_pool_) {
    		pool_->backward_resnet(diff_act_dst_aligned, pool_dst_align);
    	} else {
    		diff_act_dst_aligned = pool_dst_align;
    	}
    	float* act_src_ptr = act_src_aligned_ + i * image_skip_size;
    	float* grad_in_ptr = diff_src_ptr + i * image_skip_size;
    	for (int j = 0; j < this->in_size_; j+=8) {
    		Triple_t act_src  = load_triple(act_src_ptr , j, this->in_size_);
    		Triple_t grad_out = load_triple(diff_act_dst_aligned, j, this->in_size_);
    		act_src = act_src_dec_func(igm, act_src);
    		// activation backward
    		grad_out = act_back_func(grad_out, act_src);

    		//  norm back
    		int channel_ptr = j % channel_;
    		__m256 mean_m = _mm256_load_ps(&this->mean_aligned_[channel_ptr]);
			__m256 std_m  = _mm256_load_ps(&this->std_aligned_[channel_ptr]);
			__m256 one_vec = _mm256_set1_ps((float) 1);
			std_m  = _mm256_div_ps(one_vec, std_m);
    		
    		__m256 beta  = _mm256_load_ps(&center_grad_aligned_[channel_ptr]);
            __m256 gamma = _mm256_load_ps(&scale_grad_aligned_[channel_ptr]);
            __m256 scale_m = _mm256_load_ps(&scale_aligned_[channel_ptr]);
            __m256 center_m = _mm256_load_ps(&center_aligned_[channel_ptr]);

            // gradients to parameters
            __m256 input0 = _mm256_div_ps(_mm256_sub_ps(act_src.first, center_m), scale_m);
            __m256 input1 = _mm256_div_ps(_mm256_sub_ps(act_src.second, center_m), scale_m);

            __m256 grad_gamma = _mm256_add_ps(_mm256_mul_ps(grad_out.first, input0), gamma);
            __m256 grad_beta  = _mm256_add_ps(grad_out.first, beta);
            grad_gamma = _mm256_add_ps(_mm256_mul_ps(grad_out.second, input1), grad_gamma);
            grad_beta  = _mm256_add_ps(grad_out.second, grad_beta);
 			// gradients to input
 			Triple_t grad_in;
            grad_in.first  = _mm256_mul_ps(grad_out.first, _mm256_mul_ps(scale_m, std_m));
            grad_in.second = _mm256_mul_ps(grad_out.second, _mm256_mul_ps(scale_m, std_m));

            _mm256_stream_ps(&center_grad_aligned_[channel_ptr], grad_beta);
            _mm256_stream_ps(&scale_grad_aligned_[channel_ptr], grad_gamma);

            // gradients to bias
            __m256 bias_grad = _mm256_load_ps(&bias_grad_aligned_[j%channel_]);
    		bias_grad = _mm256_add_ps(bias_grad, grad_in.first);
    		bias_grad = _mm256_add_ps(bias_grad, grad_in.second);
    		_mm256_stream_ps(&bias_grad_aligned_[j%channel_], bias_grad);

            grad_in = encrypt(gm, grad_in);
            dump_triple(grad_in, grad_in_ptr, j, this->in_size_);
    	}
	}

	if (diff_act_dst_ptr != nullptr)
		delete [] diff_act_dst_ptr;
	if (pool_dst_ptr != nullptr)
		delete [] pool_dst_ptr;
}



ResnetActivation::~ResnetActivation() {
	if (pool_ != nullptr)
		delete pool_;

	delete [] center_ptr_;
	delete [] scale_ptr_;
	delete [] center_grad_ptr_;
	delete [] scale_grad_ptr_;
	delete [] bias_ptr_;
	delete [] mean_ptr_;
	delete [] batch_src_ptr_;
	delete [] act_src_ptr_;
} 