#ifndef _RELU_H_
#define _RELU_H_

#include <assert.h>

#include <math.h>
#include <numeric>

#include "example_utils.hpp"
#include "Enclave_t.h"

using namespace dnnl;

class Relu {
 public:
  Relu(memory::dims relu_data_tz, memory::desc tags, engine eng, stream s, float negative_slope=0.0f, bool use_python=false) {
	this->eng_ = eng;
	this->relu_data_tz_ = relu_data_tz;
	this->s_ = s;
	this->negative_slope_ = negative_slope;
	this->use_python_ = use_python;

	this->input_channels_ = relu_data_tz[1];
    if (this->input_channels_ == 1) {
      this->input_channels_ = relu_data_tz[3];
    }

	auto relu_desc = eltwise_forward::desc(prop_kind::forward,
            algorithm::eltwise_relu, tags, negative_slope);
	relu_pd_ = eltwise_forward::primitive_desc(relu_desc, eng_);
	if (relu_pd_.src_desc() != tags || relu_pd_.src_desc() != relu_pd_.dst_desc()) {
	  ocall_print_string("shit\n");
	}
	net_fwd.push_back(eltwise_forward(this->relu_pd_));
  }


  void forward(memory& relu_src_memory, memory& relu_dst_memory) {
	net_fwd.at(0).execute(this->s_, {{DNNL_ARG_SRC, relu_src_memory}, {DNNL_ARG_DST, relu_dst_memory}});
  }

  

  void update_backward(memory::desc tags) {
	using tag = memory::format_tag;
    using dt = memory::data_type;
	
	auto relu_diff_dst_md = memory::desc({relu_data_tz_}, dt::f32, tag::any);
	auto relu_bwd_desc = eltwise_backward::desc(algorithm::eltwise_relu,
												relu_diff_dst_md, relu_pd_.src_desc(), negative_slope_);

	this->relu_bwd_pd_
            = eltwise_backward::primitive_desc(relu_bwd_desc, this->eng_, this->relu_pd_);
	if (tags != relu_bwd_pd_.diff_dst_desc() && !use_python_) {
	  primitive_attr dummy_attr;
	  auto reorder_pd = reorder::primitive_desc(eng_, tags, eng_, relu_bwd_pd_.diff_dst_desc(), dummy_attr);
	  relu_diff_dst_memory = memory(relu_bwd_pd_.diff_dst_desc(), eng_);
	  this->bwd_src_reorder_ = true;
	  net_bwd.push_back(reorder(reorder_pd));
	}
	net_bwd.push_back(eltwise_backward(relu_bwd_pd_));
  }

  void backward(memory& relu_prev_dst_memory, memory& relu_diff_src_memory, memory& relu_src_memory) {
	int i = 0;
	if (bwd_src_reorder_ && !use_python_) {
	  net_bwd.at(i).execute(s_, {
		  {DNNL_ARG_FROM, relu_prev_dst_memory},
            {DNNL_ARG_TO, relu_diff_dst_memory}});
	  i++;
	} else {
	  relu_diff_dst_memory = relu_prev_dst_memory;
	}
	
	net_bwd.at(i).execute(s_, {{DNNL_ARG_SRC, relu_src_memory},
		  {DNNL_ARG_DIFF_DST, relu_diff_dst_memory},
			{DNNL_ARG_DIFF_SRC, relu_diff_src_memory}
	  });
  }
  
  memory::desc dst_desc() {
    return relu_pd_.dst_desc();
  }

  memory::desc bwd_src_desc() {
	return this->relu_bwd_pd_.diff_src_desc();
  }

  memory::desc src_desc() {
	return relu_pd_.src_desc();
  }

  memory::desc diff_dst_desc() {
	return this->relu_bwd_pd_.diff_dst_desc();
  }

  int input_channels() {
	return input_channels_;
  }
  
 private:
  int input_channels_;
  std::vector<primitive> net_fwd;
  std::vector<primitive> net_bwd;
  eltwise_forward::primitive_desc relu_pd_;
  eltwise_backward::primitive_desc relu_bwd_pd_;
  float negative_slope_;
  memory::dims relu_data_tz_;
  memory relu_diff_dst_memory;
  bool bwd_src_reorder_ = false;
  engine eng_;
  stream s_;
  bool use_python_;
};
#endif