#ifndef _POOL_H_
#define _POOL_H_

#include <assert.h>

#include <math.h>
#include <numeric>

#include "example_utils.hpp"
#include "Enclave_t.h"

using namespace dnnl;

class Pool {
 public:
  Pool(memory::dims pool_dst_tz, memory::dims pool_src_tz, memory::dims pool_kernel,
	   memory::dims pool_strides, memory::dims pool_padding, memory::desc src_tags,
	   engine  eng, stream s) {
	using tag = memory::format_tag;
    using dt = memory::data_type;
	
	this->eng_ = eng;
	this->s_ = s;
	this->pool_kernel = pool_kernel;
	this->pool_strides = pool_strides;
	this->pool_padding = pool_padding;
	this->input_channels_ = pool_src_tz[1];
	if (this->input_channels_ == 1) {
	  this->input_channels_ = pool_src_tz[3];
	}
	
	pool_src_md = memory::desc({pool_src_tz}, dt::f32, tag::any);
	pool_dst_md = memory::desc({pool_dst_tz}, dt::f32, tag::any);
	auto pool_desc = pooling_forward::desc(prop_kind::forward,
            algorithm::pooling_max, src_tags, pool_dst_md,
            pool_strides, pool_kernel, pool_padding, pool_padding);
	pool_pd_ = pooling_forward::primitive_desc(pool_desc, eng_);
	net_fwd.push_back(pooling_forward(pool_pd_));

	
  }


  void forward(memory& pool_src_memory, memory& pool_dst_memory, memory& pool_work_memory) {
	net_fwd.at(0).execute(s_, {{DNNL_ARG_SRC, pool_src_memory}, {DNNL_ARG_WORKSPACE, pool_work_memory},
															{DNNL_ARG_DST, pool_dst_memory}});
	s_.wait();
  }

  void update_backward(memory::desc tags) {
	auto pool_bwd_desc = pooling_backward::desc(algorithm::pooling_max,
												pool_src_md, pool_dst_md, pool_strides, pool_kernel,
												pool_padding, pool_padding);
	
	pool_bwd_pd_ = pooling_backward::primitive_desc(pool_bwd_desc, eng_, pool_pd_);

	if (tags != pool_bwd_pd_.diff_dst_desc()) {
	  primitive_attr dummy_attr;
	  auto reorder_pd = reorder::primitive_desc(eng_, tags, eng_, pool_bwd_pd_.diff_dst_desc(), dummy_attr);
	  pool_diff_dst_memory = memory(pool_bwd_pd_.diff_dst_desc(), eng_);
	  diff_dst_reorder = true;
	  net_bwd.push_back(reorder(reorder_pd));
	}
	net_bwd.push_back(pooling_backward(pool_bwd_pd_));
  }

  void backward(memory& pool_prev_dst_memory, memory& pool_diff_src_memory, memory& pool_workspace_memory) {
	int i = 0;
	if (diff_dst_reorder) {
	  printf("reshaping\n");
	  net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, pool_prev_dst_memory},
                {DNNL_ARG_TO, pool_diff_dst_memory}});
	  i++;
	} else {
	  pool_diff_dst_memory = pool_prev_dst_memory;
	}

	net_bwd.at(i).execute(s_, {{DNNL_ARG_DIFF_DST, pool_diff_dst_memory},
            {DNNL_ARG_DIFF_SRC, pool_diff_src_memory},
            {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
	s_.wait();
  }

  memory::desc work_desc() {
	return pool_pd_.workspace_desc();
  }

  int input_channel() {
	return input_channels_;
  }
  
  memory::desc src_desc() {
	return pool_pd_.src_desc();
  }
  memory::desc dst_desc() {
    return pool_pd_.dst_desc();
  }

  memory::desc diff_src_desc() {
  	return pool_bwd_pd_.diff_src_desc();
  }

  memory::desc diff_dst_desc() {
	return pool_bwd_pd_.diff_src_desc();
  }
  
 private:
  int input_channels_;
  std::vector<primitive> net_fwd;
  std::vector<primitive> net_bwd;
  pooling_forward::primitive_desc pool_pd_;
  pooling_backward::primitive_desc pool_bwd_pd_;
  memory::desc pool_dst_md;
  memory::desc pool_src_md;
  memory::dims pool_strides;
  memory::dims pool_kernel;
  memory::dims pool_padding;
  memory::desc diff_desc_;
  memory pool_diff_dst_memory;
  bool diff_dst_reorder = false;
  engine eng_;
  stream s_;
};
#endif