/**
* During our testing, the multi-threaded performance was slower than 
* the single-threaded performance 
* (which may be due to multi-threading affecting the compiler's optimizations), 
* so we did not use the data obtained from the multi-threaded program.
*/

// Define M_PI directly for SGX compatibility
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#include <cstring>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <sgx_trts.h>
#include <sgx_thread.h>
#include <pthread.h>

#include "Enclave.h"
#include "Enclave_t.h"
#include "DataStructure.h"

const size_t MAX_THREADS = 4;

extern ObfParamArray g_obf_params[ObfParamType::OBF_COUNT];
extern NormParamArray g_norm_params[NormParamType::NORM_COUNT];

// Thread data structure for rms_norm
struct RMSNormThreadData {
    float* hidden_states;
    NormParam norm_param;
    size_t batch_size;
    size_t seq_length;
    size_t hidden_size;
    size_t start_batch;
    size_t end_batch;
};

// Thread function for rms_norm
void* rms_norm_thread_func(void* arg) {
    RMSNormThreadData* data = (RMSNormThreadData*)arg;
    float* hidden_states = data->hidden_states;
    NormParam norm_param = data->norm_param;
    size_t batch_size = data->batch_size;
    size_t seq_length = data->seq_length;
    size_t hidden_size = data->hidden_size;
    size_t start_batch = data->start_batch;
    size_t end_batch = data->end_batch;

    // Process assigned batches
    for (size_t b = start_batch; b < end_batch; b++) {
        for (size_t s = 0; s < seq_length; s++) {
            // Compute norm
            float norm = 0.0f;
            size_t idx = b * seq_length * hidden_size + s * hidden_size;
            for (size_t h = 0; h < hidden_size; h++) {
                float val = hidden_states[idx + h];
                norm += val * val;
            }
            norm = norm / hidden_size;
            norm = norm + norm_param.eps;
            norm = 1.0f / sqrt(norm);

            // Apply norm, weight, and bias
            for (size_t h = 0; h < hidden_size; h++) {
                float bias = 0.0;
                if (norm_param.bias != nullptr) {
                    bias = norm_param.bias[h];
                }
                hidden_states[idx + h] = hidden_states[idx + h] * norm * norm_param.weight[h] + bias;
            }
        }
    }

    return nullptr;
}

int rms_norm(float* hidden_states, NormParam norm_param, size_t batch_size, size_t seq_length, size_t hidden_size) {
    // Determine number of threads to use
    size_t num_threads = (batch_size < MAX_THREADS) ? batch_size : MAX_THREADS;
    pthread_t threads[MAX_THREADS];
    RMSNormThreadData thread_data[MAX_THREADS];

    // Calculate batch range per thread
    size_t batch_per_thread = batch_size / num_threads;
    size_t remainder = batch_size % num_threads;

    // Create and start threads
    for (size_t i = 0; i < num_threads; i++) {
        thread_data[i].hidden_states = hidden_states;
        thread_data[i].norm_param = norm_param;
        thread_data[i].batch_size = batch_size;
        thread_data[i].seq_length = seq_length;
        thread_data[i].hidden_size = hidden_size;
        thread_data[i].start_batch = i * batch_per_thread + ((i < remainder) ? i : remainder);
        thread_data[i].end_batch = thread_data[i].start_batch + batch_per_thread + ((i < remainder) ? 1 : 0);

        pthread_create(&threads[i], nullptr, rms_norm_thread_func, &thread_data[i]);
    }

    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; i++) {
        pthread_join(threads[i], nullptr);
    }

    return 0;
}

// Thread data structure for layer_norm
struct LayerNormThreadData {
    float* hidden_states;
    NormParam norm_param;
    size_t batch_size;
    size_t seq_length;
    size_t hidden_size;
    size_t start_batch;
    size_t end_batch;
};

