#include <assert.h>

#include <math.h>
#include <numeric>
#include <vector>
#include <string>

#include "dnnl.hpp"
#include "sgxdnn_main.hpp"
#include "relu.hpp"
#include "pool.hpp"
#include "conv2d.hpp"
#include "example_utils.hpp"
#include "Enclave_t.h"

using namespace dnnl;
using tag = memory::format_tag;
using dt = memory::data_type;


extern "C" {
    void ecall_dnnl_init(int train_inside, int internal_batch);
    void ecall_setup_relu(int* input_size);
    void ecall_slalom_relu_back(float* grad, float* relu_diff_src_buf, float* relu_src_buf, float* bias_grad);

    int  ecall_setup_maxpoolrelu(int* input_size, int* output_size,
                                 int* kernel_size, int* strides, int* padding);
    void ecall_maxpoolrelu_back(float* grad, float* pool_diff_src_buf, float* relu_src_buf,
                                float* workspace, float* bias_grad);
    void ecall_maxpool(float* src, float* dst, float* work);
    int  ecall_get_work_size();
    void ecall_dnnl_batch(int s);
    void ecall_setup_conv2d(int* input_size, int* output_size, int* kernel_size);
    void ecall_print_time_report() {
        for (int i = 0; i < 7; i++)
            printf("%f ", time_report[i]);
        printf("\n");
        printf("%f ", time_report[4]);
        for (int i = 7; i < 12; i++)
            printf("%f ", time_report[i]);
        printf("\n");

    }
    static int internal_batch_size = 0;
}

engine eng;
stream s;
std::vector<Relu> reluq;
int relu_exeptr=-1;

std::vector<Pool> poolq;
int pool_exeptr=-1;
std::vector<int> workspace_size;
int size_ptr=0;
int batch2 = 0;
std::vector<Conv2d> convq;

std::string array_to_string(float* a, int size) {
    std::string res = "";
    for (int i = 0; i < size; i++)
    res = res + std::to_string(a[i]) + " ";
    res = res + "\n";
    return res;
}

std::string array_to_string(int* a, int size) {
    std::string res = "";
    for (int i = 0; i < size; i++)
        res = res + std::to_string(a[i]) + " ";
    res = res + "\n";
    return res;
}

bool load_conv;

void ecall_dnnl_init(int train_inside, int internal_batch) {
    eng = engine(dnnl::engine::kind::cpu, 0);
    s = stream(eng);
    load_conv = train_inside == 1;
    internal_batch_size = internal_batch;
}

void ecall_dnnl_batch(int s) {
    batch2 = s;
    printf("dnnl batch is %d\n", batch2);
}

void ecall_setup_relu(int* input_size) {
    printf("relu setup ");
    ocall_print_string((std::to_string(input_size[0]) + " " + std::to_string(input_size[1]) + " "
                        + std::to_string(input_size[2]) + " " +
                        std::to_string(input_size[3]) + "\n").c_str());
    memory::dims relu_src_tz = {internal_batch_size - 1, input_size[1], input_size[2], input_size[3]};
    memory::desc desc = {{relu_src_tz}, dt::f32, tag::nhwc};
    // set use python to true
    Relu relu = Relu(relu_src_tz, desc, eng, s, 0.0f, true);
    relu.update_backward(desc);
    reluq.push_back(relu);
    relu_exeptr++;
    printf("relu done\n");
}

