#ifndef SGXDNN_CONV2DPURE_H_
#define SGXDNN_CONV2DPURE_H_

#define EIGEN_USE_TENSOR

#include <stdio.h>
#include <iostream>
#include <string>
#include <type_traits>
#include <assert.h>

#include "../mempool.hpp"
#include "../utils.hpp"
#include "../Crypto.h"
#include "layer.hpp"
#include "activation.hpp"
#include "eigen_spatial_convolutions.h"
#include <cmath>
#include "immintrin.h"

#ifndef USE_SGX
#include <chrono>
#else
#include "Enclave.h"
#include "sgx_tcrypto.h"
#include "Crypto.h"
#endif

using namespace tensorflow;

namespace SGXDNN
{

	template <typename T>
	class Conv2DPure : public Layer<T>
	{
	public:
		Conv2DPure(const std::string& name,
			   const array4d input_shape,
               const array4d& kernel_shape,
               const int row_stride,
               const int col_stride,
               const Eigen::PaddingType& padding,
           
			   MemPool* mem_pool,
			   bool is_verif_mode,
			   bool verif_preproc,
			   const std::string& activation_type
			   ): Layer<T>(name, input_shape),
			   kernel_shape_(kernel_shape),
               row_stride_(row_stride),
               col_stride_(col_stride),
               padding_(padding),
               kernel_data_(nullptr),
               bias_data_(nullptr),
               kernel_(NULL, kernel_shape),
               bias_(NULL, kernel_shape[3]),
			   mem_pool_(mem_pool),
			   h(input_shape[1]),
			   w(input_shape[2]),
			   ch_in(kernel_shape[2]),
			   h_out(0),
			   w_out(0),
			   ch_out(kernel_shape[3]),
			   patch_size(kernel_shape[0] * kernel_shape[1]),
			   image_size(input_shape[1] * input_shape[2]),
			   out_image_size(0)
		{
			const int filter_rows = kernel_shape[0];
			const int filter_cols = kernel_shape[1];

			GetWindowedOutputSize(h, filter_rows, row_stride_,
								  padding_, &h_out, &pad_rows_);
			GetWindowedOutputSize(w, filter_cols, col_stride_,
								  padding_, &w_out, &pad_cols_);

			printf("in Conv2D with out_shape = (%d, %d, %d)\n", h_out, w_out, ch_out);
			output_shape_ = {0, h_out, w_out, ch_out};
			output_size_ = h_out * w_out * ch_out;
			input_shape_ = {0, h, w, ch_in};
			input_size_ = h * w * ch_in;
			out_image_size = h_out * w_out;

			long kernel_size = kernel_shape[0] * kernel_shape[1] * kernel_shape[2] * kernel_shape[3];

			// copy kernel and bias
		
		    kernel_data_ = mem_pool_->alloc<T>(kernel_size);
			if (kernel_data_ == NULL){
			  printf("malloc for kernel failed\n");
			}
			//std::copy(kernel, kernel + kernel_size, kernel_data_);
			new (&kernel_) typename TTypes<T, 4>::Tensor(kernel_data_, kernel_shape);
			
			long bias_size = kernel_shape[3];
			bias_data_ = new T[bias_size];
			// YONGQIN commemd out for testing
			//std::copy(bias, bias + bias_size, bias_data_);
			new (&bias_) typename TTypes<T>::ConstVec(bias_data_, kernel_shape[3]);
			
		}

		array4d output_shape() override
		{
			return output_shape_;
		}

		int output_size() override
		{
			return output_size_;
		}

		int num_linear() override
		{
			return 1;
		}


		int h;
		int w;
		int ch_in;
		int h_out;
		int w_out;
		int ch_out;
		int patch_size;
		int image_size;
		int out_image_size;

	protected:

		TensorMap<T, 4> apply_impl(TensorMap<T, 4> input, void* device_ptr = NULL, bool release_input=true) override
		{

  
			const int batch = input.dimension(0);
			output_shape_[0] = batch;

			// allocate memory to store the output
			T* output_mem_ = mem_pool_->alloc<T>(batch * output_size_);
			auto output_map = TensorMap<T, 4>(output_mem_, output_shape_);

			
			sgx_time_t start = get_time();
			output_map = Eigen::SpatialConvolution(input, kernel_, col_stride_, row_stride_, padding_);
			sgx_time_t end = get_time();

			if (TIMING) { printf("convd (%ld x %ld x %ld) took %.4f seconds\n", input.dimension(1), input.dimension(2), input.dimension(3), get_elapsed_time(start, end)); };

			// add bias
			const int bias_size = bias_.dimension(0);
			const int rest_size = output_map.size() / bias_size;
			Eigen::DSizes<int, 1> one_d(output_map.size());
			Eigen::DSizes<int, 1> bcast(rest_size);
	
			output_map.reshape(one_d) = output_map.reshape(one_d) + bias_.broadcast(bcast).reshape(one_d);
			if (release_input) {
				mem_pool_->release(input.data());
			}

			return output_map;
		}

		

		T* kernel_data_;
		T* bias_data_;
		TensorMap<T, 4> kernel_;
		TensorMap<T, 1> bias_;

		const Eigen::PaddingType padding_;
		const int row_stride_;
		const int col_stride_;
		int pad_rows_;
		int pad_cols_;

		MemPool* mem_pool_;

		array4d input_shape_;
		array4d kernel_shape_;
		int input_size_;

		array4d output_shape_;
		int output_size_;
	};

 
} //SGXDNN namespace

#endif