// Thread function for layer_norm
void* layer_norm_thread_func(void* arg) {
    LayerNormThreadData* data = (LayerNormThreadData*)arg;
    float* hidden_states = data->hidden_states;
    NormParam norm_param = data->norm_param;
    size_t batch_size = data->batch_size;
    size_t seq_length = data->seq_length;
    size_t hidden_size = data->hidden_size;
    size_t start_batch = data->start_batch;
    size_t end_batch = data->end_batch;

    // Process assigned batches
    for (size_t b = start_batch; b < end_batch; b++) {
        for (size_t s = 0; s < seq_length; s++) {
            // Compute mean
            float mean = 0.0f;
            size_t idx = b * seq_length * hidden_size + s * hidden_size;
            for (size_t h = 0; h < hidden_size; h++) {
                mean += hidden_states[idx + h];
            }
            mean = mean / hidden_size;

            // Compute variance
            float var = 0.0f;
            for (size_t h = 0; h < hidden_size; h++) {
                float diff = hidden_states[idx + h] - mean;
                var += diff * diff;
            }
            var = var / hidden_size;
            var = var + norm_param.eps;
            var = 1.0f / sqrt(var);

            // Apply norm, weight, and bias
            for (size_t h = 0; h < hidden_size; h++) {
                float bias = 0.0;
                if (norm_param.bias != nullptr) {
                    bias = norm_param.bias[h];
                }
                hidden_states[idx + h] = (hidden_states[idx + h] - mean) * var * norm_param.weight[h] + bias;
            }
        }
    }

    return nullptr;
}

int layer_norm(float* hidden_states, NormParam norm_param, size_t batch_size, size_t seq_length, size_t hidden_size) {
    // Determine number of threads to use
    size_t num_threads = (batch_size < MAX_THREADS) ? batch_size : MAX_THREADS;
    pthread_t threads[MAX_THREADS];
    LayerNormThreadData thread_data[MAX_THREADS];

    // Calculate batch range per thread
    size_t batch_per_thread = batch_size / num_threads;
    size_t remainder = batch_size % num_threads;

    // Create and start threads
    for (size_t i = 0; i < num_threads; i++) {
        thread_data[i].hidden_states = hidden_states;
        thread_data[i].norm_param = norm_param;
        thread_data[i].batch_size = batch_size;
        thread_data[i].seq_length = seq_length;
        thread_data[i].hidden_size = hidden_size;
        thread_data[i].start_batch = i * batch_per_thread + ((i < remainder) ? i : remainder);
        thread_data[i].end_batch = thread_data[i].start_batch + batch_per_thread + ((i < remainder) ? 1 : 0);

        pthread_create(&threads[i], nullptr, layer_norm_thread_func, &thread_data[i]);
    }

    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; i++) {
        pthread_join(threads[i], nullptr);
    }

    return 0;
}

// Thread data structure for restore
struct RestoreThreadData {
    float* x;
    float* y;
    float* y_permuted;
    int* inv_perm;
    ObfParam obf_param;
    size_t batch_size;
    size_t seq_length;
    size_t hidden_size_x;
    size_t hidden_size_y;
    size_t start_batch;
    size_t end_batch;
};

// Thread function for restore
void* restore_thread_func(void* arg) {
    RestoreThreadData* data = (RestoreThreadData*)arg;
    float* x = data->x;
    float* y = data->y;
    float* y_permuted = data->y_permuted;
    int* inv_perm = data->inv_perm;
    ObfParam obf_param = data->obf_param;
    size_t batch_size = data->batch_size;
    size_t seq_length = data->seq_length;
    size_t hidden_size_x = data->hidden_size_x;
    size_t hidden_size_y = data->hidden_size_y;
    size_t start_batch = data->start_batch;
    size_t end_batch = data->end_batch;

    // Process assigned batches
    for (size_t b = start_batch; b < end_batch; b++) {
        for (size_t s = 0; s < seq_length; s++) {
            // Compute x @ mask.view(-1,1)
            float x_dot_mask = 0.0f;
            size_t x_idx = b * seq_length * hidden_size_x + s * hidden_size_x;
            for (size_t h = 0; h < hidden_size_x; h++) {
                x_dot_mask += x[x_idx + h] * obf_param.mask[h];
            }

            // Permute y: y[:,:,inv_perm]
            size_t y_idx = b * seq_length * hidden_size_y + s * hidden_size_y;
            size_t y_permuted_idx = b * seq_length * hidden_size_y + s * hidden_size_y;
            for (size_t h = 0; h < hidden_size_y; h++) {
                y_permuted[y_permuted_idx + h] = y[y_idx + inv_perm[h]];
            }

            // Compute (y[:,:,inv_perm] - (x @ mask.view(-1,1)).repeat(1, 1, y.size(-1)) * ratio_mask) * inv_ratio_w
            for (size_t h = 0; h < hidden_size_y; h++) {
                float term = x_dot_mask * obf_param.ratio_mask[h];
                y[y_idx + h] = (y_permuted[y_permuted_idx + h] - term) / obf_param.ratio_w[h];
            }
        }
    }

    return nullptr;
}

