#ifndef _LINEAR_SP_H_
#define _LINEAR_SP_H_

#include "example_utils.hpp"
#include "Enclave_t.h"
#include "layer.hpp"

#include "dnnl.hpp"

using namespace dnnl;

class LinearSp : public LayerSp {
  public:
    LinearSp(int size[4], int out_size[2], float* kernel_data, float* bias_data, memory::desc srg_tag, engine eng, stream s);
    void forward(float* src, float* dst, bool is_train);
    void update_backward(memory::desc diff_src_tag);
    void backward(float*, float*);

    int out_size() {return this->output_size_;}

    int input_size() {return this->input_size_;}

    int type() {return 32;}
    memory::desc diff_src_desc() {return nhwc_;}
    ~LinearSp() {delete linear_;}
    memory::desc dst_desc()      {return nhwc_out_;}


    void* linear_;
    memory::desc nhwc_;
    memory::desc nhwc_out_;
    engine eng_;
    stream s_;
    float* src_dump_;
    float* weight_dump_;
    std::vector<primitive> net_fwd;
    std::vector<primitive> net_bwd;
    memory::desc src_tag_;
    memory::desc diff_src_tag_;
    int input_size_;
    int output_size_;
    bool input_reorder_ = false;
    bool grad_reorder_  = false;

}; // class LinearSp


#endif // _LINEAR_SP_H_