#include "relu_sp.h"
#include "Enclave.h"

Relu_Sp::Relu_Sp(int relu_data_sz[4], memory::desc tags, 
	             engine eng, stream s, 
	            float negative_slope, bool use_python) {
	memory::dims relu_data_tz = {relu_data_sz[0], relu_data_sz[3], relu_data_sz[1], relu_data_sz[2]};
	printf("%d %d %d %d", relu_data_sz[0], relu_data_sz[1], relu_data_sz[2], relu_data_sz[3]);
	this->eng_ = eng;
	this->relu_data_tz_ = relu_data_tz;
	this->s_ = s;
	this->negative_slope_ = negative_slope;
	this->use_python_ = use_python;

	this->input_channels_ = relu_data_tz[1];
    if (this->input_channels_ == 1) {
      this->input_channels_ = relu_data_tz[3];
    }
    printf("here12");
    this->in_size_ = relu_data_sz[1] * relu_data_sz[2] * relu_data_sz[4];
    this->out_size_ = relu_data_sz[1] * relu_data_sz[2] * relu_data_sz[4];
    
    printf("here333");

	auto relu_desc = eltwise_forward::desc(prop_kind::forward,
            algorithm::eltwise_relu, tags, negative_slope);
	relu_pd_ = eltwise_forward::primitive_desc(relu_desc, eng_);
	if (relu_pd_.src_desc() != tags || relu_pd_.src_desc() != relu_pd_.dst_desc()) {
	  ocall_print_string("shit\n");
	}
	ocall_extern_alloc((void**) &dump_src_, 
		                           relu_pd_.src_desc().get_size());
	printf("323");

	net_fwd.push_back(eltwise_forward(this->relu_pd_));
	printf("after relu constructor");
}

void Relu_Sp::forward(float* src, float* dst, bool is_train) {
  	// dump src_
	if (sharding) {
		this->dump_src_input((void*) this->dump_src_, src, relu_pd_.src_desc().get_size());
		this->saved_src_ = nullptr;
	} else {
		this->saved_src_ = src;
	}

  	memory relu_src_memory = memory(relu_pd_.src_desc(), eng_, src);
  	memory relu_dst_memory = memory(relu_pd_.dst_desc(), eng_, dst);
	  net_fwd.at(0).execute(this->s_, {{DNNL_ARG_SRC, relu_src_memory}, {DNNL_ARG_DST, relu_dst_memory}});

	this->s_.wait();
}

void Relu_Sp::update_backward(memory::desc tags) {
	using tag = memory::format_tag;
    using dt = memory::data_type;
	
	auto relu_diff_dst_md = memory::desc({relu_data_tz_}, dt::f32, tag::any);
	auto relu_bwd_desc = eltwise_backward::desc(algorithm::eltwise_relu,
												relu_diff_dst_md, relu_pd_.src_desc(), negative_slope_);

	this->relu_bwd_pd_
            = eltwise_backward::primitive_desc(relu_bwd_desc, this->eng_, this->relu_pd_);
	if (tags != relu_bwd_pd_.diff_dst_desc() && !use_python_) {
	  primitive_attr dummy_attr;
	  auto reorder_pd = reorder::primitive_desc(eng_, tags, eng_, relu_bwd_pd_.diff_dst_desc(), dummy_attr);
	  
	  this->bwd_src_reorder_ = true;
	  net_bwd.push_back(reorder(reorder_pd));
	}
	net_bwd.push_back(eltwise_backward(relu_bwd_pd_));
}

void Relu_Sp::backward(float* diff_src_ptr, float* diff_dst_ptr) {
  	// load src
  	float* src      = this->saved_src_;
  	if (sharding) {
 		src = new float[relu_pd_.src_desc().get_size() / sizeof(float)];
 		this->load_src_input(src, this->dump_src_, relu_pd_.src_desc().get_size());
	}
  	memory relu_prev_dst_memory = memory(relu_pd_.dst_desc(), eng_, diff_dst_ptr);
  	memory relu_diff_src_memory = memory(relu_pd_.src_desc(), eng_, diff_src_ptr);
  	memory relu_src_memory      = memory(relu_pd_.src_desc(), eng_, src);

	int i = 0;
	memory relu_diff_dst_memory;
	if (bwd_src_reorder_ && !use_python_) {
		relu_diff_dst_memory = memory(relu_bwd_pd_.diff_dst_desc(), eng_);

	  	net_bwd.at(i).execute(s_, {
		  	{DNNL_ARG_FROM, relu_prev_dst_memory},
            {DNNL_ARG_TO, relu_diff_dst_memory}});
	  	i++;
	} else {
  		relu_diff_dst_memory = relu_prev_dst_memory;
	}
	
	net_bwd.at(i).execute(s_, {{DNNL_ARG_SRC, relu_src_memory},
	  		{DNNL_ARG_DIFF_DST, relu_diff_dst_memory},
			{DNNL_ARG_DIFF_SRC, relu_diff_src_memory}
  	});

  	// deleting src
	if (sharding && src != nullptr)
		delete src;

}