void restore(float* x, float* y, ObfParam obf_param, size_t batch_size, size_t seq_length, size_t hidden_size_x, size_t hidden_size_y) {
    // inv_perm = torch.argsort(permutation)
    int* inv_perm = (int*)malloc(hidden_size_y * sizeof(int));
    for (size_t i = 0; i < hidden_size_y; i++) {
        inv_perm[obf_param.permutation[i]] = i;
    }

    // Create temporary storage for permuted y
    float* y_permuted = (float*)malloc(batch_size * seq_length * hidden_size_y * sizeof(float));

    // Determine number of threads to use
    size_t num_threads = (batch_size < MAX_THREADS) ? batch_size : MAX_THREADS;
    pthread_t threads[MAX_THREADS];
    RestoreThreadData thread_data[MAX_THREADS];

    // Calculate batch range per thread
    size_t batch_per_thread = batch_size / num_threads;
    size_t remainder = batch_size % num_threads;

    // Create and start threads
    for (size_t i = 0; i < num_threads; i++) {
        thread_data[i].x = x;
        thread_data[i].y = y;
        thread_data[i].y_permuted = y_permuted;
        thread_data[i].inv_perm = inv_perm;
        thread_data[i].obf_param = obf_param;
        thread_data[i].batch_size = batch_size;
        thread_data[i].seq_length = seq_length;
        thread_data[i].hidden_size_x = hidden_size_x;
        thread_data[i].hidden_size_y = hidden_size_y;
        thread_data[i].start_batch = i * batch_per_thread + ((i < remainder) ? i : remainder);
        thread_data[i].end_batch = thread_data[i].start_batch + batch_per_thread + ((i < remainder) ? 1 : 0);

        pthread_create(&threads[i], nullptr, restore_thread_func, &thread_data[i]);
    }

    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; i++) {
        pthread_join(threads[i], nullptr);
    }

    // Free temporary memory
    free(inv_perm);
    free(y_permuted);
}


int ecall_norm(float* hidden_states, int layer_idx, int norm_type, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (norm_type == NormParamType::NORM || norm_type == NormParamType::LN_F) {
        layer_idx = 0;
    }
    if (layer_idx < 0) {
        return -1;
    }

    // copy hidden_states to output
    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));

    start_clock();
    int ret = 0;
    switch (norm_type) {
        case NormParamType::INPUT_LAYERNORM:
        case NormParamType::POST_ATTENTION_LAYERNORM:
        case NormParamType::Q_NORM:
        case NormParamType::K_NORM:
        case NormParamType::PRE_FEEDFORWARD_LAYERNORM:
        case NormParamType::POST_FEEDFORWARD_LAYERNORM:
        case NormParamType::NORM:
            if (layer_idx >= g_norm_params[norm_type].count) {
                return -1;
            }
            ret = rms_norm(output_enclave, g_norm_params[norm_type].params[layer_idx], batch_size, seq_length, hidden_size);
            break;

        case NormParamType::LN_1:
        case NormParamType::LN_2:
        case NormParamType::LN_F:
            ret =  layer_norm(output_enclave, g_norm_params[norm_type].params[layer_idx], batch_size, seq_length, hidden_size);
            break;
        default:
            ret = -1;
            printf("ecall_norm: norm_type %d is not supported\n", norm_type);
            break;
    }
    end_clock("----------------------------------------------------\nSgx cost(norm): %.6f milliseconds\n----------------------------------------------------\n");
    
    memcpy(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return ret;
}

// Thread data structure for silu_activation
struct SiluThreadData {
    float* hidden_states;
    size_t batch_size;
    size_t seq_length;
    size_t hidden_size;
    size_t start_batch;
    size_t end_batch;
};

