//#include <fdeep/fdeep.hpp>

#include <stdarg.h>
#include <stdio.h>      /* vsnprintf */
#include <vector>
#include "Enclave.h"
#include "Enclave_t.h"  /* print_string */
#include "dnnl.hpp"
#include "sgx_trts.h"
#include "sgx_tcrypto.h"
#include "batchnorm_sp.h"
#include "depthwise_sp.h"
#include "layer.hpp"
#include "example_utils.hpp"
#include "inverted_sp.h"
#include "pool_sp.h"
#include "relu_sp.h"
#include "linear_sp.h"
#include "resnet_activation.h"
#include "resnet_bottom.h"
#include "resnet_block.h"
using namespace dnnl;
using namespace std;

double get_time_force();
double get_elapsed_time(double start, double end);


extern double enclave_fwd_time;
extern double enclave_bwd_time;
extern double batchnorm_fwd_time;
extern double batchnorm_bwd_time;
extern double dump_time;
extern double load_time;
extern int load_num;
extern int dump_num;
extern engine      eng;
extern stream      s;
#include "conv2d_sp.hpp"


bool final_reorder = false;
memory::desc src_tags;
memory::desc final_src_tags;
vector<LayerSp*> layers;
primitive reorder_layer;
void printf(const char *fmt, ...)
{
    char buf[BUFSIZ] = {'\0'};
    va_list ap;
    va_start(ap, fmt);
    vsnprintf(buf, BUFSIZ, fmt, ap);
    va_end(ap);
    ocall_print_string(buf);
}

bool loading_inverted = false;
bool loading_resblock = false;

void start_clock() {
    ocall_start_clock();
}

void end_clock(const char* str) {
    ocall_end_clock(str);
}

bool overall_first = true;

void ecall_sgx_conv_create(int conv_src_sz[4], int conv_dst_sz[4], int conv_weight_sz[4],
                     int conv_strides_sz[2], int conv_padding_sz[2], float* weight_data,
                     float* bias_data,  int is_first) {
    using tag = memory::format_tag;
    using dt = memory::data_type;
    bool is_first_bool = is_first == 1;

    if (is_first_bool) {
        src_tags = {{{conv_src_sz[0], conv_src_sz[3], conv_src_sz[1], conv_src_sz[2]}}, dt::f32, tag::nhwc};
        printf("is first");
    }
    overall_first = false;
    LayerSp* l = new Conv2d(conv_src_sz, conv_dst_sz, 
                          conv_weight_sz, conv_strides_sz, 
                          conv_padding_sz, src_tags, 
                          eng, s, weight_data, bias_data, 
                          true, is_first_bool);
    src_tags = l->dst_desc();
    final_src_tags = {{{conv_dst_sz[0], conv_dst_sz[3], conv_dst_sz[1], conv_dst_sz[2]}}, dt::f32, tag::nhwc};
    if (loading_resblock)
        ((ResnetBlockSp*) layers.at(layers.size()-1))->add_layer(l);
    else if (loading_inverted)
        ((InvertedCellSp*)layers.at(layers.size()-1))->add_layer(l);
    else
        layers.push_back(l);
    printf("conv2d finish setup");
}

void* conv2d_wrapper (int conv_src_sz[4], int conv_dst_sz[4], int conv_weight_sz[4],
                     int conv_strides_sz[2], int conv_padding_sz[2], void* src_tags,
                     float* weight_data, float* bias_data, bool training, bool is_first) {

    return(void*) new Conv2d(conv_src_sz, conv_dst_sz, 
                              conv_weight_sz, conv_strides_sz, 
                              conv_padding_sz, *((memory::desc*)src_tags), 
                              eng, s, weight_data, bias_data, 
                              training, is_first);
}

void ecall_sgx_depth_conv_create(int conv_src_sz[4], int conv_dst_sz[4], int conv_weight_sz[4],
                     int conv_strides_sz[2], int conv_padding_sz[2], float* weight_data,
                     float* bias_data,  int is_first) {
    using tag = memory::format_tag;
    using dt = memory::data_type;
    bool is_first_bool = is_first == 1;
    overall_first = false;

    if (is_first_bool) {
        src_tags = {{{conv_src_sz[0], conv_src_sz[3], conv_src_sz[1], conv_src_sz[2]}}, dt::f32, tag::nhwc};
        printf("is first");
    }
    
    LayerSp* l = new DepthwiseConv2dSp(conv_src_sz, conv_dst_sz, 
                          conv_weight_sz, conv_strides_sz, 
                          conv_padding_sz, src_tags, 
                          eng, s, weight_data, bias_data
                          );
    src_tags = l->dst_desc();

    final_src_tags = {{{conv_dst_sz[0], conv_dst_sz[3], conv_dst_sz[1], conv_dst_sz[2]}}, dt::f32, tag::nhwc};

    if (loading_resblock)
        ((ResnetBlockSp*) layers.at(layers.size()-1))->add_layer(l);
    else if (loading_inverted) 
        ((InvertedCellSp*)layers.at(layers.size()-1))->add_layer(l);

    else
        layers.push_back(l);

}

