#include "layer.hpp"
#include "inverted_sp.h"
#include "resnet_block.h"
#include "relu_sp.h"
#include "Enclave.h"
#include "Enclave_t.h"
#include <string>
#include "immintrin.h"

void* conv2d_wrapper (int conv_src_sz[4], int conv_dst_sz[4], int conv_weight_sz[4],
                     int conv_strides_sz[2], int conv_padding_sz[2], void* src_tags,
                     float* weight_data, float* bias_data, bool training, bool is_first);

ResnetBlockSp::ResnetBlockSp(int in_size[4], int out_size[4],
							 int stride[2], bool identity,
							 float* weight_data, float* bias_data,
	                         memory::desc tags, 
	                         engine eng, stream s) {

	this->identity_ = identity;
	this->eng_ = eng;
	this->s_   = s;
	right_dst_desc_ = tags;
	if (!this->identity_) {
        printf("not identity");
		int kernel_size[4] = {1, 1, in_size[3], out_size[3]};
		int padding[2]     = {0, 0};
		LayerSp* l;
		
		l = (LayerSp*) conv2d_wrapper(in_size, out_size, 
	                                  kernel_size, stride, 
	                                  padding, (void*)&tags, 
	                                  weight_data, bias_data, 
	                                  true, false);
		printf("after conv2d pinter");
        memory::desc next = l->dst_desc();
        printf("after conv2d 21");

		this->networks_r_.push_back(l);

		l = new BatchNormSP(out_size, 0, 0.0, 0.0, next, eng, s);
		this->networks_r_.push_back(l);
		right_dst_desc_ = l->dst_desc();

	}
    printf("after identity");

	relu_in_size_[0] = out_size[0];
	relu_in_size_[1] = out_size[1];
	relu_in_size_[2] = out_size[2];
	relu_in_size_[3] = out_size[3];
    printf("after resblock");

}

void ResnetBlockSp::setup_final_relu() {
	memory::desc left_dst_desc = this->networks_l_.at(this->networks_l_.size() - 1)->dst_desc();
    primitive_attr dummy_attr;
    printf("before if");
	if (right_dst_desc_ != left_dst_desc) {

		right_fwd_reorder_ = true;
		reroder_right_f_pd_ = reorder(reorder::primitive_desc(this->eng_, right_dst_desc_, this->eng_, left_dst_desc, dummy_attr));
	}
    printf("after if");

	this->final_relu_ = (LayerSp*) new Relu_Sp(relu_in_size_, left_dst_desc, 
	                                              eng_, s_);
}

void ResnetBlockSp::forward(float* src, float* dst, bool is_train) {
	int size = this->networks_l_.size();
    float* res_l_ptr;
    float* res_l_ptr_aligned;
    float* in_ptr = src;
    float* in_ptr_aling = ALIGN32(in_ptr);

    // left path
    for (int i = 0; i < size; i++) {
        LayerSp* l = this->networks_l_.at(i);
        res_l_ptr = new float[l->dst_desc().get_size() / sizeof(float) + 8];
        res_l_ptr_aligned = ALIGN32(res_l_ptr);
        l->forward(in_ptr_aling, res_l_ptr_aligned, true);
        if (i != 0)
            delete [] in_ptr;
        in_ptr = res_l_ptr;
        in_ptr_aling = res_l_ptr_aligned;
    }

    float* res_r_ptr=nullptr;
    float* res_r_ptr_aligned=ALIGN32(in_ptr);
	in_ptr = src;
	in_ptr_aling = ALIGN32(in_ptr);
    // right path
    int size_r = this->networks_r_.size();
    for (int i = 0; i < size_r; i++) {
        LayerSp* l = this->networks_r_.at(i);
        res_r_ptr = new float[l->dst_desc().get_size() / sizeof(float) + 8];
        res_r_ptr_aligned = ALIGN32(res_r_ptr);
        l->forward(in_ptr_aling, res_r_ptr_aligned, true);
        if (i != 0)
            delete [] in_ptr;
        in_ptr = res_r_ptr;
        in_ptr_aling = res_r_ptr_aligned;
    }


    memory::desc left_dst_desc = this->networks_l_.at(this->networks_l_.size() - 1)->dst_desc();

    float* right_fwd_reordered = nullptr;
    float* right_fwd_reordered_aligned_ = res_r_ptr_aligned;

    if (right_fwd_reorder_) {
    	right_fwd_reordered = new float[left_dst_desc.get_size() / sizeof(float) + 8];
    	right_fwd_reordered_aligned_ = ALIGN32(right_fwd_reordered);
    	reroder_right_f_pd_.execute(this->s_, {{DNNL_ARG_FROM, memory(right_dst_desc_,   this->eng_, res_r_ptr_aligned)}, 
                                              {DNNL_ARG_TO,    memory(left_dst_desc, this->eng_, right_fwd_reordered_aligned_)}});
    }
    // sum and final relu
    for (int i = 0; i< left_dst_desc.get_size() / sizeof(float); i+=8) {
    	__m256 l = _mm256_load_ps(&res_l_ptr_aligned[i]);
    	__m256 r = _mm256_load_ps(&res_r_ptr_aligned[i]);
    	l        = _mm256_add_ps(l, r);

    	_mm256_stream_ps(&res_l_ptr_aligned[i], l);
    }

    this->final_relu_->forward(res_l_ptr_aligned, dst, false);

    delete [] res_l_ptr;
    
    if (right_fwd_reordered != nullptr)
    	delete [] right_fwd_reordered;

   	if (res_r_ptr != nullptr)
   		delete [] res_r_ptr;
}