// Thread function for silu_activation
void* silu_thread_func(void* arg) {
    SiluThreadData* data = (SiluThreadData*)arg;
    float* hidden_states = data->hidden_states;
    size_t batch_size = data->batch_size;
    size_t seq_length = data->seq_length;
    size_t hidden_size = data->hidden_size;
    size_t start_batch = data->start_batch;
    size_t end_batch = data->end_batch;

    // Process assigned batches
    for (size_t b = start_batch; b < end_batch; b++) {
        for (size_t s = 0; s < seq_length; s++) {
            size_t idx = b * seq_length * hidden_size + s * hidden_size;
            for (size_t h = 0; h < hidden_size; h++) {
                float input = hidden_states[idx + h];
                hidden_states[idx + h] = input / (1.0f + expf(-input));
            }
        }
    }

    return nullptr;
}

// Thread data structure for new_gelu_activation
struct NewGeluThreadData {
    float* hidden_states;
    size_t batch_size;
    size_t seq_length;
    size_t hidden_size;
    size_t start_batch;
    size_t end_batch;
};

// Thread function for new_gelu_activation
void* new_gelu_thread_func(void* arg) {
    NewGeluThreadData* data = (NewGeluThreadData*)arg;
    float* hidden_states = data->hidden_states;
    size_t batch_size = data->batch_size;
    size_t seq_length = data->seq_length;
    size_t hidden_size = data->hidden_size;
    size_t start_batch = data->start_batch;
    size_t end_batch = data->end_batch;

    // Process assigned batches
    for (size_t b = start_batch; b < end_batch; b++) {
        for (size_t s = 0; s < seq_length; s++) {
            size_t idx = b * seq_length * hidden_size + s * hidden_size;
            for (size_t h = 0; h < hidden_size; h++) {
                float input = hidden_states[idx + h];
                hidden_states[idx + h] = 0.5f * input * 
                        (1.0f + tanhf(sqrt(2.0 / M_PI) * (input + 0.044715 * input * input * input)));
            }
        }
    }

    return nullptr;
}

// Thread data structure for gelu_tanh_activation
struct GeluTanhThreadData {
    float* hidden_states;
    size_t batch_size;
    size_t seq_length;
    size_t hidden_size;
    size_t start_batch;
    size_t end_batch;
};

// Thread function for gelu_tanh_activation
void* gelu_tanh_thread_func(void* arg) {
    GeluTanhThreadData* data = (GeluTanhThreadData*)arg;
    float* hidden_states = data->hidden_states;
    size_t batch_size = data->batch_size;
    size_t seq_length = data->seq_length;
    size_t hidden_size = data->hidden_size;
    size_t start_batch = data->start_batch;
    size_t end_batch = data->end_batch;

    // Process assigned batches
    for (size_t b = start_batch; b < end_batch; b++) {
        for (size_t s = 0; s < seq_length; s++) {
            size_t idx = b * seq_length * hidden_size + s * hidden_size;
            for (size_t h = 0; h < hidden_size; h++) {
                float input = hidden_states[idx + h];
                hidden_states[idx + h] = input * 0.5f *
                        (1.0f + tanhf(sqrt(2.0 / M_PI) * (input + 0.044715 * input * input * input)));
            }
        }
    }

    return nullptr;
}

// Thread data structure for bias processing
struct BiasThreadData {
    float* hidden_states;
    float* bias;
    size_t batch_size;
    size_t seq_length;
    size_t hidden_size;
    size_t start_batch;
    size_t end_batch;
    bool add;
};

// Thread function for bias processing
void* bias_thread_func(void* arg) {
    BiasThreadData* data = (BiasThreadData*)arg;
    float* hidden_states = data->hidden_states;
    float* bias = data->bias;
    size_t batch_size = data->batch_size;
    size_t seq_length = data->seq_length;
    size_t hidden_size = data->hidden_size;
    size_t start_batch = data->start_batch;
    size_t end_batch = data->end_batch;
    bool add = data->add;

    // Process assigned batches
    for (size_t b = start_batch; b < end_batch; b++) {
        for (size_t s = 0; s < seq_length; s++) {
            size_t idx = b * seq_length * hidden_size + s * hidden_size;
            for (size_t h = 0; h < hidden_size; h++) {
                if (add) {
                    hidden_states[idx + h] += bias[h];
                } else {
                    hidden_states[idx + h] -= bias[h];
                }
            }
        }
    }

    return nullptr;
}

