
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <iostream>
#include <fstream>
#include <sstream>
#include <chrono>
#include <omp.h>

#include "sgx_urts.h"
#include "Enclave_u.h"
#include "extern_mmu.hpp"

#ifndef TRUE
#define TRUE 1
#endif

#ifndef FALSE
#define FALSE 0
#endif

#define TOKEN_FILENAME   "enclave.token"
#define ENCLAVE_FILENAME "enclave.signed.so"

using namespace std::chrono;

typedef struct _sgx_errlist_t {
    sgx_status_t err;
    const char *msg;
    const char *sug; /* Suggestion */
} sgx_errlist_t;

/* Error code returned by sgx_create_enclave */
static sgx_errlist_t sgx_errlist[] = {
    {
        SGX_ERROR_UNEXPECTED,
        "Unexpected error occurred.",
        NULL
    },
    {
        SGX_ERROR_INVALID_PARAMETER,
        "Invalid parameter.",
        NULL
    },
    {
        SGX_ERROR_OUT_OF_MEMORY,
        "Out of memory.",
        NULL
    },
    {
        SGX_ERROR_ENCLAVE_LOST,
        "Power transition occurred.",
        "Please refer to the sample \"PowerTransition\" for details."
    },
    {
        SGX_ERROR_INVALID_ENCLAVE,
        "Invalid enclave image.",
        NULL
    },
    {
        SGX_ERROR_INVALID_ENCLAVE_ID,
        "Invalid enclave identification.",
        NULL
    },
    {
        SGX_ERROR_INVALID_SIGNATURE,
        "Invalid enclave signature.",
        NULL
    },
    {
        SGX_ERROR_OUT_OF_EPC,
        "Out of EPC memory.",
        NULL
    },
    {
        SGX_ERROR_NO_DEVICE,
        "Invalid SGX device.",
        "Please make sure SGX module is enabled in the BIOS, and install SGX driver afterwards."
    },
    {
        SGX_ERROR_MEMORY_MAP_CONFLICT,
        "Memory map conflicted.",
        NULL
    },
    {
        SGX_ERROR_INVALID_METADATA,
        "Invalid enclave metadata.",
        NULL
    },
    {
        SGX_ERROR_DEVICE_BUSY,
        "SGX device was busy.",
        NULL
    },
    {
        SGX_ERROR_INVALID_VERSION,
        "Enclave version was invalid.",
        NULL
    },
    {
        SGX_ERROR_INVALID_ATTRIBUTE,
        "Enclave was not authorized.",
        NULL
    },
    {
        SGX_ERROR_ENCLAVE_FILE_ACCESS,
        "Can't open enclave file.",
        NULL
    },
};

/* Check error conditions for loading enclave */
void print_error_message(sgx_status_t ret)
{
    size_t idx = 0;
    size_t ttl = sizeof sgx_errlist/sizeof sgx_errlist[0];

    for (idx = 0; idx < ttl; idx++) {
        if(ret == sgx_errlist[idx].err) {
            if(NULL != sgx_errlist[idx].sug)
                printf("Info: %s\n", sgx_errlist[idx].sug);
            printf("Error: %s\n", sgx_errlist[idx].msg);
            break;
        }
    }

    if (idx == ttl)
        printf("Error: Unexpected error occurred.\n");
}

/* OCall functions */
void ocall_print_string(const char *str)
{
    /* Proxy/Bridge will check the length and null-terminate 
     * the input string to prevent buffer overflow. 
     */
    printf("%s", str);
    std::cout << std::endl;
}

thread_local std::chrono::time_point<std::chrono::high_resolution_clock> start;

void ocall_start_clock()
{
	start = std::chrono::high_resolution_clock::now();
}

void ocall_end_clock(const char * str)
{
	auto finish = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> elapsed = finish - start;
    printf(str, elapsed.count());
}

double ocall_get_time()
{
    auto now = std::chrono::high_resolution_clock::now();
	return time_point_cast<microseconds>(now).time_since_epoch().count();
}

ExternMMU emmu;

void* ocall_extern_alloc(size_t size) {
    return emmu.alloc(size);
}