void ecall_sgx_bn_create(int size[4], int mode, float eps, float momentum) {
    LayerSp* l     = new BatchNormSP(size, mode, eps, momentum, src_tags, eng, s);
    src_tags       = l->dst_desc();
    final_src_tags = src_tags;
    overall_first = false;
    if (loading_resblock)
        ((ResnetBlockSp*) layers.at(layers.size()-1))->add_layer(l);
    else if (loading_inverted)
        ((InvertedCellSp*)layers.at(layers.size()-1))->add_layer(l);
    else
        layers.push_back(l);
}

void ecall_sgx_relu_create(int size[4]) {
    using tag = memory::format_tag;
    using dt  = memory::data_type;
    LayerSp* l     = new Relu_Sp(size, src_tags, eng, s);
    src_tags       = l->dst_desc();
    printf("%d %d %d %d", size[0], size[3], size[1], size[2]);
    final_src_tags = {{{size[0], size[3], size[1], size[2]}}, dt::f32, tag::nhwc};;
    overall_first = false;
    if (loading_resblock)
        ((ResnetBlockSp*) layers.at(layers.size()-1))->add_layer(l);
    else if (loading_inverted)
        ((InvertedCellSp*)layers.at(layers.size()-1))->add_layer(l);
    else
        layers.push_back(l);
}

void ecall_sgx_pool_create(int in_size[4], int out_size[4], int kernel_size[2], 
                           int stride[2],  int padding[2],  int type) {
    using tag = memory::format_tag;
    using dt = memory::data_type;
    printf("entered ecall create pool");
    if (overall_first) {
        src_tags = {{{in_size[0], in_size[3], in_size[1], in_size[2]}}, dt::f32, tag::nhwc};
        overall_first = false;
    }

    LayerSp* l  = new PoolSp(in_size, out_size, 
                             kernel_size, stride, 
                             padding, type, 
                             src_tags,
                             eng, s);

    src_tags = l->dst_desc();
    final_src_tags = {{{out_size[0], out_size[3], out_size[1], out_size[2]}}, 
                        dt::f32, tag::nhwc};

    if (loading_resblock)
        ((ResnetBlockSp*) layers.at(layers.size()-1))->add_layer(l);
    else if (loading_inverted)
        ((InvertedCellSp*)layers.at(layers.size()-1))->add_layer(l);
    else
        layers.push_back(l);
}

void ecall_sgx_linear_create(int in_size[4], int out_size[2],
                             float* kernel_data, float* bias_data) {
    using tag = memory::format_tag;
    using dt = memory::data_type;
    if (overall_first) {
        src_tags = {{{in_size[0], in_size[3], in_size[1], in_size[2]}}, dt::f32, tag::nhwc};
        overall_first = false;
    }
    LayerSp* l  = new LinearSp(in_size, out_size, 
                             kernel_data, bias_data, 
                             src_tags,
                             eng, s);
    
    src_tags = l->dst_desc();
    final_src_tags = {{{out_size[0], out_size[1], 1, 1 }}, 
                        dt::f32, tag::nhwc};

    if (loading_resblock)
        ((ResnetBlockSp*) layers.at(layers.size()-1))->add_layer(l);
    else if (loading_inverted)
        ((InvertedCellSp*)layers.at(layers.size()-1))->add_layer(l);
    else
        layers.push_back(l);
}

void ecall_setup_final_reorder() {
    primitive_attr dummy_attr;
    if (src_tags != final_src_tags) {
        final_reorder = true;
        reorder_layer = reorder(reorder::primitive_desc(eng, src_tags, eng, final_src_tags, dummy_attr));

    }
}

void ecall_inverted_init() {
    loading_inverted = true;
    
    layers.push_back(new InvertedCellSp(eng, s));
}

void ecall_inverted_compl() {
    loading_inverted = false;
}

void ecall_resblock_init(int in_size[4], int out_size[4],
                         int stride[2], int identity) {
    using tag = memory::format_tag;
    using dt = memory::data_type;
    loading_resblock = true;
    printf("%d", layers.size());
    LayerSp* l = new ResnetBlockSp(in_size, out_size, 
                                   stride, identity == 1, 
                                   nullptr, nullptr,
                                   src_tags,
                                   eng, s);
    layers.push_back(l);
}