int ecall_silu_activation(float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (hidden_states == nullptr || output == nullptr) {
        return -1;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();

    // Determine number of threads to use
    size_t num_threads = (batch_size < MAX_THREADS) ? batch_size : MAX_THREADS;
    pthread_t threads[MAX_THREADS];
    SiluThreadData thread_data[MAX_THREADS];

    // Calculate batch range per thread
    size_t batch_per_thread = batch_size / num_threads;
    size_t remainder = batch_size % num_threads;

    // Create and start threads
    for (size_t i = 0; i < num_threads; i++) {
        thread_data[i].hidden_states = output_enclave;
        thread_data[i].batch_size = batch_size;
        thread_data[i].seq_length = seq_length;
        thread_data[i].hidden_size = hidden_size;
        thread_data[i].start_batch = i * batch_per_thread + ((i < remainder) ? i : remainder);
        thread_data[i].end_batch = thread_data[i].start_batch + batch_per_thread + ((i < remainder) ? 1 : 0);

        pthread_create(&threads[i], nullptr, silu_thread_func, &thread_data[i]);
    }

    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; i++) {
        pthread_join(threads[i], nullptr);
    }

    end_clock("----------------------------------------------------\nSgx cost(silu): %.6f milliseconds\n----------------------------------------------------\n");

    memcpy(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

int ecall_new_gelu_activation(float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (hidden_states == nullptr || output == nullptr) {
        return -1;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();
    // Determine number of threads to use
    size_t num_threads = (batch_size < MAX_THREADS) ? batch_size : MAX_THREADS;
    pthread_t threads[MAX_THREADS];
    NewGeluThreadData thread_data[MAX_THREADS];

    // Calculate batch range per thread
    size_t batch_per_thread = batch_size / num_threads;
    size_t remainder = batch_size % num_threads;

    // Create and start threads
    for (size_t i = 0; i < num_threads; i++) {
        thread_data[i].hidden_states = output_enclave;
        thread_data[i].batch_size = batch_size;
        thread_data[i].seq_length = seq_length;
        thread_data[i].hidden_size = hidden_size;
        thread_data[i].start_batch = i * batch_per_thread + ((i < remainder) ? i : remainder);
        thread_data[i].end_batch = thread_data[i].start_batch + batch_per_thread + ((i < remainder) ? 1 : 0);

        pthread_create(&threads[i], nullptr, new_gelu_thread_func, &thread_data[i]);
    }

    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; i++) {
        pthread_join(threads[i], nullptr);
    }
    end_clock("----------------------------------------------------\nSgx cost(new_gelu): %.6f milliseconds\n----------------------------------------------------\n");

    memcpy(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

int ecall_gelu_tanh_activation(float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (hidden_states == nullptr || output == nullptr) {
        return -1;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();
    // Determine number of threads to use
    size_t num_threads = (batch_size < MAX_THREADS) ? batch_size : MAX_THREADS;
    pthread_t threads[MAX_THREADS];
    GeluTanhThreadData thread_data[MAX_THREADS];

    // Calculate batch range per thread
    size_t batch_per_thread = batch_size / num_threads;
    size_t remainder = batch_size % num_threads;

    // Create and start threads
    for (size_t i = 0; i < num_threads; i++) {
        thread_data[i].hidden_states = output_enclave;
        thread_data[i].batch_size = batch_size;
        thread_data[i].seq_length = seq_length;
        thread_data[i].hidden_size = hidden_size;
        thread_data[i].start_batch = i * batch_per_thread + ((i < remainder) ? i : remainder);
        thread_data[i].end_batch = thread_data[i].start_batch + batch_per_thread + ((i < remainder) ? 1 : 0);

        pthread_create(&threads[i], nullptr, gelu_tanh_thread_func, &thread_data[i]);
    }

    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; i++) {
        pthread_join(threads[i], nullptr);
    }

    end_clock("----------------------------------------------------\nSgx cost(new_gelu): %.6f milliseconds\n----------------------------------------------------\n");

    memcpy(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

int ecall_restore(float* x, float* y, int layer_idx, int layer_type, float* bias, size_t batch_size, size_t seq_length, size_t hidden_size_x, size_t hidden_size_y, float* output) {
    if (x == nullptr || y == nullptr || output == nullptr) {
        printf("Error: ecall_restore: x, y, output cannot be nullptr\n");
        return -1;
    }

    if (layer_type < 0 || layer_type >= ObfParamType::OBF_COUNT) {
        printf("Error: ecall_restore: layer_type %d is not supported\n", layer_type);
        return -2;
    }

    if (layer_idx < 0 || layer_idx >= g_obf_params[layer_type].count) {
        printf("Error: ecall_restore: layer_idx %d out of range [0, %d)\n", layer_idx, g_obf_params[layer_type].count);
        return -3;
    }

    if (hidden_size_x != g_obf_params[layer_type].params[layer_idx].mask_size) {
        return -4;
    }

    if (hidden_size_y != g_obf_params[layer_type].params[layer_idx].ratio_size) {
        return -5;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size_y * sizeof(float));
    memcpy(output_enclave, y, batch_size * seq_length * hidden_size_y * sizeof(float));
    
    start_clock();
    if (bias != nullptr) {
        // y = y - bias
        // Determine number of threads to use
        size_t num_threads = (batch_size < MAX_THREADS) ? batch_size : MAX_THREADS;
        pthread_t threads[MAX_THREADS];
        BiasThreadData thread_data[MAX_THREADS];

        // Calculate batch range per thread
        size_t batch_per_thread = batch_size / num_threads;
        size_t remainder = batch_size % num_threads;

        // Create and start threads
        for (size_t i = 0; i < num_threads; i++) {
            thread_data[i].hidden_states = output_enclave;
            thread_data[i].bias = bias;
            thread_data[i].batch_size = batch_size;
            thread_data[i].seq_length = seq_length;
            thread_data[i].hidden_size = hidden_size_y;
            thread_data[i].start_batch = i * batch_per_thread + ((i < remainder) ? i : remainder);
            thread_data[i].end_batch = thread_data[i].start_batch + batch_per_thread + ((i < remainder) ? 1 : 0);
            thread_data[i].add = false;

            pthread_create(&threads[i], nullptr, bias_thread_func, &thread_data[i]);
        }

        // Wait for all threads to complete
        for (size_t i = 0; i < num_threads; i++) {
            pthread_join(threads[i], nullptr);
        }
    }

    // x, y, mask, ratio_mask, ratio_w, permutation

    restore(x, output_enclave, g_obf_params[layer_type].params[layer_idx], batch_size, seq_length, hidden_size_x, hidden_size_y);
    
    // y = y + bias
    if (bias != nullptr) {
        // Determine number of threads to use
        size_t num_threads = (batch_size < MAX_THREADS) ? batch_size : MAX_THREADS;
        pthread_t threads[MAX_THREADS];
        BiasThreadData thread_data[MAX_THREADS];

        // Calculate batch range per thread
        size_t batch_per_thread = batch_size / num_threads;
        size_t remainder = batch_size % num_threads;

        // Create and start threads
        for (size_t i = 0; i < num_threads; i++) {
            thread_data[i].hidden_states = output_enclave;
            thread_data[i].bias = bias;
            thread_data[i].batch_size = batch_size;
            thread_data[i].seq_length = seq_length;
            thread_data[i].hidden_size = hidden_size_y;
            thread_data[i].start_batch = i * batch_per_thread + ((i < remainder) ? i : remainder);
            thread_data[i].end_batch = thread_data[i].start_batch + batch_per_thread + ((i < remainder) ? 1 : 0);
            thread_data[i].add = true;

            pthread_create(&threads[i], nullptr, bias_thread_func, &thread_data[i]);
        }

        // Wait for all threads to complete
        for (size_t i = 0; i < num_threads; i++) {
            pthread_join(threads[i], nullptr);
        }
    }
    end_clock("----------------------------------------------------\nSgx cost(restore): %.6f milliseconds\n----------------------------------------------------\n");

    memcpy(output, output_enclave, batch_size * seq_length * hidden_size_y * sizeof(float));

    free(output_enclave);
    return 0;
}

