#ifndef SGXDNN_INVERTED_CELL_SP_H_
#define SGXDNN_INVERTED_CELL_SP_H_
#include "dnnl.hpp"

using namespace dnnl;

#include <string>
#include <vector>
#include "immintrin.h"
#include "layer.hpp"
#include "batchnorm_sp.h"
class InvertedCellSp : public LayerSp {
  public:
    InvertedCellSp(engine eng, stream s) {eng_ = eng; s_ = s;}

    void forward(float* src, float* dst, bool is_train);
    
    void update_backward(memory::desc diff_src_tag);
    void backward(float* diff_src_ptr, float* diff_dst_ptr);

    int out_size() {return output_size_;}

    int input_size() {return input_size_;}

    int type() {return 1;}

    void add_layer(LayerSp* l) {network_.push_back(l);};

    memory::desc dst_desc() {return network_.at(network_.size()-1)->dst_desc();}
    memory::desc diff_src_desc();
  private:
    std::vector<LayerSp*> network_;
    primitive reorder_grad_;
    std::string mode_;
    memory::desc src_tag_;
    int input_size_;
    int output_size_;
    bool input_reorder_ = false;
    bool grad_out_reorder_ = false;
    std::vector<primitive> net_fwd;
    memory::desc nhwc_;
    engine eng_;
    stream s_;
    float* src_dump_;
}; // class InvertedCellSp

#endif