void ecall_resblock_compl() {
    using tag = memory::format_tag;
    using dt = memory::data_type;
    printf("%d", layers.size());

    LayerSp* l = layers.at(layers.size() - 1);
    ResnetBlockSp* ls = (ResnetBlockSp*) l;
    final_src_tags = {{{ls->relu_in_size_[0], ls->relu_in_size_[3], ls->relu_in_size_[1], ls->relu_in_size_[2]}}, 
                         dt::f32, tag::nhwc};
    loading_resblock = false;
    ls->setup_final_relu();

    src_tags = ls->dst_desc();


}

void ecall_forward(float* in, float* out) {
    int size_in_byte;
    float* in_ptr = new float[224*224*3 + 8];

    float* in_ptr_aling = ALIGN32(in_ptr);

    std::copy(in, in+224*224*3, in_ptr_aling);

    float* res_ptr;
    float* res_ptr_aligned;
    int i = 0;
    for (LayerSp* l : layers) {
        res_ptr = new float[l->dst_desc().get_size() / sizeof(float) + 8];
        res_ptr_aligned = ALIGN32(res_ptr);
        l->forward(in_ptr_aling, res_ptr_aligned, true);

        size_in_byte = l->out_size();
        delete [] in_ptr;
        in_ptr = res_ptr;
        in_ptr_aling = res_ptr_aligned;

    }
    float* reorder_ptr = res_ptr;
    float* reordered = res_ptr_aligned;
    if (final_reorder) {
        reordered   = new float[size_in_byte];
        reorder_ptr = reordered;
        reorder_layer.execute(s, {{DNNL_ARG_FROM, memory(src_tags, eng, res_ptr_aligned)}, 
                              {DNNL_ARG_TO,   memory(final_src_tags, eng, reordered)}});
        delete [] res_ptr;
    }
    printf("after fwd");
    std::copy(reordered, reordered + size_in_byte, out);
    delete [] reorder_ptr;

}

void ecall_enclave_update_backward() {
    memory::desc back_tag = final_src_tags;

    int size = layers.size();
    for (int i = size - 1; i >=0; i--) {
        LayerSp* l = layers.at(i);
        printf("____________");
        printf("%d", i);
        l->update_backward(back_tag);
        if (i != 0)
            back_tag = l->diff_src_desc();
    }
}

void ecall_reset_timing() {
  enclave_fwd_time = 0.0;
  enclave_bwd_time = 0.0;
  batchnorm_fwd_time = 0.0;
  batchnorm_bwd_time = 0.0;
  load_num = 0;
  dump_num = 0;
  dump_time = 0.0;
  load_time = 0.0;
}

void ecall_print_timing() {
    printf("%4.4f, %4.4f", enclave_fwd_time, enclave_bwd_time);
    printf("%4.4f, %4.4f", batchnorm_fwd_time, batchnorm_bwd_time);
    printf("%4.4f, %4.4f", dump_time, load_time);
    
    printf("dump times %d load times %d\n", dump_num, load_num);

}

/*
void ecall_enclave_backward(float* grad_out, float* grad_in) {
    int size_in_byte    = 16;
    float* in_ptr       = new float[16+8];
    float* in_ptr_aling = ALIGN32(in_ptr);
    std::copy(grad_out, grad_out+16, in_ptr_aling);

    float* grad_out_ptr;
    float* grad_out_ptr_aligned;
    int size = layers.size();

    for (int i = size - 1; i >= 0; i--) {
        LayerSp* l           = layers.at(i);
        printf("layer %d %d", i, l->diff_src_desc().get_size() / sizeof(float));
        grad_out_ptr         = new float[l->diff_src_desc().get_size() / sizeof(float) + 8];
        grad_out_ptr_aligned = ALIGN32(grad_out_ptr);
        l->backward(grad_out_ptr_aligned, in_ptr_aling);
        printf("back complete");
        delete [] in_ptr;

        in_ptr       = grad_out_ptr;
        in_ptr_aling = grad_out_ptr_aligned;
    }
    std::copy(in_ptr_aling, in_ptr_aling+32, grad_in);

}*/



void ecall_enclave_backward(float* grad_out, float* grad_in) {
    int size_in_byte    = 7*7*2048+8;
    float* in_ptr       = new float[7*7*2048+8];
    float* in_ptr_aling = ALIGN32(in_ptr);
    std::copy(grad_out, grad_out+7*7*2048, in_ptr_aling);

    float* grad_out_ptr;
    float* grad_out_ptr_aligned;
    int size = layers.size();

    for (int i = size - 1; i >= 0; i--) {
        LayerSp* l           = layers.at(i);
        if (i != 0) {
            grad_out_ptr         = new float[l->diff_src_desc().get_size() / sizeof(float) + 8];
            grad_out_ptr_aligned = ALIGN32(grad_out_ptr);
        }
        l->backward(grad_out_ptr_aligned, in_ptr_aling);

        delete [] in_ptr;

        in_ptr       = grad_out_ptr;
        in_ptr_aling = grad_out_ptr_aligned;
    }
}

std::vector<ResnetActivation*> res_act;
std::vector<ResnetBottom*>     res_bot;



