#define EIGEN_USE_THREADS

#include <iostream>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <dlfcn.h>

using namespace std;
using namespace tensorflow;

template <typename Device, typename T>
class MaxpoolReluBackOp : public OpKernel {
 public:
  explicit MaxpoolReluBackOp(OpKernelConstruction* context) : OpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("eid_low", &eid_low_));
    OP_REQUIRES_OK(context, context->GetAttr("eid_high", &eid_high_));
	lib_ = dlopen("App/enclave_bridge.so", RTLD_NOW);
	OP_REQUIRES(context, lib_ != NULL, errors::Unknown("Unable to load .so"));
  }
 
  void Compute(OpKernelContext* context) override {
    const Tensor& grad = context->input(0);
    const Tensor& relu_src = context->input(1);
	const Tensor& workspace = context->input(2);
    Tensor* output = nullptr;
	Tensor* grad_out = nullptr;
	// FOR NOW hardcoded parameter
	auto output_shape = TensorShape({
		grad.shape().dim_sizes()[0],
		  grad.shape().dim_sizes()[1]*2,
		  grad.shape().dim_sizes()[2]*2,
		  grad.shape().dim_sizes()[3]
		  });

	auto grad_shape = TensorShape({grad.shape().dim_sizes()[3]});
	OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
	OP_REQUIRES_OK(context, context->allocate_output(1, grad_shape, &grad_out));
 
    const Device& d = context->eigen_device<Device>();

    unsigned long int eid_ = (eid_high_ << 32) + eid_low_;

	typedef void (*maxrelu_back)(unsigned long int eid, float* grad, float* output, float* relu_src, float* work, float* grad_out);
	dlerror();
	maxrelu_back maxreluback = (maxrelu_back) dlsym(lib_, "maxpoolrelu_back");
	const char *dlsym_error = dlerror();

	OP_REQUIRES(context, !dlsym_error, errors::Unknown("loading of relu failed: ", dlsym_error));	

	maxreluback(eid_,
			 (float*) grad.flat<T>().data(),
			 (float*) output->flat<T>().data(),
			 (float*) relu_src.flat<T>().data(),
			 (float*) workspace.flat<T>().data(),
			 (float*) grad_out->flat<T>().data()
			 );

  }

 private:
  void* lib_;
  int64 eid_low_;
  int64 eid_high_;
};

typedef Eigen::ThreadPoolDevice CPUDevice;

REGISTER_KERNEL_BUILDER(Name("MaxpoolReluBack").Device(DEVICE_CPU), MaxpoolReluBackOp<CPUDevice, float>);

REGISTER_OP("MaxpoolReluBack")
    .Attr("eid_low: int")
    .Attr("eid_high: int")
    .Input("grad: float")
    .Input("relu_src: float")
    .Input("workspace: float")
    .Output("out: float")
    .Output("grad_out: float");
	
