
#ifndef _RESNET_BOTTOM_SP_H_
#define _RESNET_BOTTOM_SP_H_
#include <assert.h>

#include <math.h>
#include <numeric>
#include <string>
#include "example_utils.hpp"
#include "Enclave_t.h"
#include "Enclave.h"
#include "layer.hpp"
#include "pool_sp.h"
using namespace dnnl;
using namespace std;
#include "immintrin.h"

class ResnetBottom : public LayerSp{
  public:
    ResnetBottom(string act_mode, 
	             int    in_size[4], 
	             int    out_size[4], 
	             float  eps,
                 float  momentum,
	             float* bias_data_l,
	             float* bias_data_r,
	             memory::desc src_tags, 
	             engine eng, 
	             stream s);

    void forward(float*, float*, bool) {}
    void forward_sp(float* left_in, float* right_in, float* mean_left, float* mean_right, 
	                float* dst, bool training);

    void update_backward(memory::desc) {};
    void backward(float*, float*) {}
    void backward_sp(float* left_grad, float* right_grad, float* diff_dst_ptr);
    int type() {return 0;}
    int out_size()   {return this->out_size_;}
    int input_size() {return this->in_size_;}
    
    memory::desc diff_src_desc() {return  nhwc_;}
    memory::desc dst_desc() {return nhwc_out_;}
    ~ResnetBottom() {}
  private:

    bool diff_dst_reorder = false;
    engine eng_;
    stream s_;
    int in_size_;
    int out_size_;
    int channel_;

    memory::desc src_tag_;
    int internal_batch_size_;
    int batch_size_;
    std::string act_mode_;
    PoolSp* pool_ = nullptr;
    memory::desc nhwc_;
    memory::desc nhwc_out_;

    ResnetActivation* left_norm_;
    ResnetActivation* right_norm_;

    float* left_res_ptr_;
	float* left_res_align_;
	float* right_res_ptr_;
	float* right_res_align_;
    float* relu_temp_align_;
    float* relu_temp_ptr_;
}; // class ResnetBottom


#endif