#include "pool_sp.h"
#include "dnnl.hpp"

using namespace dnnl;

PoolSp::PoolSp(int src_sz[4], 
	  	       int dst_sz[4], 
	  	       int kernel_sz[2], 
	  	       int stride_tz[2], 
	  	       int padding[2],
	           int type, 
	  	       memory::desc src_tags,
		       engine  eng, stream s) {
	using tag = memory::format_tag;
    using dt = memory::data_type;
    printf("entered pool setup");
    memory::dims pool_dst_tz = {dst_sz[0], dst_sz[3], dst_sz[1], dst_sz[2]};
    memory::dims pool_src_tz = {src_sz[0], src_sz[3], src_sz[1], src_sz[2]};
    memory::dims pool_kernel = {kernel_sz[0], kernel_sz[1]};
    memory::dims pool_strides= {stride_tz[0], stride_tz[1]};
    memory::dims pool_padding= {padding[0], padding[1]};
    this->in_size_  = src_sz[0] * src_sz[3] * src_sz[1] * src_sz[2];
    this->out_size_ = dst_sz[0] * dst_sz[3] * dst_sz[1] * dst_sz[2];
    this->eng_ = eng;
	this->s_ = s;
	this->pool_kernel = pool_kernel;
	this->pool_strides = pool_strides;
	this->pool_padding = pool_padding;
	this->input_channels_ = pool_src_tz[1];
	printf("pool type %d", type);
	printf("src shape %d %d %d %d", pool_src_tz[0], pool_src_tz[1], pool_src_tz[2], pool_src_tz[3]);
	printf("dst shape %d %d %d %d", pool_dst_tz[0], pool_dst_tz[1], pool_dst_tz[2], pool_dst_tz[3]);
	printf("ker shape %d %d", pool_kernel[0], pool_kernel[1]);
	printf("str shape %d %d", pool_strides[0], pool_strides[1]);
	printf("pad shape %d %d", pool_padding[0], pool_padding[1]);

	pool_alg_ = algorithm::pooling_max;
	if (type != 0)
		pool_alg_ = algorithm::pooling_avg;


	pool_src_md = memory::desc({pool_src_tz}, dt::f32, tag::any);
	pool_dst_md = memory::desc({pool_dst_tz}, dt::f32, tag::any);
	auto pool_desc = pooling_forward::desc(prop_kind::forward,
            pool_alg_, src_tags, pool_dst_md,
            pool_strides, pool_kernel, pool_padding, pool_padding);

	pool_pd_ = pooling_forward::primitive_desc(pool_desc, eng_);
	net_fwd.push_back(pooling_forward(pool_pd_));
	ocall_extern_alloc((void**) &dump_src_, 
		                           sizeof(float) * pool_pd_.workspace_desc().get_size());	

	src_tags_ = src_tags;
}

void PoolSp::forward(float* src, float* dst, bool is_first=false) {
	memory pool_src_memory  = memory(src_tags_, eng_, (void*) src);
  	memory pool_dst_memory  = memory(pool_pd_.dst_desc(), eng_, (void*) dst);
  	
  	float* work             = new float[pool_pd_.workspace_desc().get_size() / sizeof(float)];
  	memory pool_work_memory = memory(pool_pd_.workspace_desc(), eng_, work);
    net_fwd.at(0).execute(s_, {{DNNL_ARG_SRC, pool_src_memory}, {DNNL_ARG_WORKSPACE, pool_work_memory},
								{DNNL_ARG_DST, pool_dst_memory}});

    s_.wait();

    if (sharding) {
    	this->dump_src_input((void*) this->dump_src_, work, pool_pd_.workspace_desc().get_size());
    	this->saved_work_ = nullptr;
    	delete[] work;
    } else {
    	this->saved_work_ = work;
    }

}

void PoolSp::forward_resnet(float* src, float* dst, float* work, bool is_first=false) {
	memory pool_src_memory  = memory(pool_pd_.src_desc(), eng_, (void*) src);
  	memory pool_dst_memory  = memory(pool_pd_.dst_desc(), eng_, (void*) dst);
  	
  	memory pool_work_memory = memory(pool_pd_.workspace_desc(), eng_, work);
    net_fwd.at(0).execute(s_, {{DNNL_ARG_SRC, pool_src_memory}, {DNNL_ARG_WORKSPACE, pool_work_memory},
								{DNNL_ARG_DST, pool_dst_memory}});
    this->saved_work_ = work;
    s_.wait();
}

