#include <iostream>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <dlfcn.h>

using namespace std;
using namespace tensorflow;

template <typename Device, typename T>
class BatchNormDarkBackOp : public OpKernel {
  public:
    explicit BatchNormDarkBackOp(OpKernelConstruction* context) : OpKernel(context) {
        OP_REQUIRES_OK(context, context->GetAttr("eid_low", &eid_low_));
        OP_REQUIRES_OK(context, context->GetAttr("eid_high", &eid_high_));
        lib_ = dlopen("App/enclave_bridge.so", RTLD_NOW);
        OP_REQUIRES(context, lib_ != NULL, errors::Unknown("Unable to load .so"));
    }

    void Compute(OpKernelContext* context) override {
        const Tensor& grad = context->input(0);
        const Tensor& input = context->input(1);
        const Tensor& skip_input = context->input(2);
        const Tensor& act_src = context->input(3);
        Tensor* grad_out = nullptr;
        OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &grad_out));
        unsigned long int eid_ = (eid_high_ << 32) + eid_low_;
 
        typedef void (*batch_func_back)(unsigned long int eid, float* grad_out, float* grad, float* inp, float* skip_src, float* act_src);  
        dlerror();
        batch_func_back batch_norm_back = (batch_func_back) dlsym(lib_, "batchnormSp_dark_back");
        const char *dlsym_error = dlerror();
        OP_REQUIRES(context, !dlsym_error, errors::Unknown("loading of relu failed: ", dlsym_error));
    
        batch_norm_back(eid_,
                        (float*) grad_out->flat<T>().data(),
                        (float*) grad.flat<T>().data(),
                        (float*) input.flat<T>().data(),
                        (float*) skip_input.flat<T>().data(),
                        (float*) act_src.flat<T>().data()
                        );
    }
  private:
    void* lib_;
    int64 eid_low_;
    int64 eid_high_;

};  // class BatchNormDarkBackOp

typedef Eigen::ThreadPoolDevice CPUDevice;

REGISTER_KERNEL_BUILDER(Name("BatchNormDarkBack").Device(DEVICE_CPU), BatchNormDarkBackOp<CPUDevice, float>);

REGISTER_OP("BatchNormDarkBack")
    .Attr("eid_low: int")
    .Attr("eid_high: int")
    .Input("grad: float")
    .Input("input: float")
    .Input("skip_input: float")
    .Input("act_src: float")
    .Output("grad_out: float")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
        c->set_output(0, c->input(0));
        return Status::OK();
    });