int ecall_setup_maxpoolrelu(int* input_size, int* output_size, int* kernel_size, int* strides, int* padding) {  
    memory::dims pool_src_tz = {internal_batch_size - 1, input_size[1], input_size[2], input_size[3]};
    memory::dims pool_dst_tz = {internal_batch_size - 1, output_size[1], output_size[2], output_size[3]};
    memory::dims kernel_tz = {kernel_size[0], kernel_size[1]};
    memory::dims strides_tz = {strides[0], strides[1]};
    memory::dims padding_tz = {padding[0], padding[1]};
    memory::desc desc = {{pool_src_tz}, dt::f32, tag::nhwc};
    memory::desc desc1 = {{pool_dst_tz}, dt::f32, tag::nhwc};
    printf("maxpool relu setup deleting pre relu \n");
    ocall_print_string((std::to_string(input_size[0]) + " " + std::to_string(input_size[1]) + " "
                                            + std::to_string(input_size[2]) + " " +
                                            std::to_string(input_size[3]) + "\n").c_str());

    ocall_print_string((std::to_string(output_size[0]) + " " + std::to_string(output_size[1]) + " "
                                            + std::to_string(output_size[2]) + " " +
                                            std::to_string(output_size[3]) + "\n").c_str());
    ocall_print_string((std::to_string(kernel_size[0]) + " " + std::to_string(kernel_size[1]) + "\n").c_str());
    
    ocall_print_string((std::to_string(strides[0]) + " " + std::to_string(strides[1]) + "\n").c_str());
    
    ocall_print_string((std::to_string(padding[0]) + " " + std::to_string(padding[1]) + "\n").c_str());
    
    Pool pool = Pool(pool_dst_tz, pool_src_tz, kernel_tz, strides_tz, padding_tz, desc, eng, s);
    pool.update_backward(desc1);
    poolq.push_back(pool);
    pool_exeptr++;
    workspace_size.push_back(pool.work_desc().get_size());
    reluq.pop_back();
    relu_exeptr--;
    ecall_setup_relu(output_size);
    //return 0;
    return pool.work_desc().get_size();
}

int ecall_get_work_size() {
    int res = workspace_size.at(size_ptr);
    size_ptr++;
    if (size_ptr == workspace_size.size()) {
    size_ptr = 0;
    }
    return res;
}

void ecall_slalom_relu_back(float* grad, float* relu_diff_src_buf, float* relu_src_buf, float* grad_out) {
    assert(relu_exeptr >= 0);
    double rb_start, rb_end;
    ocall_get_time(&rb_start);

    static bool last_layer = true;
    //("================\n");
    if (last_layer) {
        int num_channels = 1000;
        
        //printf("last layer backward\n");
        int output_size = num_channels * internal_batch_size;
        int input_size = output_size;
        for (int i = 0; i < batch2; i++) {
            // last layer scramble the 
            // scramble the gradient and leave
            scramble_matrix_func(grad, relu_diff_src_buf, output_size, 1);
            // gradient to bias
            bias_grad(grad, bias_grads, i == 0, num_channels, output_size / internal_batch_size, internal_batch_size);
            grad += output_size;
            relu_diff_src_buf += input_size;
            relu_src_buf += input_size;
        }
        std::copy(bias_grads, bias_grads+1000, grad_out);
        last_layer = false;
        return;
    }
    
    
    Relu relu = reluq.at(relu_exeptr);
    

    int output_size = (relu.dst_desc().get_size() / sizeof(float)) / (internal_batch_size - 1)  * internal_batch_size;
    int input_size  = (relu.src_desc().get_size() / sizeof(float)) / (internal_batch_size - 1)  * internal_batch_size;
    //printf("relu back %d, input size %d batch2 %d\n", relu_exeptr, input_size, batch2);
    int num_channels = relu.input_channels();
    
    for (int i = 0; i < batch2; i++) {
        // unblind relu_src_buf to buffer1
        //printf("unblind1\n");
        // mode : 0 to unblind
        //blind_matrix_func(relu_src_buf, temp_buffer, input_size, 0);
        //unblind grad to buffer 2
        //printf("unblind2\n");
        // mode : 0 to unscramble
        //scramble_matrix_func(grad, temp_buffer2, output_size, 0);
        //printf("unblind3\n");
        double relu_start, relu_end;
        ocall_get_time(&relu_start);
        merged_unblind_scramble_reluback(grad, relu_src_buf, temp_buffer2, 
                                temp_buffer3, output_size / 2);
        ocall_get_time(&relu_end);
        time_report[9] += (relu_end - relu_start) / (1000. * 1000);
        bias_grad(temp_buffer3, bias_grads, i == 0, num_channels, output_size / internal_batch_size, internal_batch_size);
        scramble_matrix_func(temp_buffer3, relu_diff_src_buf, input_size, 1);
        grad += output_size;
        relu_diff_src_buf += input_size;
        relu_src_buf += input_size;
        
    }
    std::copy(bias_grads, bias_grads+num_channels, grad_out);
    relu_exeptr--;
    
    if (relu_exeptr == -1) {
        relu_exeptr = (int) reluq.size() - 1;
        last_layer = true;
    }
    ocall_get_time(&rb_end);
    time_report[11] += (rb_end - rb_start) / (1000. * 1000);
}

