#ifndef _POOL_SP_H_
#define _POOL_SP_H_

#include <assert.h>

#include <math.h>
#include <numeric>

#include "example_utils.hpp"
#include "Enclave_t.h"
#include "layer.hpp"
using namespace dnnl;

class PoolSp : public LayerSp {
  public:
    PoolSp(
         int src_sz[4], 
  	     int dst_sz[4], 
  	     int kernel_sz[2], 
  	     int stride_tz[2], 
  	     int padding[2],
         int type, 
  	     memory::desc src_tags,
	     engine  eng, stream s);

    void forward(float*, float*, bool);
    void forward_resnet(float*, float*, float*, bool);

    void update_backward(memory::desc);

    void backward(float*, float*);
    void backward_resnet(float*, float*);
    int type() {return 0;}

    int out_size()   {return this->out_size_;}

    int input_size() {return this->in_size_;}
    
    memory::desc diff_src_desc() {return  pool_bwd_pd_.diff_src_desc();}
    memory::desc dst_desc() {return pool_pd_.dst_desc();}
  
  int input_channels_;
  float* dump_src_;
  std::vector<primitive> net_fwd;
  std::vector<primitive> net_bwd;
  pooling_forward::primitive_desc pool_pd_;
  pooling_backward::primitive_desc pool_bwd_pd_;
  memory::desc pool_dst_md;
  memory::desc pool_src_md;
  memory::desc src_tags_;
  algorithm pool_alg_;
  memory::dims pool_strides;
  memory::dims pool_kernel;
  memory::dims pool_padding;
  memory pool_diff_dst_memory;
  memory::desc diff_desc_;

  bool diff_dst_reorder = false;
  engine eng_;
  stream s_;
  int in_size_;
  int out_size_;
  float* saved_work_=nullptr;
  memory::desc nhwc_;
  memory::desc nhwc_out_;
};
#endif