extern "C"
{

    /*
     * Initialize the enclave
     */
    unsigned long int initialize_enclave(void)
    {

        std::cout << "Initializing Enclave..." << std::endl;

        sgx_enclave_id_t eid = 0;
        sgx_launch_token_t token = {0};
        sgx_status_t ret = SGX_ERROR_UNEXPECTED;
        int updated = 0;

        /* call sgx_create_enclave to initialize an enclave instance */
        /* Debug Support: set 2nd parameter to 1 */
        ret = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &eid, NULL);
        if (ret != SGX_SUCCESS) {
            print_error_message(ret);
            throw ret;
        }

        std::cout << "Enclave id: " << eid << std::endl;

        return eid;
    }

    /*
     * Destroy the enclave
     */
    void destroy_enclave(unsigned long int eid)
    {
        std::cout << "Destroying Enclave with id: " << eid << std::endl;
        sgx_destroy_enclave(eid);
    }

    void load_model_float(unsigned long int eid, char* model_json, float** filters) {
    	sgx_status_t ret = ecall_load_model_float(eid, model_json, filters);
    	if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
	}

   void sgx_conv_create(unsigned long int eid, int conv_src_sz[4], int conv_dst_sz[4], 
                         int conv_weight_sz[4], int conv_strides_sz[2], int conv_padding_sz[2], float* weight_data,
                         float* bias_data,  int is_first) {
      sgx_status_t ret = ecall_sgx_conv_create(eid, conv_src_sz, conv_dst_sz, 
                                               conv_weight_sz, conv_strides_sz, conv_padding_sz,
                                               weight_data, bias_data, is_first);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void sgx_depth_conv_create(unsigned long int eid, int conv_src_sz[4], int conv_dst_sz[4], 
                         int conv_weight_sz[4], int conv_strides_sz[2], int conv_padding_sz[2], float* weight_data,
                         float* bias_data,  int is_first) {
      sgx_status_t ret = ecall_sgx_depth_conv_create(eid, conv_src_sz, conv_dst_sz, 
                                               conv_weight_sz, conv_strides_sz, conv_padding_sz,
                                               weight_data, bias_data, is_first);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }
   void sgx_bn_create(unsigned long int eid, int in_size[4], int mode, float eps, float momentum) {
      sgx_status_t ret = ecall_sgx_bn_create(eid, in_size, mode, eps, momentum);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void sgx_pool_create(unsigned long int eid, int in_size[4], 
                         int out_size[4], int kernel_size[2], 
                         int stride[2],  int padding[2],  int type) {
      sgx_status_t ret = ecall_sgx_pool_create(
                         eid, in_size, out_size, 
                         kernel_size, stride,  
                         padding,  type);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void sgx_linear_create(unsigned long int eid, 
                          int in_size[4], 
                          int out_size[2],
                          float* kernel_data, 
                          float* bias_data) {
      sgx_status_t ret = ecall_sgx_linear_create(eid, 
                                                 in_size, 
                                                 out_size, 
                                                 kernel_data, 
                                                 bias_data);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void sgx_relu_create(unsigned long int eid,
                        int in_size[4]) {
      sgx_status_t ret = ecall_sgx_relu_create(eid, in_size);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }


  void resblock_init(unsigned long int eid,
                       int in_size[4], int out_size[4],
                       int stride[2],  int identity) {
      sgx_status_t ret = ecall_resblock_init(eid, in_size, out_size, stride, identity);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void resblock_compl(unsigned long int eid) {
    sgx_status_t ret = ecall_resblock_compl(eid);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void inverted_init(unsigned long int eid) {
    sgx_status_t ret = ecall_inverted_init(eid);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
   }

   void inverted_compl(unsigned long int eid) {
    sgx_status_t ret = ecall_inverted_compl(eid);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void enclave_update_backward(unsigned long int eid) {
    sgx_status_t ret = ecall_enclave_update_backward(eid);
      if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }
   void forward(unsigned long int eid, float* in, float* out) {
    sgx_status_t ret = ecall_forward(eid, in, out);
    if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void enclave_backward(unsigned long int eid, float* in, float* out) {
    sgx_status_t ret = ecall_enclave_backward(eid, in, out);
    if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void setup_final_reorder(unsigned long int eid) {
    sgx_status_t ret = ecall_setup_final_reorder(eid);
    if (ret != SGX_SUCCESS) {
        print_error_message(ret);
        throw ret;
      }
   }

   void predict_float(unsigned long int eid, float* input, float* output, int batch_size) {

		#pragma omp parallel for num_threads(3)
		for (int i=0; i<batch_size; i++) {
			//sgx_status_t ret = ecall_predict_float(eid, input, output, batch_size);
			sgx_status_t ret = ecall_predict_float(eid, input, output, 1);
			printf("predict returned!\n");
			if (ret != SGX_SUCCESS) {
				print_error_message(ret);
				throw ret;
			}
		}
        printf("returning...\n");
	}

    void load_model_float_verify(unsigned long int eid, char* model_json, float** filters, bool preproc) {
		sgx_status_t ret = ecall_load_model_float_verify(eid, model_json, filters, preproc);
		if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
	}

  void predict_verify_float(unsigned long int eid, float* input, float* output, float** aux_data, int batch_size) {
	  
		sgx_status_t ret = ecall_predict_verify_float(eid, input, output, aux_data, batch_size);
		if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
		printf("returning...\n");
	}

  void slalom_relu(unsigned long int eid, float* input, float* output, float* blind, float* relu_src,int num_elements, char* activation) {
	sgx_status_t ret = ecall_slalom_relu(eid, input, output, blind, relu_src, num_elements, activation);
		if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
	}

  void slalom_maxpoolrelu(unsigned long int eid, float* input, float* output, float* workspace, float* relu_src, float* bias, 
							long int dim_in[4], long int dim_out[4],
                            int window_rows, int window_cols, 
							int row_stride, int col_stride, 
							bool same_padding)
	{
	  sgx_status_t ret = ecall_slalom_maxpoolrelu(eid, input, output, workspace, relu_src, bias,
										  			dim_in, dim_out, 
											 		window_rows, window_cols, 
											 		row_stride, col_stride,
													same_padding);
		if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
	}

	void slalom_init(unsigned long int eid, bool integrity, bool privacy, int batch_size) {
		sgx_status_t ret = ecall_slalom_init(eid, integrity, privacy, batch_size);
		if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
	}

	void slalom_get_r(unsigned long int eid, float* out, int size) {
		sgx_status_t ret = ecall_slalom_get_r(eid, out, size);
		if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
	}

	void slalom_set_z(unsigned long int eid, float* z, float* z_enc, int size) {
		sgx_status_t ret = ecall_slalom_set_z(eid, z, z_enc, size);
		if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
	}

	void slalom_blind_input(unsigned long int eid, float* inp, float* out, int size) {
		sgx_status_t ret = ecall_slalom_blind_input(eid, inp, out, size);
		if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
	}

	void sgxdnn_benchmarks(unsigned long int eid, int num_threads) {
		sgx_status_t ret = ecall_sgxdnn_benchmarks(eid, num_threads);
		if (ret != SGX_SUCCESS) {
			print_error_message(ret);
			throw ret;
		}
	}
  
  void dnnl_init(unsigned long int eid, int train_inside, int internal_batch) {
    sgx_status_t ret = ecall_dnnl_init(eid, train_inside, internal_batch);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void setup_relu(unsigned long int eid, int* input_size) {
    sgx_status_t ret = ecall_setup_relu(eid, input_size);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void slalom_relu_back(unsigned long int eid, float* grad, float* relu_diff_src_buf, float* relu_src_buf, float* bias_grad) {
    sgx_status_t ret = ecall_slalom_relu_back(eid, grad, relu_diff_src_buf, relu_src_buf, bias_grad);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }
  void maxpool(unsigned long int eid, float* src, float* dst, float* workspace) {
    sgx_status_t ret = ecall_maxpool(eid, src, dst, workspace);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void setup_maxpoolrelu(unsigned long int eid, int* retval, int* input_size,
                         int* output_size, int* kernel_size, int* strides, int* padding) {
    sgx_status_t ret = ecall_setup_maxpoolrelu(eid, retval, input_size, output_size, kernel_size, strides, padding);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void maxpoolrelu_back(unsigned long int eid, float* grad, float* pool_diff_src_buf, float* relu_src, float* work, float* bias_grad) {
    sgx_status_t ret = ecall_maxpoolrelu_back(eid, grad, pool_diff_src_buf, relu_src, work, bias_grad);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void setup_batchnormsp (unsigned long int eid, int* input_shape, int privacy, float eps, float momentum) {
    sgx_status_t ret = ecall_setup_batchnormsp(eid, input_shape, privacy, eps, momentum);
     if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void batchnormSp_dark(unsigned long int eid, float* output, float* inp, float* means, float* skip_input, float* act_src, 
                              int batch_size, const char* act_mode) {

    sgx_status_t ret = ecall_batchnormSp_dark(eid, output, inp, means, skip_input, act_src, 
                                              batch_size, act_mode);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void fill_parameter_matrix(unsigned long int eid, float* bm_ptr, float* um_ptr, float* gm_ptr,
							 float* igm_ptr, int internal_batch_size) {
	sgx_status_t ret = ecall_fill_parameter_matrix(eid, bm_ptr, um_ptr, gm_ptr,
												   igm_ptr, internal_batch_size);
	if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void get_work_size(unsigned long int eid, int* retval) {
	sgx_status_t ret = ecall_get_work_size(eid, retval);
	if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void print_time_report(unsigned long int eid) {
    sgx_status_t ret = ecall_print_timing(eid);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void reset_timing(unsigned long int eid) {
    sgx_status_t ret = ecall_reset_timing(eid);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void test_handler(unsigned long int eid) {
    sgx_status_t ret = ecall_test_handler(eid);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void batchnormSp_dark_back(unsigned long int eid, float* grad_out, float* grad, float* inp, float* skip_src, float* act_src) {
    sgx_status_t ret = ecall_batchnormSp_dark_back(eid, grad_out, grad, inp, skip_src, act_src);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void resnet_setup_activation(unsigned long int eid,
                               int act_mode_int, 
                               int in_size[4], 
                               int out_size[4], 
                               int pool_window[2], 
                               int pool_stride[2], 
                               float eps,
                               float momentum,
                               float* bias_data) {
    sgx_status_t ret = ecall_resnet_setup_activation(eid, act_mode_int, 
                                                     in_size, out_size, 
                                                     pool_window, pool_stride,
                                                     eps,
                                                     momentum,
                                                     bias_data);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void resnet_setup_bottom(unsigned long int eid,
                           int act_mode_int, 
                           int in_size[4], 
                           int out_size[4], 
                           float eps,
                           float momentum,
                           float* bias_data_l,
                           float* bias_data_r
                           ) {
    sgx_status_t ret = ecall_resnet_setup_bottom(eid, act_mode_int, 
                                                 in_size, out_size, 
                                                 eps,
                                                 momentum,
                                                 bias_data_l,
                                                 bias_data_r);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void resnet_activation_fwd(unsigned long int eid, float* src, 
                             float* dst, float* mean_extern) {
    sgx_status_t ret = ecall_resnet_activation_fwd(eid, src, dst, mean_extern);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void resnet_activation_bwd (unsigned long int eid, float* diff_src_ptr, 
                              float* diff_dst_ptr) {
    sgx_status_t ret = ecall_resnet_activation_bwd(eid, diff_src_ptr, diff_dst_ptr);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void resnet_bottom_fwd (unsigned long int eid,
                          float* left_in,   float* right_in, 
                          float* mean_left, float* mean_right, 
                          float* dst) {
    sgx_status_t ret = ecall_resnet_bottom_fwd(eid, left_in, right_in, mean_left, mean_right, dst);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }

  void resnet_bottom_bwd (unsigned long int eid,
                          float* left_grad, 
                          float* right_grad, 
                          float* diff_dst_ptr) {
    sgx_status_t ret = ecall_resnet_bottom_bwd(eid, left_grad, right_grad, diff_dst_ptr);
    if (ret != SGX_SUCCESS) {
      print_error_message(ret);
      throw ret;
    }
  }
}

/* Application entry */
int main(int argc, char *argv[])
{
    (void)(argc);
    (void)(argv);

    try {
        sgx_enclave_id_t eid = initialize_enclave();

        std::cout << "Enclave id: " << eid << std::endl;
		    ecall_test_grad(eid);
        return 0;
		const unsigned int filter_sizes[] = {3*3*3*64, 64, 
											3*3*64*64, 64, 
											3*3*64*128, 128, 
											3*3*128*128, 128, 
											3*3*128*256, 256, 
											3*3*256*256, 256, 
											3*3*256*256, 256, 
											3*3*256*512, 512, 
											3*3*512*512, 512, 
											3*3*512*512, 512, 
											3*3*512*512, 512, 
											3*3*512*512, 512, 
											3*3*512*512, 512, 
											7 * 7 * 512 * 4096, 4096,
											4096 * 4096, 4096,
											4096 * 1000, 1000};

		float** filters = (float**) malloc(2*16*sizeof(float*));
        for (int i=0; i<2*16; i++) {
			filters[i] = (float*) malloc(filter_sizes[i] * sizeof(float));
		}

		const unsigned int output_sizes[] = {224*224*64,
                                             224*224*64, 
                                             112*112*128, 
                                             112*112*128, 
                                             56*56*256, 
                                             56*56*256, 
                                             56*56*256, 
                                             28*28*512, 
                                             28*28*512, 
                                             28*28*512, 
                                             14*14*512, 
                                             14*14*512, 
                                             14*14*512, 
											 4096,
											 4096,
											 1000};

		float** extras = (float**) malloc(16*sizeof(float*));
		for (int i=0; i<16; i++) {
			extras[i] = (float*) malloc(output_sizes[i] * sizeof(float));
		}

		test_handler(eid);
		return 0;
    float* img = (float*) malloc(224 * 224 * 3 * sizeof(float));
    float* output = (float*) malloc(1000 * sizeof(float));
	

		int input_size[] = {2, 32, 224, 224};
		int output_size[] = {2, 32, 112, 112};
		int kernel_size[] = {2, 2};
		int padding[] = {0, 0};
		int stride[] = {2, 2};
		
		float* grad= new float[2*224*224*32];
		float* res = new float[2*224*224*32];
		float* src = new float[2*224*224*32];
		
		// setup relu and slalom                                                                                                                                                            
		dnnl_init(eid, 1, 2);
		int ress;
		setup_maxpoolrelu(eid, &ress, input_size, output_size, kernel_size,stride, padding);
        // Destroy the enclave
        sgx_destroy_enclave(eid);
		

		
        return 0;
    }
    catch (int e)
    {
        printf("Info: Enclave Launch failed!.\n");
        printf("Enter a character before exit ...\n");
        getchar();
        return -1;
    }
}