void ecall_maxpool(float* src, float* dst, float* workspace) {
    static int pool_fwd_ptr = 0;
    static int batch_counter = 0;
    double mb_start, mb_end;
    ocall_get_time(&mb_start);
    //ocall_print_string((std::string("maxpool exe ptr") + std::to_string(pool_fwd_ptr)+"\n").c_str());
    //ocall_print_string((std::string("maxpool bat ptr") + std::to_string(batch_counter)+"\n").c_str());
    //ocall_print_string((std::string("maxpool bat size") + std::to_string(batch2)+"\n").c_str());
    //ocall_print_string((std::to_string(src[802816]) + "\n").c_str());
    Pool pool = poolq.at(pool_fwd_ptr);
    memory pool_src = memory(pool.src_desc(), eng, (void*) src);
    memory pool_dst = memory(pool.dst_desc(), eng, (void*) dst);
    memory pool_work= memory(pool.work_desc(), eng, (void*) workspace);
    int work_size   = pool.work_desc().get_size() / sizeof(float);
    pool.forward(pool_src, pool_dst, pool_work);

    auto src_desc = pool.src_desc().get_size();
    auto dst_desc = pool.dst_desc().get_size();
    //auto in_dims = src_desc.dim();
    //auto ou_dims = dst_desc.dim();
    //int inshape[4] = {in_dims[0], in_dims[1], in_dims[2], in_dims[3]};
    //int oushape[4] = {ou_dims[0], ou_dims[1], ou_dims[2], ou_dims[3]}; 
    //ocall_print_string((std::to_string(src_desc) + "\n").c_str());
    //ocall_print_string((std::to_string(dst_desc) + "\n").c_str());
    
    s.wait();

    batch_counter++;
    if (batch_counter == batch2) {
    batch_counter = 0;
    pool_fwd_ptr++;
    }
    
    if (pool_fwd_ptr == poolq.size())
        pool_fwd_ptr = 0;
}

