#ifndef DEPTHWISE_CONV_SP_H_
#define DEPTHWISE_CONV_SP_H_
#include "dnnl.hpp"

using namespace dnnl;

#include "layer.hpp"

class DepthwiseConv2dSp: public LayerSp {
  public:
    DepthwiseConv2dSp(
           int conv_src_sz[4], int conv_dst_sz[4], int conv_weight_sz[4],
           int conv_strides_sz[2], int conv_padding_sz[2], memory::desc src_tags,
           engine eng, stream s, float* weight_data, float* bias_data);

    void forward(float* src, float* dst, bool is_train);
    
    void update_backward(memory::desc diff_src_tag);
    void backward(float*, float*);

    int out_size();

    int input_size();

    int type() {return 0;}
    memory::desc diff_src_desc();
    ~DepthwiseConv2dSp() {delete depthwise_;}
    memory::desc dst_desc(); 


  private:
    float* src_dump_;
    void* depthwise_;
    memory::desc src_tag_;
    memory::desc diff_src_tag_;
    int input_size_;
    int output_size_;
    bool input_reorder_ = false;
    bool grad_reorder_ = false;
    std::vector<primitive> net_fwd;
    std::vector<primitive> net_bwd;
    memory::desc nhwc_;
    memory::desc nhwc_out_;
    engine eng_;
    stream s_;
}; // class DepthwiseConv2dSp
#endif