#ifndef _CONV2D_H_
#define _CONV2D_H_

#include <assert.h>

#include <math.h>
#include <numeric>

#include "example_utils.hpp"
#include "Enclave_t.h"

using namespace dnnl;

static memory::dim product(const memory::dims &dims) {
    return std::accumulate(dims.begin(), dims.end(), (memory::dim)1,
            std::multiplies<memory::dim>());
}

class Conv2d {
 public:
  Conv2d(memory::dims conv_src_tz, memory::dims conv_weights_tz, memory::dims conv_dst_tz,
         memory::dims conv_bias_tz, memory::dims conv_strides, memory::dims conv_padding,
		 memory::desc src_tag, memory::desc weights_tag,
		 engine  eng, stream s, bool training=true, bool is_first=false) {
    // namespaces
    using tag = memory::format_tag;
    using dt = memory::data_type;

    auto conv_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any);
    auto conv_bias_md = memory::desc({conv_bias_tz}, dt::f32, tag::any);
    auto conv_weights_md = memory::desc({conv_weights_tz}, dt::f32, tag::any);
    auto conv_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::any);

    // convolution primitive descriptor
    auto conv_desc = convolution_forward::desc(prop_kind::forward,
            algorithm::convolution_direct, conv_src_md, conv_weights_md,
            conv_bias_md, conv_dst_md, conv_strides, conv_padding,
            conv_padding);
	conv_pd_ = convolution_forward::primitive_desc(conv_desc, eng);
	src_reorder_ = conv_pd_.src_desc() != src_tag;
	weights_reorder_ = conv_pd_.weights_desc() != weights_tag;
	primitive_attr dummy_attr;
	if (src_reorder_) {
	  auto reorder_pd = reorder::primitive_desc(eng, src_tag, eng, conv_pd_.src_desc(), dummy_attr);
	  conv_src_memory = memory(conv_pd_.src_desc(), eng);
	  net_fwd.push_back(reorder(reorder_pd));
	}
	if ( weights_reorder_) {
	  auto reorder_pd =	reorder::primitive_desc(eng, weights_tag, eng, conv_pd_.weights_desc(), dummy_attr);
	  conv_weights_memory = memory(conv_pd_.weights_desc(), eng);
      net_fwd.push_back(reorder(reorder_pd));
	}
	net_fwd.push_back(convolution_forward(conv_pd_));
	eng_ = eng;
	s_ = s;
	is_first_ = is_first;
	is_train_ = training;


  }


    void forward(memory& conv_user_src_memory, memory& conv_user_weights_memory,
               memory& conv_user_bias_memory, memory& conv_dst_memory) {
    int i = 0;
    // reshape operators                                                                                                   
    if (src_reorder_) {
      net_fwd.at(i).execute(s_,
                            {{DNNL_ARG_FROM, conv_user_src_memory},
                                {DNNL_ARG_TO, conv_src_memory}});
        i++;
    } else {
      conv_src_memory = conv_user_src_memory;
    }

    if (weights_reorder_) {
      net_fwd.at(i).execute(s_,
                {{DNNL_ARG_FROM, conv_user_weights_memory},
               {DNNL_ARG_TO, conv_weights_memory}});
      i++;
    } else {
      conv_weights_memory = conv_user_weights_memory;
    }

    // actual convolution                                                                                                  
    net_fwd.at(i).execute(s_, {{DNNL_ARG_SRC, conv_src_memory},
            {DNNL_ARG_WEIGHTS, conv_user_weights_memory},
            {DNNL_ARG_BIAS, conv_user_bias_memory},
              {DNNL_ARG_DST, conv_dst_memory}});

  }

  

  void update_backward(memory::dims conv_src_tz, memory::dims conv_weights_tz, memory::dims conv_dst_tz,
					   memory::dims conv_bias_tz, memory::dims conv_strides, memory::dims conv_padding,
					   memory::desc diff_src_tag, memory::desc diff_weights_tag) {
	if (!is_train_)
	  return;

	  // namespaces
    using tag = memory::format_tag;
    using dt = memory::data_type;

    // backward to gradients
    // backward memory descriptors
    auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any);
    auto conv_diff_bias_md = memory::desc({conv_bias_tz}, dt::f32, tag::any);
    auto conv_diff_weights_md = memory::desc({conv_weights_tz}, dt::f32, tag::any);
    auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::any);
    auto conv_diff_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any);

	// backwards to input
	if (!is_first_) {
	  // creating backward data descriptors
	  auto conv_bwd_data_desc = convolution_backward_data::desc(algorithm::convolution_direct,
																conv_bwd_src_md,
																conv_diff_weights_md,
																conv_diff_dst_md,
																conv_strides,
																conv_padding, conv_padding
																);
	  auto conv_bwd_data_pd = convolution_backward_data::primitive_desc(conv_bwd_data_desc, eng_, conv_pd_);

	  if (conv_pd_.weights_desc() != conv_bwd_data_pd.weights_desc()) {
		auto reorder_pd = reorder::primitive_desc(eng_, conv_pd_.weights_desc(), eng_, conv_bwd_data_pd.weights_desc());
		conv_bwd_weights_memory = memory(conv_bwd_data_pd.weights_desc(), eng_);
		net_bwd.push_back(reorder(reorder_pd));
		bwd_weights_reorder_ = true;
	  }
	  to_back = conv_bwd_data_pd.diff_src_desc();
	  net_bwd.push_back(convolution_backward_data(conv_bwd_data_pd));
    }
	
    // create backward convolution primitive descriptor
    auto conv_bwd_weights_desc
      = convolution_backward_weights::desc(algorithm::convolution_direct,
                                           conv_bwd_src_md, conv_diff_weights_md, conv_diff_bias_md,
                                           conv_diff_dst_md, conv_strides, conv_padding, conv_padding);
    auto conv_bwd_weights_pd =
      convolution_backward_weights::primitive_desc(conv_bwd_weights_desc, eng_, conv_pd_);
    primitive_attr dummy_attr;
	
    // reshape conv_src
    if (conv_bwd_weights_pd.src_desc() != conv_pd_.src_desc()) {
	  auto reorder_pd = reorder::primitive_desc(eng_, conv_pd_.src_desc(),
                                                eng_, conv_bwd_weights_pd.src_desc(), dummy_attr);
	  conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_desc(), eng_);
      net_bwd.push_back(reorder(reorder_pd));
      bwd_src_reorder_ = true;
	}


	if (diff_src_tag != conv_bwd_weights_pd.diff_dst_desc()) {
	  auto reorder_pd = reorder::primitive_desc(eng_, diff_src_tag,
	                                           eng_, conv_bwd_weights_pd.diff_dst_desc(), dummy_attr);
	  conv_diff_dst_memory = memory(conv_bwd_weights_pd.diff_dst_desc(), eng_);
	  net_bwd.push_back(reorder(reorder_pd));
	  bwd_dst_reorder_ = true;
	}
	net_bwd.push_back(convolution_backward_weights(conv_bwd_weights_pd));
	conv_diff_weights_memory = memory(conv_bwd_weights_pd.diff_weights_desc(), eng_);
	
	// inplace reorder
	if (diff_weights_tag != conv_bwd_weights_pd.diff_weights_desc()) {
	  auto reorder_pd = reorder::primitive_desc(eng_, conv_bwd_weights_pd.diff_weights_desc(),
                                                eng_, diff_weights_tag, dummy_attr);
	  
	  net_bwd.push_back(reorder(reorder_pd));
	  bwd_gradients_reorder_ = true;
	}
  }

  void backward(memory& conv_bwd_data_memory, memory& conv_dst_memory, memory& conv_diff_bias_memory) {
	int i = 0;
	if (!is_first_) {
	  if (bwd_weights_reorder_) {
		net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, conv_weights_memory},
            {DNNL_ARG_TO, conv_bwd_weights_memory}}) ;
		i++;
	  } else {
		conv_bwd_weights_memory = conv_weights_memory;
	  }


	  net_bwd.at(i).execute(s_, {{DNNL_ARG_DIFF_DST, conv_dst_memory},
          {DNNL_ARG_WEIGHTS, conv_bwd_weights_memory},
            {DNNL_ARG_DIFF_SRC, conv_bwd_data_memory}
		});
	  i++;
	}

	// backwards to gradients with lots of reorders
	if (bwd_src_reorder_) {
	  net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, conv_src_memory},
			{DNNL_ARG_TO, conv_bwd_src_memory}});
	  i++;
	} else {
	  conv_bwd_src_memory = conv_src_memory;
	}

	if (bwd_dst_reorder_) {
	  net_bwd.at(i).execute(s_, {{DNNL_ARG_FROM, conv_dst_memory},
			{DNNL_ARG_TO, conv_diff_dst_memory}});
	  i++;
	} else {
	  conv_diff_dst_memory = conv_dst_memory;
	}

	// actual gradients
	net_bwd.at(i).execute(s_, {{DNNL_ARG_DIFF_DST, conv_diff_dst_memory},
		  {DNNL_ARG_SRC, conv_bwd_src_memory},
				{DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory},                              
			  {DNNL_ARG_DIFF_BIAS, conv_diff_bias_memory}});
	i++;
  }

  memory::desc dst_desc() {
	return conv_pd_.dst_desc();
  }

  memory::desc bwd_src_desc() {
	return to_back;
  }
  
 private:
  std::vector<primitive> net_fwd;
  std::vector<primitive> net_bwd;
  engine eng_;
  stream s_;
  convolution_forward::primitive_desc conv_pd_;
  bool is_first_ = false;
  bool is_train_ = false;
  bool src_reorder_ = false;
  bool weights_reorder_ = false;
  bool bwd_src_reorder_ = false;
  bool bwd_dst_reorder_ = false;
  bool bwd_gradients_reorder_ = false;
  bool bwd_weights_reorder_ = false;
  memory conv_weights_memory;
  memory conv_src_memory;
  memory conv_bwd_src_memory;
  memory conv_diff_dst_memory;
  memory conv_bwd_weights_memory;
  memory conv_diff_weights_memory;
  memory::desc to_back;
  



};
#endif