void ecall_maxpoolrelu_back(float* grad, float* pool_diff_src_buf, float* relu_src_buf, float* workspace, float* grad_out) {
    assert(pool_exeptr >= 0);
    assert(relu_exeptr >= 0);
    double mb_start, mb_end;
    ocall_get_time(&mb_start);
    //("=======================\n");
    //printf("Maxpool reluback %d %d\n", pool_exeptr, relu_exeptr);
    Relu relu = reluq.at(relu_exeptr);
    Pool pool = poolq.at(pool_exeptr);
    int input_size  = (pool.src_desc().get_size() / sizeof(float)) / (internal_batch_size - 1) * internal_batch_size;
    int output_size = (pool.dst_desc().get_size() / sizeof(float)) / (internal_batch_size - 1) * internal_batch_size;

    int num_channels = pool.input_channel();
    //printf("num of channles is %d\n", num_channels);
    for (int i = 0; i < batch2; i++) {
        // unblind inputs
        //printf("maxpool unblind1\n");
        //unblind_internal(grad, NULL, temp_buffer, output_size);
        // unscramble the grad
        // mode : 0 to unscramble
        scramble_matrix_func(grad, temp_buffer, output_size, 0);
        //printf("maxpool unblind2\n");
        //unblind_internal(relu_src_buf, NULL, temp_buffer2, output_size);
        // unblind the relu_src
        blind_matrix_func(relu_src_buf, temp_buffer2, output_size, 0);
        //printf("maxpool unblind1\n");
        
        memory relu_diff_dst = memory(relu.diff_dst_desc(), eng, (void*) temp_buffer);
        memory relu_src = memory(relu.src_desc(), eng, (void*) temp_buffer2);
        memory relu_diff_src = memory(relu.bwd_src_desc(), eng);
        //printf("before relu\n");
        double relu_start, relu_end;
        ocall_get_time(&relu_start);
        relu.backward(relu_diff_dst, relu_diff_src, relu_src);
        s.wait();
        ocall_get_time(&relu_end);
        time_report[10] += (relu_end - relu_start) / (1000. * 1000);    
        //printf("good relu\n");
        memory pool_diff_dst = relu_diff_src;
        memory pool_diff_src = memory(pool.diff_src_desc(), eng, (void*) temp_buffer3);
        memory pool_workspace = memory(pool.work_desc(), eng, workspace);
        //printf("before maxpool\n");
        ocall_get_time(&relu_start);
        pool.backward(pool_diff_dst, pool_diff_src, pool_workspace);
        s.wait();
        ocall_get_time(&relu_end);
        time_report[10] += (relu_end - relu_start) / (1000. * 1000);    

        // write bias gradients to np array
        //printf("bias grad\n");
        bias_grad(temp_buffer3, bias_grads, i == 0, num_channels, input_size / internal_batch_size, internal_batch_size);
        //printf("bias grad21\n");
        // mode : 1 to scramble
        scramble_matrix_func(temp_buffer3, pool_diff_src_buf, input_size, 1);
        //printf("blinded inputer good \n");
        grad += output_size;
        pool_diff_src_buf += input_size;
        relu_src_buf += output_size;
        workspace += output_size;
    }
    // copy gradients out
    std::copy(bias_grads, bias_grads+num_channels, grad_out);
    pool_exeptr--;
    if (pool_exeptr == -1)
    pool_exeptr = (int) poolq.size() - 1;

    relu_exeptr--;
    if (relu_exeptr == -1) {
        relu_exeptr = (int) reluq.size() - 1;
    }
    ocall_get_time(&mb_end);
    time_report[11] += (mb_end - mb_start) / (1000. * 1000);

}

void ecall_setup_conv2d(int* input_size, int* output_size, int* kernel_size) {
    if (!load_conv)
    return;
    memory::dims conv_src_tz = {2, input_size[1], input_size[2], input_size[3]};
    memory::dims conv_dst_tz = {2, output_size[1], output_size[2], output_size[3]};
    memory::dims weight_tz   = {kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3]};
    memory::dims bias_tz     = {kernel_size[0]};
    memory::dims conv_strides = {1, 1};
    memory::dims conv_padding = {1, 1};
    printf("setting up conv2d\n");
    printf("%d, %d, %d, %d\n", conv_src_tz[0], conv_src_tz[1], conv_src_tz[2], conv_src_tz[3]);
    printf("%d, %d, %d, %d\n", conv_dst_tz[0], conv_dst_tz[1], conv_dst_tz[2], conv_dst_tz[3]);
    printf("%d, %d, %d, %d\n", weight_tz[0], weight_tz[1], weight_tz[2], weight_tz[3]);
    Conv2d conv = Conv2d(conv_src_tz, weight_tz, conv_dst_tz,
                         bias_tz, conv_strides, conv_padding,
                         {{conv_src_tz}, dt::f32, tag::nhwc},
                         {{weight_tz}, dt::f32, tag::oihw},
                         eng, s, true, convq.size() == 0);

    conv.update_backward(conv_src_tz, weight_tz, conv_dst_tz,
                         bias_tz, conv_strides, conv_padding, conv.dst_desc(),
                         {{weight_tz}, dt::f32, tag::oihw});
    convq.push_back(conv);
}