void ResnetBlockSp::update_backward(memory::desc diff_src_tag) {
	this->final_relu_->update_backward(diff_src_tag);

	memory::desc tag_next_l   = this->final_relu_->diff_src_desc();
	memory::desc tag_next_r = this->final_relu_->diff_src_desc();
	int size_l = this->networks_l_.size();
	int size_r = this->networks_r_.size();
	// left update
	for (int i = size_l - 1; i >= 0; i--) {
		LayerSp* l = this->networks_l_.at(i);
		l->update_backward(tag_next_l);
		tag_next_l = l->diff_src_desc();
	}

	// right update
	for (int i = size_r - 1; i >= 0; i--) {
		LayerSp* l = this->networks_r_.at(i);
        printf("%d", i);
		l->update_backward(tag_next_r);
        printf("%d", i);
		tag_next_r = l->diff_src_desc();
        printf("%d", i);

	}
    primitive_attr dummy_attr;


	if (tag_next_l != tag_next_r) {
		this->right_bwd_reorder_ = true;
		reroder_right_b_pd_ = reorder(reorder::primitive_desc(this->eng_, tag_next_r, this->eng_, tag_next_l, dummy_attr));
	}
}

void ResnetBlockSp::backward(float* diff_src_ptr, float* diff_dst_ptr) {
	float* grad_out_ptr   = new float[this->final_relu_->diff_src_desc().get_size() / sizeof(float) + 8];
	float* grad_out_align = ALIGN32(grad_out_ptr);

	// relu backwards
	this->final_relu_->backward(grad_out_align, diff_dst_ptr);
	int size_l = this->networks_l_.size();
	int size_r = this->networks_r_.size();
	float* grad_l_ptr;
    float* grad_l_aligned;
    float* in_ptr       = grad_out_align; 
    float* in_ptr_aling = grad_out_align;

    // left backwards
    for (int i = size_l - 1; i >= 0; i--) {
        LayerSp* l           = this->networks_l_.at(i);
        grad_l_ptr         = new float[l->diff_src_desc().get_size() / sizeof(float) + 8];
        grad_l_aligned = ALIGN32(grad_l_ptr);
        l->backward(grad_l_aligned, in_ptr_aling);
        if (i != size_l - 1)
            delete [] in_ptr;
        in_ptr       = grad_l_ptr;
        in_ptr_aling = grad_l_aligned;
    }

    float* grad_r_ptr = nullptr;
    float* grad_r_aligned = grad_out_align;
    in_ptr       = grad_out_align; 
    in_ptr_aling = grad_out_align;

    for (int i = size_r - 1; i >= 0; i--) {
        LayerSp* l     = this->networks_r_.at(i);
        grad_r_ptr     = new float[l->diff_src_desc().get_size() / sizeof(float) + 8];
        grad_r_aligned = ALIGN32(grad_r_ptr);
        l->backward(grad_r_aligned, in_ptr_aling);
        if (i != size_r - 1)
            delete [] in_ptr;
        in_ptr       = grad_r_ptr;
        in_ptr_aling = grad_r_aligned;
    }

    // right grad reorder
    float* right_order_ptr = nullptr;
    float* right_order_align=grad_r_aligned;
	memory::desc left_desc  = this->diff_src_desc();
	memory::desc right_desc = this->final_relu_->diff_src_desc();
    if (this->networks_r_.size()>0) {
     right_desc = this->networks_r_.at(0)->diff_src_desc(); 
    }
    if (right_bwd_reorder_) {
    	
		right_order_ptr   = new float[left_desc.get_size() / sizeof(float) + 8];
    	right_order_align = ALIGN32(right_order_ptr);
    	reroder_right_b_pd_.execute(this->s_, {{DNNL_ARG_FROM, memory(right_desc, this->eng_, grad_r_aligned)}, 
                                              {DNNL_ARG_TO,    memory(left_desc,   this->eng_, right_order_align)}});
    }


    for (int i = 0; i < left_desc.get_size() / sizeof(float); i+=8) {
    	__m256 l = _mm256_load_ps(&grad_l_aligned[i]);
    	__m256 r = _mm256_load_ps(&right_order_align[i]);
    	l        = _mm256_add_ps(l, r);

    	_mm256_stream_ps(&diff_src_ptr[i], l);
    }	

    if (grad_r_ptr != nullptr)
    	delete [] grad_r_ptr;

    if (right_order_ptr != nullptr)
    	delete [] right_order_ptr;
    delete [] grad_l_ptr;
	delete [] grad_out_ptr;
}
