#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <chrono>
#include <omp.h>
#include <cstdlib>
#include <malloc.h>
#include <time.h>

#include "sgx_tseal.h"
#include "sgx_urts.h"
#include "Enclave_u.h"


// Global enclave ID
sgx_enclave_id_t global_eid = 0;

#ifdef __cplusplus
extern "C" {
#endif

extern double sgx_exe_time;

/**
* Initialize SGX enclave
*/
int init_enclave(const char* enclave_path) {
    sgx_status_t status = SGX_SUCCESS;
    sgx_launch_token_t token = {0};
    int updated = 0;

    // 创建 enclave
    status = sgx_create_enclave(enclave_path, SGX_DEBUG_FLAG, &token, &updated, &global_eid, NULL);
    if (status != SGX_SUCCESS) {
        return (int)status;
    }

    return 0;
}

double get_sgx_exe_time() {
    return sgx_exe_time;
}

void reset_sgx_exe_time() {
    sgx_exe_time = 0.0;
}

/**
* Get the current enclave ID
*/
sgx_enclave_id_t get_enclave_id() {
    return global_eid;
}

/**
* Destroy SGX enclave
*/
void destroy_enclave() {
    if (global_eid != 0) {
        sgx_destroy_enclave(global_eid);
        global_eid = 0;
    }
}

/**
* Prepare obfuscation parameters in SGX
*/
int prepare_obf_params(sgx_enclave_id_t eid, int param_type, ObfParamArray* params) {
    if (eid == 0) {
        return -1;
    }

    if (params == NULL) {
        return -3;
    }

    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;

    sgx_ret = ecall_prepare_obf_params(eid, &ret, param_type, *params);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

/**
* Prepare normalization parameters in SGX
*/
int prepare_norm_params(sgx_enclave_id_t eid, int param_type, NormParamArray* params) {
    if (eid == 0) {
        return -1;
    }

    if (params == NULL) {
        return -3;
    }

    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_prepare_norm_params(eid, &ret, param_type, *params);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

int norm(sgx_enclave_id_t eid, float* hidden_states, int layer_idx, int norm_type, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (eid == 0) {
        return -1;
    }

    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_norm(eid, &ret, hidden_states, layer_idx, norm_type, batch_size, seq_length, hidden_size, output);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

int gelu_tanh_activation(sgx_enclave_id_t eid, float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (eid == 0) {
        return -1;
    }

    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_gelu_tanh_activation(eid, &ret, hidden_states, batch_size, seq_length, hidden_size, output);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

int new_gelu_activation(sgx_enclave_id_t eid, float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (eid == 0) {
        return -1;
    }

    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_new_gelu_activation(eid, &ret, hidden_states, batch_size, seq_length, hidden_size, output);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

int silu_activation(sgx_enclave_id_t eid, float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (eid == 0) {
        return -1;
    }

    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_silu_activation(eid, &ret, hidden_states, batch_size, seq_length, hidden_size, output);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

int restore(sgx_enclave_id_t eid, float* x, float* y, int layer_idx, int layer_type, float* bias, size_t batch_size, size_t seq_length, size_t hidden_size_x, size_t hidden_size_y, float* output) {
    if (eid == 0) {
        return -1;
    }

    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_restore(eid, &ret, x, y, layer_idx, layer_type, bias, batch_size, seq_length, hidden_size_x, hidden_size_y, output);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

#ifdef __cplusplus
}
#endif