void ecall_resnet_setup_activation(int act_mode_int, 
                                   int in_size[4], 
                                   int out_size[4], 
                                   int pool_window[2], 
                                   int pool_stride[2], 
                                   float eps,
                                   float momentum,
                                   float* bias_data) {
    //printf("\n======================================\n");
    //printf("resnet activation setup ptr %d", res_act.size());


    std::string act_mode;
    if (act_mode_int == 0) {
        act_mode = std::string("bias_add");
    } else if (act_mode_int == 1) {
        act_mode = std::string("bnzerorelu");
    } else if (act_mode_int == 2) {
        act_mode = std::string("bnrelupool");

    } else {
        printf("act mode error");
        assert(false);
    }  
    printf("%s", act_mode.c_str());

    using tag = memory::format_tag;
    using dt = memory::data_type;
    memory::desc n_tags = {{{2, in_size[3], in_size[1], in_size[2]}}, dt::f32, tag::nhwc};

    ResnetActivation* l = new ResnetActivation(act_mode, in_size, out_size, 
                                           pool_window, pool_stride, 
                                           eps, momentum, bias_data,
                                           n_tags,
                                           eng, s
                                           );
    
    res_act.push_back(l); 
}

void ecall_resnet_setup_bottom(int act_mode_int, 
                               int in_size[4], 
                               int out_size[4], 
                               float eps,
                               float momentum,
                               float* bias_data_l,
                               float* bias_data_r) {
    std::string act_mode;
    if (act_mode_int == 0) {
        act_mode = std::string("normal");
    } else {
        act_mode = std::string("downsample");
    }
    using tag = memory::format_tag;
    using dt = memory::data_type;
    memory::desc n_tags = {{{in_size[0], in_size[3], in_size[1], in_size[2]}}, dt::f32, tag::nhwc};
    
    ResnetBottom* l = new ResnetBottom(act_mode, 
                                       in_size, 
                                       out_size, 
                                       eps,
                                       momentum,
                                       bias_data_l,
                                       bias_data_r,
                                       n_tags, 
                                       eng, 
                                       s);
    
    res_bot.push_back(l); 

}

void ecall_resnet_activation_fwd(float* src, float* dst, float* mean_extern) {
    static int fwd_ptr = 0;

    ResnetActivation* l = res_act.at(fwd_ptr);
    //printf("resnet activation ptr %d / %d", fwd_ptr, res_act.size()-1);
    double s = get_time_force();
    l->forward_sp(src, dst, mean_extern, false, false);
    double e = get_time_force();  
    enclave_fwd_time += get_elapsed_time(s, e);
    fwd_ptr++;
    if (fwd_ptr >= res_act.size()) {
        fwd_ptr = 0;
    }

}

void ecall_resnet_activation_bwd(float* diff_src_ptr, float* diff_dst_ptr) {
    static int bwd_ptr = res_act.size() - 1;
    //printf("resnet bwd activation ptr %d / %d", bwd_ptr, res_act.size()-1);
 
    ResnetActivation* l = res_act.at(bwd_ptr);
    double s = get_time_force();
    l->backward(diff_src_ptr, diff_dst_ptr);
    double e = get_time_force();  
    enclave_bwd_time += get_elapsed_time(s, e);
    bwd_ptr--;
    if (bwd_ptr < 0) {
        bwd_ptr = res_act.size() - 1;
    }
}


void ecall_resnet_bottom_fwd(float* left_in,   float* right_in, 
                             float* mean_left, float* mean_right, 
                             float* dst) {
    static int fwd_ptr = 0;
    //printf("resnet fwd bottom ptr %d / %d", fwd_ptr, res_bot.size()-1);

    ResnetBottom* l = res_bot.at(fwd_ptr);
    double s = get_time_force();

    l->forward_sp(left_in, 
               right_in,   
               mean_left,  
               mean_right, 
               dst,
               false
               );
    double e = get_time_force();  
    enclave_fwd_time += get_elapsed_time(s, e);

    fwd_ptr++;
    if (fwd_ptr >= res_bot.size()) {
        fwd_ptr = 0;
    }
}

void ecall_resnet_bottom_bwd(float* left_grad, float* right_grad, float* diff_dst_ptr) {
    static int bwd_ptr = res_bot.size() - 1;
    //printf("resnet bwd bottom ptr %d / %d", bwd_ptr, res_bot.size()-1);

    ResnetBottom* l = res_bot.at(bwd_ptr);
    double s = get_time_force();

    l->backward_sp(left_grad,
                   right_grad,
                   diff_dst_ptr
                   );
    double e = get_time_force();  
    enclave_bwd_time += get_elapsed_time(s, e);

    bwd_ptr--;
    if (bwd_ptr < 0) {
        bwd_ptr = res_bot.size() - 1;
    }
}