void PoolSp::update_backward(memory::desc diff_desc) {

	auto pool_bwd_desc = pooling_backward::desc(pool_alg_,
												this->pool_src_md, this->pool_dst_md, this->pool_strides, this->pool_kernel,
												this->pool_padding, this->pool_padding);
	pool_bwd_pd_ = pooling_backward::primitive_desc(pool_bwd_desc, eng_, pool_pd_);
	if (diff_desc != pool_bwd_pd_.diff_dst_desc()) {
	  primitive_attr dummy_attr;

	  auto reorder_pd = reorder::primitive_desc(eng_, diff_desc, eng_, pool_bwd_pd_.diff_dst_desc(), dummy_attr);
	  diff_dst_reorder = true;
	  net_bwd.push_back(reorder(reorder_pd));
	}
	this->diff_desc_ = diff_desc;
	net_bwd.push_back(pooling_backward(pool_bwd_pd_));
}

void PoolSp::backward(float* diff_src_ptr, float* diff_dst_ptr) {
	  

	float* workspace = this->saved_work_;
	float* diff_dst_reorder_ptr = diff_dst_ptr;
	if (sharding) {
		workspace = new float[pool_pd_.workspace_desc().get_size() / sizeof(float)];
	    this->load_src_input(workspace, this->dump_src_, pool_pd_.workspace_desc().get_size());
	}

	if (diff_dst_reorder) {
		diff_dst_reorder_ptr = new float[pool_bwd_pd_.diff_dst_desc().get_size() / sizeof(float)];
	}
	pool_diff_dst_memory = memory(pool_bwd_pd_.diff_dst_desc(), eng_, diff_dst_reorder_ptr);
	memory pool_prev_dst_memory  = memory(this->diff_desc_, eng_, diff_dst_ptr);
	memory pool_diff_src_memory  = memory(pool_bwd_pd_.diff_src_desc(), eng_, diff_src_ptr);
	memory pool_workspace_memory = memory(pool_pd_.workspace_desc(), eng_, workspace);  	
	int i = 0;
	if (diff_dst_reorder) {
		net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, pool_prev_dst_memory},
	          {DNNL_ARG_TO, pool_diff_dst_memory}});
		i++;
	} else {
		pool_diff_dst_memory = pool_prev_dst_memory;
	}

	net_bwd.at(i).execute(s_, {{DNNL_ARG_DIFF_DST, pool_diff_dst_memory},
	      {DNNL_ARG_DIFF_SRC, pool_diff_src_memory},
	      {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
	s_.wait();

	if (sharding)
		delete [] workspace;
	if (diff_dst_reorder)
		delete [] diff_dst_reorder_ptr;
}

void PoolSp::backward_resnet(float* diff_src_ptr, float* diff_dst_ptr) {
	  

	float* workspace = this->saved_work_;
	float* diff_dst_reorder_ptr = diff_dst_ptr;
	
	if (diff_dst_reorder) {
		diff_dst_reorder_ptr = new float[pool_bwd_pd_.diff_dst_desc().get_size() / sizeof(float)];
	}
	pool_diff_dst_memory = memory(pool_bwd_pd_.diff_dst_desc(), eng_, diff_dst_reorder_ptr);
	memory pool_prev_dst_memory  = memory(this->diff_desc_, eng_, diff_dst_ptr);
	memory pool_diff_src_memory  = memory(pool_pd_.src_desc(), eng_, diff_src_ptr);
	memory pool_workspace_memory = memory(pool_pd_.workspace_desc(), eng_, workspace);  	
	int i = 0;
	if (diff_dst_reorder) {
		net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, pool_prev_dst_memory},
	          {DNNL_ARG_TO, pool_diff_dst_memory}});
		i++;
	} else {
		pool_diff_dst_memory = pool_prev_dst_memory;
	}

	net_bwd.at(i).execute(s_, {{DNNL_ARG_DIFF_DST, pool_diff_dst_memory},
	      {DNNL_ARG_DIFF_SRC, pool_diff_src_memory},
	      {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
	s_.wait();

	if (diff_dst_reorder)
		delete [] diff_dst_reorder_ptr;
}