#ifndef _RELU_SP_H_
#define _RELU_SP_H_

#include <assert.h>

#include <math.h>
#include <numeric>

#include "example_utils.hpp"
#include "Enclave_t.h"
#include "layer.hpp"
using namespace dnnl;

class Relu_Sp : public LayerSp {
 public:
  Relu_Sp(int relu_data_sz[4], memory::desc tags, engine eng, stream s, 
          float negative_slope=0.0f, bool use_pyth=false);


  void forward(float* src, float* dst, bool is_train);

  

  void update_backward(memory::desc tags);

  void backward(float* diff_src_ptr, float* diff_dst_ptr);
  
  memory::desc dst_desc() {
    return relu_pd_.dst_desc();
  }

  memory::desc bwd_src_desc() {
	return this->relu_bwd_pd_.diff_src_desc();
  }

  memory::desc src_desc() {
	return relu_pd_.src_desc();
  }

  memory::desc diff_dst_desc() {
	return this->relu_bwd_pd_.diff_dst_desc();
  }

  memory::desc diff_src_desc() {
    return this->relu_bwd_pd_.diff_src_desc();
  }
  int input_channels() {
	return input_channels_;
  }

  int out_size() {return out_size_;}
  int input_size() {return in_size_;}
  int type() {return 0;}
 private:
  int input_channels_;
  std::vector<primitive> net_fwd;
  std::vector<primitive> net_bwd;
  eltwise_forward::primitive_desc relu_pd_;
  eltwise_backward::primitive_desc relu_bwd_pd_;
  float negative_slope_;
  memory::dims relu_data_tz_;
  bool bwd_src_reorder_ = false;
  engine eng_;
  stream s_;
  bool use_python_;
  float* dump_src_;
  float* saved_src_;
  int in_size_;
  int out_size_;
};
#endif
