#ifndef SGXDNN_BATCHNORM_SP_H_
#define SGXDNN_BATCHNORM_SP_H_
#include "dnnl.hpp"

using namespace dnnl;

#include <string>
#include "immintrin.h"
#include "layer.hpp"
class BatchNormSP : public LayerSp {
  public:
    BatchNormSP(int size[4], int mode, float eps, float momentum, memory::desc src_tag, engine eng, stream s);

     
    void forward_special (float* src, float* dst, float* skip, bool is_train);
    void forward(float* src, float* dst, bool is_train);
    
    void update_backward(memory::desc diff_src_tag);
    void backward(float*, float*);
    void backward_special(float*, float*, float*);

    int out_size();

    int input_size();

    int type() {return 1;}
    memory::desc diff_src_desc();
    ~BatchNormSP() {delete batchnorm_;}
    memory::desc dst_desc();

    void* batchnorm_;
    std::string mode_;
    int mode_num_;
    memory::desc src_tag_;
    memory::desc diff_src_tag_;
    int input_size_;
    int output_size_;
    bool input_reorder_ = false;
    bool grad_reorder_ = false;
    std::vector<primitive> net_fwd;
    std::vector<primitive> net_bwd;
    memory::desc nhwc_;
    engine eng_;
    stream s_;
    float* src_dump_;
    float* act_src_dump_;
    float* act_src_ptr_;
    float* act_src_;
}; // class BatchNormSP

#endif