#ifndef _RESNET_BLOCK_H_
#define _RESNET_BLOCK_H_

#include <assert.h>

#include <math.h>
#include <numeric>

#include "example_utils.hpp"
#include "Enclave_t.h"
#include "layer.hpp"

class ResnetBlockSp : public LayerSp {
  public:
    ResnetBlockSp(int in_size[4], int out_size[4], 
    			  int stride[4], bool identity,
    			  float* weight_data, float* bias_data,
    	          memory::desc tags, 
    	          engine eng, stream s);

    void forward(float* src, float* dst, bool is_train);
    void setup_final_relu();
    void update_backward(memory::desc diff_src_tag);
    void backward(float* diff_src_ptr, float* diff_dst_ptr);

    void add_layer(LayerSp* l) {networks_l_.push_back(l);}
    int type() {return 0;}
    int out_size() {return this->output_size_;}
    int input_size() {return this->input_size_;}
    memory::desc dst_desc() {return final_relu_->dst_desc();}
    memory::desc diff_src_desc() {return networks_l_.at(0)->diff_src_desc();}
    int relu_in_size_[4];

  private:
    std::vector<LayerSp*> networks_l_;
    std::vector<LayerSp*> networks_r_;
    primitive reorder_grad_;
    std::string mode_;
    memory::desc src_tag_;
    int input_size_;
    int output_size_;
    bool input_reorder_ = false;
    bool grad_out_reorder_ = false;
    std::vector<primitive> net_fwd;
    memory::desc nhwc_;
    engine eng_;
    stream s_;
    float* src_dump_;
    bool identity_;
    LayerSp* final_relu_;
    memory::desc right_dst_desc_;
    bool right_fwd_reorder_=false;
    bool right_bwd_reorder_=false;
    primitive reroder_right_f_pd_;
    primitive reroder_right_b_pd_;
}; // class ResnetBlockSp

#endif