#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <chrono>
#include <omp.h>
#include <cstdlib>
#include <malloc.h>
#include <time.h>

#include "sgx_tseal.h"
#include "sgx_urts.h"
#include "Enclave_u.h"


// Global enclave ID
sgx_enclave_id_t global_eid = 0;

#ifdef __cplusplus
extern "C" {
#endif

/**
* Initialize SGX enclave
*/
int init_enclave(const char* enclave_path) {
    sgx_status_t status = SGX_SUCCESS;
    sgx_launch_token_t token = {0};
    int updated = 0;

    // 创建 enclave
    status = sgx_create_enclave(enclave_path, SGX_DEBUG_FLAG, &token, &updated, &global_eid, NULL);
    if (status != SGX_SUCCESS) {
        return (int)status;
    }

    return 0;
}

/**
* Get the current enclave ID
*/
sgx_enclave_id_t get_enclave_id() {
    return global_eid;
}

/**
* Destroy SGX enclave
*/
void destroy_enclave() {
    if (global_eid != 0) {
        sgx_destroy_enclave(global_eid);
        global_eid = 0;
    }
}

/**
* Initialize obfuscation parameters
*/
int prepare_input_obf_params(sgx_enclave_id_t eid, ObfusVectorListList* v_list0, ObfusIndicesList* indices_list0, 
                    ObfusVectorListList* v_list, ObfusIndicesList* indices_list, ObfusIndicesList* indices_list_inv) {
    if (eid == 0) {
        return -1;
    }
    
    if (v_list0 == NULL || indices_list0 == NULL || v_list == NULL || indices_list == NULL || indices_list_inv == NULL) {
        return -3;
    }
    
    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_prepare_input_obf_params(eid, &ret, *v_list0, *indices_list0, *v_list, *indices_list, *indices_list_inv);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

/**
* Initialize otp parameters
*/
int prepare_otp_params(sgx_enclave_id_t eid, Otp* otp, OtpLogits* otp_logits) {
    if (eid == 0) {
        return -1;
    }
    
    if (otp == NULL || otp_logits == NULL) {
        return -3;
    }
    
    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_prepare_otp_params(eid, &ret, *otp, *otp_logits);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

/**
* Perform obfuscation operation
*/
int perform_obfuscation(sgx_enclave_id_t eid, float* hidden_states, float* residual, size_t batch_size, 
                             size_t seq_length, size_t hidden_size, float* output, int optimized_stage) {
    if (eid == 0) {
        return -1;
    }
    
    if (hidden_states == NULL || residual == NULL || output == NULL) {
        return -3;
    }
    
    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    if (optimized_stage == 0) {
        sgx_ret = ecall_perform_obfuscation_optimized0(eid, &ret, hidden_states, residual, batch_size, seq_length, hidden_size, output);
    } else if (optimized_stage == 1) {
        sgx_ret = ecall_perform_obfuscation_optimized1(eid, &ret, hidden_states, residual, batch_size, seq_length, hidden_size, output);
    } else if (optimized_stage == 2) {
        sgx_ret = ecall_perform_obfuscation_optimized2(eid, &ret, hidden_states, residual, batch_size, seq_length, hidden_size, output);
    } else if (optimized_stage == 3) {
        sgx_ret = ecall_perform_obfuscation_optimized3(eid, &ret, hidden_states, residual, batch_size, seq_length, hidden_size, output);
    } else {
        sgx_ret = ecall_perform_obfuscation(eid, &ret, hidden_states, residual, batch_size, seq_length, hidden_size, output);
    }
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

/**
* Perform otp operation
*/
int perform_otp(sgx_enclave_id_t eid, float* hidden_states, size_t batch_size, size_t logits_to_keep, size_t hidden_size, float* output) {
    if (eid == 0) {
        return -1;
    }
    
    if (hidden_states == NULL || output == NULL) {
        return -3;
    }
    
    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_perform_otp(eid, &ret, hidden_states, batch_size, logits_to_keep, hidden_size, output);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}

/**
* Perform logits recover operation
*/
int perform_logits_recover(sgx_enclave_id_t eid, float* hidden_states, size_t batch_size, size_t logits_to_keep, size_t hidden_size, int64_t* output) {
    if (eid == 0) {
        return -1;
    }
    
    if (hidden_states == NULL || output == NULL) {
        return -3;
    }
    
    int ret = 0;
    sgx_status_t sgx_ret = SGX_SUCCESS;
    
    sgx_ret = ecall_perform_logits_recover(eid, &ret, hidden_states, batch_size, logits_to_keep, hidden_size, output);
    if (sgx_ret != SGX_SUCCESS) {
        return -2;
    }
    
    return ret;
}


#ifdef __cplusplus
}
#endif

