#include <iostream>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <dlfcn.h>

using namespace std;
using namespace tensorflow;

template <typename Device, typename T>
class ResnetActivationBackOp : public OpKernel {
  public:
    explicit ResnetActivationBackOp(OpKernelConstruction* context) : OpKernel(context) {
        OP_REQUIRES_OK(context, context->GetAttr("eid_low", &eid_low_));
        OP_REQUIRES_OK(context, context->GetAttr("eid_high", &eid_high_));
        lib_ = dlopen("App/enclave_bridge.so", RTLD_NOW);
        OP_REQUIRES(context, lib_ != NULL, errors::Unknown("Unable to load .so"));
        OP_REQUIRES_OK(context, context->GetAttr("act_mode", &act_mode_));
    }

    void Compute(OpKernelContext* context) override {
        const Tensor& input = context->input(0);

        auto output_shape = input.shape();
        if (act_mode_ == "bnrelupool") {
            output_shape = TensorShape({
             input.shape().dim_sizes()[0],
             (input.shape().dim_sizes()[1] + 1)*2,
             (input.shape().dim_sizes()[2] + 1)*2,
             input.shape().dim_sizes()[3]
             });
        }
        //printf("=================================\n");
        ///printf("resnet act grad out size %d %d, grad in size %d %d\n", input.shape().dim_sizes()[1], input.shape().dim_sizes()[2], output_shape.dim_sizes()[1], output_shape.dim_sizes()[2]);
        //printf("%s\n", act_mode_.c_str());

        Tensor* output = nullptr;
        OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
        
        unsigned long int eid_ = (eid_high_ << 32) + eid_low_;

        typedef void(* back) (unsigned long int, float*, float*);
        dlerror();
        back res_act_back = (back) dlsym(lib_, "resnet_activation_bwd");
        const char *dlsym_error = dlerror();
        OP_REQUIRES(context, !dlsym_error, errors::Unknown("loading of relu failed: ", dlsym_error));

        res_act_back(eid_,
                    (float*) output->flat<T>().data(),
                    (float*) input.flat<T>().data()
                );
    }
  private:
    void* lib_;
    std::string act_mode_;
    int64 eid_low_;
    int64 eid_high_;

};  // class ResnetActivationBackOp

typedef Eigen::ThreadPoolDevice CPUDevice;

REGISTER_KERNEL_BUILDER(Name("ResnetActivationBack").Device(DEVICE_CPU), ResnetActivationBackOp<CPUDevice, float>);

REGISTER_OP("ResnetActivationBack")
    .Attr("eid_low: int")
    .Attr("eid_high: int")
    .Attr("act_mode: string")
    .Input("grad_out: float")
    .Output("grad_in: float")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      std::string act_mode;
      c->GetAttr("act_mode", &act_mode);

      
      if (act_mode == "bnrelupool") {
        auto N = c->Dim(c->input(0), 0);
        auto H = c->Dim(c->input(0), 1);
        auto W = c->Dim(c->input(0), 2);
        auto C = c->Dim(c->input(0), 3); 

        H = c->MakeDim((c->Value(H) + 1) * 2);
        W = c->MakeDim((c->Value(W) + 1) * 2);
        c->set_output(0, c->MakeShape({N, H, W, C}));
        return Status::OK();
      }
      c->set_output(0, c->input(0));
      return Status::OK();
    });
    