// Block AVX-512 headers
#define __AVX512FP16INTRIN_H_INCLUDED 1
#define __AVX512FP16VLINTRIN_H_INCLUDED 1
#define __AVX512INTRIN_H_INCLUDED 1
#define __AVX512VLINTRIN_H_INCLUDED 1
#define __AVX512CDINTRIN_H_INCLUDED 1
#define __AVX512BWINTRIN_H_INCLUDED 1
#define __AVX512DQINTRIN_H_INCLUDED 1
#define __AVX512ERINTRIN_H_INCLUDED 1
#define __AVX512PFINTRIN_H_INCLUDED 1
#define __AVX512VBMIINTRIN_H_INCLUDED 1
#define __AVX512VPOPCNTDQINTRIN_H_INCLUDED 1

// Define missing AVX-512 related types that are not available in AVX2
#ifndef _Float16
#define _Float16 float
#endif

#ifndef __m128h
#define __m128h __m128i
#endif

#ifndef __m256h
#define __m256h __m256i
#endif

#include "Obfuscation.h"
#include "Internal.h"
#include "Enclave.h"
#include "Enclave_t.h"
#include <stdlib.h>
#include <string.h>
#include <sgx_trts.h>
#include <sgx_thread.h>
#include <pthread.h>
#include <immintrin.h>   
#include <omp.h>

// Strange errors occur in the computation when the value is greater than or equal to <TCSNum>
const size_t MAX_THREADS = 16;



// Global variables to store obfuscation parameters
VectorListList g_v_list0 = {0, nullptr};
IndicesList g_indices_list0 = {0, nullptr};
VectorListList g_v_list = {0, nullptr};
IndicesList g_indices_list = {0, nullptr};
IndicesList g_indices_list_inv = {0, nullptr};

void free_vector_list_list(VectorListList& vector_list_list) {
    if (vector_list_list.vectors_list != nullptr) {
        for (size_t i = 0; i < vector_list_list.count; ++i) {
            if (vector_list_list.vectors_list[i].vectors != nullptr) {
                for (size_t j = 0; j < vector_list_list.vectors_list[i].count; ++j) {
                    free(vector_list_list.vectors_list[i].vectors[j].data);
                }
                free(vector_list_list.vectors_list[i].vectors);
            }
        }
        free(vector_list_list.vectors_list);
        vector_list_list.vectors_list = nullptr;
        vector_list_list.count = 0;
    }
}

void free_indices_list(IndicesList& indices_list) {
    if (indices_list.indices != nullptr) {
        for (size_t i = 0; i < indices_list.count; ++i) {
            free(indices_list.indices[i].data);
        }
        free(indices_list.indices);
        indices_list.indices = nullptr;
        indices_list.count = 0;
    }
}

int copy_obfus_vector_list(const ObfusVectorListList& src, VectorListList& dst) {
    dst.count = src.num_lists;
    
    dst.vectors_list = (VectorList*) aligned_alloc(32, dst.count * sizeof(VectorList));
    if (dst.vectors_list == nullptr) {
        return -1;
    }

    for (size_t i = 0; i < dst.count; ++i) {
        dst.vectors_list[i].count = src.vector_list_list[i].num_vectors;
        dst.vectors_list[i].vectors = (Vector*) aligned_alloc(32, dst.vectors_list[i].count * sizeof(Vector));
        if (dst.vectors_list[i].vectors == nullptr) {
            return -1;
        }
        for (size_t j = 0; j < dst.vectors_list[i].count; ++j) {
            dst.vectors_list[i].vectors[j].size = src.vector_list_list[i].vector_list[j].vector_size;
            dst.vectors_list[i].vectors[j].data = (float*) aligned_alloc(32, dst.vectors_list[i].vectors[j].size * sizeof(float));
            if (dst.vectors_list[i].vectors[j].data == nullptr) {
                return -1;
            }
            memcpy(dst.vectors_list[i].vectors[j].data, src.vector_list_list[i].vector_list[j].vectors, dst.vectors_list[i].vectors[j].size * sizeof(float));
        }
    }
    
    return 0;
}

int copy_obfus_indices_list(const ObfusIndicesList& src, IndicesList& dst) {
    dst.count = src.num_lists;
    dst.indices = (Indices*) aligned_alloc(32, dst.count * sizeof(Indices));
    if (dst.indices == nullptr) {
        return -1;
    }
    
    for (size_t i = 0; i < dst.count; ++i) {
        dst.indices[i].size = src.indices_list[i].num_indices;
        dst.indices[i].data = (uint32_t*) aligned_alloc(32, dst.indices[i].size * sizeof(uint32_t));
        if (dst.indices[i].data == nullptr) {
            return -1;
        }
        memcpy(dst.indices[i].data, src.indices_list[i].indices, dst.indices[i].size * sizeof(uint32_t));
    }
    
    return 0;
}

// Initialize obfuscation parameters in enclave memory
int ecall_prepare_input_obf_params(ObfusVectorListList v_list0, ObfusIndicesList indices_list0, ObfusVectorListList v_list, ObfusIndicesList indices_list, ObfusIndicesList indices_list_inv) {
    // Clean up existing data if any
    free_vector_list_list(g_v_list0);
    free_indices_list(g_indices_list0);
    free_vector_list_list(g_v_list);
    free_indices_list(g_indices_list);
    free_indices_list(g_indices_list_inv);
    
    // Copy v_list0
    if (copy_obfus_vector_list(v_list0, g_v_list0) != 0) {
        return -1;
    }
    
    // Copy indices_list0
    if (copy_obfus_indices_list(indices_list0, g_indices_list0) != 0) {
        return -1;
    }
    
    // Copy v_list
    if (copy_obfus_vector_list(v_list, g_v_list) != 0) {
        return -1;
    }
    
    // Copy indices_list
    if (copy_obfus_indices_list(indices_list, g_indices_list) != 0) {
        return -1;
    }

    if (copy_obfus_indices_list(indices_list_inv, g_indices_list_inv) != 0) {
        return -1;
    }
    
    return 0;
}

// Helper function to perform Householder reflection on columns
void house_holder_cal_col_optimized0(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, const VectorList& vector_list) {
    size_t last_col = 0;
    
    for (size_t v_idx = 0; v_idx < vector_list.count; ++v_idx) {
        Vector* vec = &vector_list.vectors[v_idx];
        size_t dim_v = vec->size;
        
        // Process each element in the batch and sequence
        // Reorder loops to improve cache locality
        for (size_t b = 0; b < batch_size; ++b) {
            for (size_t s = 0; s < seq_length; ++s) {
                // Calculate dot product: x @ vector
                float dot = 0.0f;
                size_t base_idx = b * seq_length * hidden_size + s * hidden_size + last_col;
                
                // Vectorized dot product calculation
                for (size_t i = 0; i < dim_v; ++i) {
                    dot += input[base_idx + i] * vec->data[i];
                }
                
                // Update the block: x - 2 * (x @ vector) * vector
                // Reuse base_idx to avoid recalculating
                float factor = -2.0f * dot;
                for (size_t i = 0; i < dim_v; ++i) {
                    input[base_idx + i] += factor * vec->data[i];
                }
            }
        }
        
        last_col += dim_v;
    }
}

// Thread parameter structure for house_holder_cal_col_optimized1
struct HouseHolderThreadArgs {
    float* input;
    size_t batch_size;
    size_t seq_length;
    size_t hidden_size;
    const VectorList* vector_list;
    size_t thread_id;
    size_t num_threads;
};

// Thread function for house_holder_cal_col_optimized1
void* house_holder_cal_col_thread(void* args) {
    HouseHolderThreadArgs* thread_args = (HouseHolderThreadArgs*)args;
    float* input = thread_args->input;
    size_t batch_size = thread_args->batch_size;
    size_t seq_length = thread_args->seq_length;
    size_t hidden_size = thread_args->hidden_size;
    const VectorList* vector_list = thread_args->vector_list;
    size_t thread_id = thread_args->thread_id;
    size_t num_threads = thread_args->num_threads;
    
    size_t last_col = 0;
    
    for (size_t v_idx = 0; v_idx < vector_list->count; ++v_idx) {
        const Vector* vec = &vector_list->vectors[v_idx];
        size_t dim_v = vec->size;
        
        // Calculate the range of batches and sequences for this thread
        size_t total_batches = batch_size;
        size_t total_sequences = seq_length;
        size_t total_items = total_batches * total_sequences;
        size_t items_per_thread = (total_items + num_threads - 1) / num_threads;
        size_t start_item = thread_id * items_per_thread;
        size_t end_item = (thread_id + 1) * items_per_thread;
        if (end_item > total_items) end_item = total_items;
        
        // Process the assigned batches and sequences
        for (size_t item = start_item; item < end_item; ++item) {
            size_t b = item / total_sequences;
            size_t s = item % total_sequences;
            
            // Calculate dot product: x @ vector using AVX2
            __m256 sum_vec = _mm256_setzero_ps();
            float dot = 0.0f;
            size_t base_idx = b * seq_length * hidden_size + s * hidden_size + last_col;
            size_t i = 0;
            
            // Vectorized dot product calculation (8 floats at a time)
            for (; i + 8 <= dim_v; i += 8) {
                __m256 x_vec = _mm256_load_ps(&input[base_idx + i]);
                __m256 v_vec = _mm256_load_ps(&vec->data[i]);
                __m256 mul_vec = _mm256_mul_ps(x_vec, v_vec);
                sum_vec = _mm256_add_ps(sum_vec, mul_vec);
            }
            
            // Horizontal sum of the vector
            float sum_array[8];
            _mm256_store_ps(sum_array, sum_vec);
            dot += sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3] + 
                   sum_array[4] + sum_array[5] + sum_array[6] + sum_array[7];
            
            // Handle remaining elements (if any)
            for (; i < dim_v; ++i) {
                dot += input[base_idx + i] * vec->data[i];
            }
            
            // Update the block: x - 2 * (x @ vector) * vector using AVX2
            float factor = -2.0f * dot;
            __m256 factor_vec = _mm256_set1_ps(factor);
            i = 0;
            
            // Vectorized update (8 floats at a time)
            for (; i + 8 <= dim_v; i += 8) {
                __m256 v_vec = _mm256_load_ps(&vec->data[i]);
                __m256 x_vec = _mm256_load_ps(&input[base_idx + i]);
                __m256 update_vec = _mm256_mul_ps(factor_vec, v_vec);
                __m256 result_vec = _mm256_add_ps(x_vec, update_vec);
                _mm256_store_ps(&input[base_idx + i], result_vec);
            }
            
            // Handle remaining elements (if any)
            for (; i < dim_v; ++i) {
                input[base_idx + i] += factor * vec->data[i];
            }
        }
        
        last_col += dim_v;
    }
    
    return NULL;
}

// Thread function for house_holder_cal_col_optimized2 (no AVX)
void* house_holder_cal_col_thread2(void* args) {
    HouseHolderThreadArgs* thread_args = (HouseHolderThreadArgs*)args;
    float* input = thread_args->input;
    size_t batch_size = thread_args->batch_size;
    size_t seq_length = thread_args->seq_length;
    size_t hidden_size = thread_args->hidden_size;
    const VectorList* vector_list = thread_args->vector_list;
    size_t thread_id = thread_args->thread_id;
    size_t num_threads = thread_args->num_threads;
    
    size_t last_col = 0;
    
    for (size_t v_idx = 0; v_idx < vector_list->count; ++v_idx) {
        const Vector* vec = &vector_list->vectors[v_idx];
        size_t dim_v = vec->size;
        
        // Calculate the range of batches and sequences for this thread
        size_t total_batches = batch_size;
        size_t total_sequences = seq_length;
        size_t total_items = total_batches * total_sequences;
        size_t items_per_thread = (total_items + num_threads - 1) / num_threads;
        size_t start_item = thread_id * items_per_thread;
        size_t end_item = (thread_id + 1) * items_per_thread;
        if (end_item > total_items) end_item = total_items;
        
        // Process the assigned batches and sequences
        for (size_t item = start_item; item < end_item; ++item) {
            size_t b = item / total_sequences;
            size_t s = item % total_sequences;
            
            // Calculate dot product: x @ vector (no AVX)
            float dot = 0.0f;
            size_t base_idx = b * seq_length * hidden_size + s * hidden_size + last_col;
            
            // Cache-friendly dot product calculation
            for (size_t i = 0; i < dim_v; ++i) {
                dot += input[base_idx + i] * vec->data[i];
            }
            
            // Update the block: x - 2 * (x @ vector) * vector (no AVX)
            float factor = -2.0f * dot;
            for (size_t i = 0; i < dim_v; ++i) {
                input[base_idx + i] += factor * vec->data[i];
            }
        }
        
        last_col += dim_v;
    }
    
    return NULL;
}

// Helper function to perform Householder reflection on columns - SIMD optimized with multi-threading
void house_holder_cal_col_optimized1(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, const VectorList& vector_list) {
     // Maximum number of threads to use
    size_t num_threads = MAX_THREADS;
    
    // Adjust the number of threads based on the actual workload
    size_t total_items = batch_size * seq_length;
    if (total_items < num_threads) {
        num_threads = total_items;
    }
    
    // Create threads and thread arguments
    pthread_t threads[MAX_THREADS];
    HouseHolderThreadArgs thread_args[MAX_THREADS];
    
    // Initialize and start threads
    for (size_t i = 0; i < num_threads; ++i) {
        thread_args[i].input = input;
        thread_args[i].batch_size = batch_size;
        thread_args[i].seq_length = seq_length;
        thread_args[i].hidden_size = hidden_size;
        thread_args[i].vector_list = &vector_list;
        thread_args[i].thread_id = i;
        thread_args[i].num_threads = num_threads;
        
        pthread_create(&threads[i], NULL, house_holder_cal_col_thread, &thread_args[i]);
    }
    
    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; ++i) {
        pthread_join(threads[i], NULL);
    }
}

// Helper function to perform Householder reflection on columns - Cache friendly with multi-threading (no AVX)
void house_holder_cal_col_optimized2(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, const VectorList& vector_list) {
    size_t num_threads = MAX_THREADS;
    
    // Adjust the number of threads based on the actual workload
    size_t total_items = batch_size * seq_length;
    if (total_items < num_threads) {
        num_threads = total_items;
    }
    
    // Create threads and thread arguments
    pthread_t threads[MAX_THREADS];
    HouseHolderThreadArgs thread_args[MAX_THREADS];
    
    // Initialize and start threads
    for (size_t i = 0; i < num_threads; ++i) {
        thread_args[i].input = input;
        thread_args[i].batch_size = batch_size;
        thread_args[i].seq_length = seq_length;
        thread_args[i].hidden_size = hidden_size;
        thread_args[i].vector_list = &vector_list;
        thread_args[i].thread_id = i;
        thread_args[i].num_threads = num_threads;
        
        pthread_create(&threads[i], NULL, house_holder_cal_col_thread2, &thread_args[i]);
    }
    
    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; ++i) {
        pthread_join(threads[i], NULL);
    }
}

// Helper function to perform Householder reflection on columns - Cache friendly with OpenMP (no AVX)
void house_holder_cal_col_optimized3(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, const VectorList& vector_list) {
    size_t last_col = 0;
    
    for (size_t v_idx = 0; v_idx < vector_list.count; ++v_idx) {
        const Vector* vec = &vector_list.vectors[v_idx];
        size_t dim_v = vec->size;
        
        // Calculate the range of batches and sequences
        size_t total_items = batch_size * seq_length;
        
        // Process batches and sequences in parallel using OpenMP
        #pragma omp parallel for num_threads(MAX_THREADS) schedule(static)
        for (size_t item = 0; item < total_items; ++item) {
            size_t b = item / seq_length;
            size_t s = item % seq_length;
            
            // Calculate dot product: x @ vector (no AVX)
            float dot = 0.0f;
            size_t base_idx = b * seq_length * hidden_size + s * hidden_size + last_col;
            
            // Cache-friendly dot product calculation
            for (size_t i = 0; i < dim_v; ++i) {
                dot += input[base_idx + i] * vec->data[i];
            }
            
            // Update the block: x - 2 * (x @ vector) * vector (no AVX)
            float factor = -2.0f * dot;
            for (size_t i = 0; i < dim_v; ++i) {
                input[base_idx + i] += factor * vec->data[i];
            }
        }
        
        last_col += dim_v;
    }
}

// Original Householder reflection function for comparison
void house_holder_cal_col(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, const VectorList& vector_list) {
    size_t last_col = 0;
    
    for (size_t v_idx = 0; v_idx < vector_list.count; ++v_idx) {
        Vector* vec = &vector_list.vectors[v_idx];
        size_t dim_v = vec->size;
        
        // Process each element in the batch and sequence
        for (size_t b = 0; b < batch_size; ++b) {
            for (size_t s = 0; s < seq_length; ++s) {
                // Calculate dot product: x @ vector
                float dot = 0.0f;
                for (size_t i = 0; i < dim_v; ++i) {
                    size_t idx = b * seq_length * hidden_size + s * hidden_size + last_col + i;
                    dot += input[idx] * vec->data[i];
                }
                
                // Update the block: x - 2 * (x @ vector) * vector
                for (size_t i = 0; i < dim_v; ++i) {
                    size_t idx = b * seq_length * hidden_size + s * hidden_size + last_col + i;
                    input[idx] -= 2.0f * dot * vec->data[i];
                }
            }
        }
        
        last_col += dim_v;
    }
}

// Helper function to perform column permutation - Optimized for cache
void permutate_col_optimized0(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, 
                   const Indices& indices) {
    // Create temporary buffer for permutation
    float* temp = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (!temp) return;
    
    // Perform permutation with improved cache locality
    for (size_t b = 0; b < batch_size; ++b) {
        for (size_t s = 0; s < seq_length; ++s) {
            size_t base_idx = b * seq_length * hidden_size + s * hidden_size;
            
            // Precompute the permutation for this block
            for (size_t h = 0; h < hidden_size; ++h) {
                size_t permuted_h = indices.data[h];
                temp[base_idx + h] = input[base_idx + permuted_h];
            }
        }
    }
    
    // Copy back to input
    memcpy(input, temp, batch_size * seq_length * hidden_size * sizeof(float));
    free(temp);
}

// Thread parameter structure for permutate_col_optimized1
struct PermutateThreadArgs {
    float* input;
    float* temp;
    size_t batch_size;
    size_t seq_length;
    size_t hidden_size;
    const Indices* indices;
    size_t thread_id;
    size_t num_threads;
};

// Thread function for permutate_col_optimized1
void* permutate_col_thread(void* args) {
    PermutateThreadArgs* thread_args = (PermutateThreadArgs*)args;
    float* input = thread_args->input;
    float* temp = thread_args->temp;
    size_t batch_size = thread_args->batch_size;
    size_t seq_length = thread_args->seq_length;
    size_t hidden_size = thread_args->hidden_size;
    const Indices* indices = thread_args->indices;
    size_t thread_id = thread_args->thread_id;
    size_t num_threads = thread_args->num_threads;
    
    // Calculate the range of batches and sequences for this thread
    size_t total_items = batch_size * seq_length;
    size_t items_per_thread = (total_items + num_threads - 1) / num_threads;
    size_t start_item = thread_id * items_per_thread;
    size_t end_item = (thread_id + 1) * items_per_thread;
    if (end_item > total_items) end_item = total_items;
    
    // Perform permutation with improved cache locality and SIMD gather
    for (size_t item = start_item; item < end_item; ++item) {
        size_t b = item / seq_length;
        size_t s = item % seq_length;
        
        size_t base_idx = b * seq_length * hidden_size + s * hidden_size;
        size_t h = 0;
        
        // Vectorized permutation using AVX2 gather (8 floats at a time)
        for (; h + 8 <= hidden_size; h += 8) {
            // Load permutation indices as 32-bit integers
            __m256i indices_vec = _mm256_load_si256((__m256i*)&indices->data[h]);
            
            // Use gather instruction to load permuted values
            __m256 vec = _mm256_i32gather_ps(&input[base_idx], indices_vec, sizeof(float));
            
            // Store the permuted vector
            _mm256_store_ps(&temp[base_idx + h], vec);
        }
        
        // Handle remaining elements (if any)
        for (; h < hidden_size; ++h) {
            size_t permuted_h = indices->data[h];
            temp[base_idx + h] = input[base_idx + permuted_h];
        }
    }
    
    return NULL;
}

// Thread function for permutate_col_optimized2 (no AVX)
void* permutate_col_thread2(void* args) {
    PermutateThreadArgs* thread_args = (PermutateThreadArgs*)args;
    float* input = thread_args->input;
    float* temp = thread_args->temp;
    size_t batch_size = thread_args->batch_size;
    size_t seq_length = thread_args->seq_length;
    size_t hidden_size = thread_args->hidden_size;
    const Indices* indices = thread_args->indices;
    size_t thread_id = thread_args->thread_id;
    size_t num_threads = thread_args->num_threads;
    
    // Calculate the range of batches and sequences for this thread
    size_t total_items = batch_size * seq_length;
    size_t items_per_thread = (total_items + num_threads - 1) / num_threads;
    size_t start_item = thread_id * items_per_thread;
    size_t end_item = (thread_id + 1) * items_per_thread;
    if (end_item > total_items) end_item = total_items;
    
    // Perform permutation with improved cache locality (no AVX)
    for (size_t item = start_item; item < end_item; ++item) {
        size_t b = item / seq_length;
        size_t s = item % seq_length;
        
        size_t base_idx = b * seq_length * hidden_size + s * hidden_size;
        
        // Process each element one by one for better cache locality
        for (size_t h = 0; h < hidden_size; ++h) {
            size_t permuted_h = indices->data[h];
            temp[base_idx + h] = input[base_idx + permuted_h];
        }
    }
    
    return NULL;
}

// Thread parameter structure for residual addition
struct ResidualAddThreadArgs {
    float* enclave_output;
    float* residual;
    size_t start_idx;
    size_t end_idx;
};

// Thread function for residual addition
void* residual_add_thread(void* args) {
    ResidualAddThreadArgs* thread_args = (ResidualAddThreadArgs*)args;
    float* enclave_output = thread_args->enclave_output;
    float* residual = thread_args->residual;
    size_t start_idx = thread_args->start_idx;
    size_t end_idx = thread_args->end_idx;
    
    // Process this range using AVX2 vectorization
    size_t i = start_idx;
    // Process 8 floats at a time using AVX2
    for (; i + 8 <= end_idx; i += 8) {
        __m256 output_vec = _mm256_loadu_ps(&enclave_output[i]);
        __m256 residual_vec = _mm256_loadu_ps(&residual[i]);
        __m256 result_vec = _mm256_add_ps(output_vec, residual_vec);
        _mm256_storeu_ps(&enclave_output[i], result_vec);
    }
    
    // Handle remaining elements (if any)
    for (; i < end_idx; ++i) {
        enclave_output[i] += residual[i];
    }
    
    return NULL;
}

// Thread function for residual addition - No AVX
void* residual_add_thread2(void* args) {
    ResidualAddThreadArgs* thread_args = (ResidualAddThreadArgs*)args;
    float* enclave_output = thread_args->enclave_output;
    float* residual = thread_args->residual;
    size_t start_idx = thread_args->start_idx;
    size_t end_idx = thread_args->end_idx;
    
    // Process this range using scalar operations for cache friendliness
    size_t i = start_idx;
    // Process elements one by one for better cache control
    for (; i < end_idx; ++i) {
        enclave_output[i] += residual[i];
    }
    
    return NULL;
}

// Helper function to perform column permutation - SIMD optimized with AVX2 gather and multi-threading
void permutate_col_optimized1(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, 
                   const Indices& indices) {
    // Create temporary buffer for permutation
    float* temp = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (!temp) return;
    
    size_t num_threads = MAX_THREADS;
    
    // Adjust the number of threads based on the actual workload
    size_t total_items = batch_size * seq_length;
    if (total_items < num_threads) {
        num_threads = total_items;
    }
    
    // Create threads and thread arguments
    pthread_t threads[MAX_THREADS];
    PermutateThreadArgs thread_args[MAX_THREADS];
    
    // Initialize and start threads
    for (size_t i = 0; i < num_threads; ++i) {
        thread_args[i].input = input;
        thread_args[i].temp = temp;
        thread_args[i].batch_size = batch_size;
        thread_args[i].seq_length = seq_length;
        thread_args[i].hidden_size = hidden_size;
        thread_args[i].indices = &indices;
        thread_args[i].thread_id = i;
        thread_args[i].num_threads = num_threads;
        
        pthread_create(&threads[i], NULL, permutate_col_thread, &thread_args[i]);
    }
    
    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; ++i) {
        pthread_join(threads[i], NULL);
    }
    
    // Copy back to input
    memcpy(input, temp, batch_size * seq_length * hidden_size * sizeof(float));
    free(temp);
}

// Helper function to perform column permutation - Cache friendly with multi-threading (no AVX)
void permutate_col_optimized2(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, 
                   const Indices& indices) {
    // Create temporary buffer for permutation
    float* temp = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (!temp) return;
    
    size_t num_threads = MAX_THREADS;
    
    // Adjust the number of threads based on the actual workload
    size_t total_items = batch_size * seq_length;
    if (total_items < num_threads) {
        num_threads = total_items;
    }
    
    // Create threads and thread arguments
    pthread_t threads[MAX_THREADS];
    PermutateThreadArgs thread_args[MAX_THREADS];
    
    // Initialize and start threads
    for (size_t i = 0; i < num_threads; ++i) {
        thread_args[i].input = input;
        thread_args[i].temp = temp;
        thread_args[i].batch_size = batch_size;
        thread_args[i].seq_length = seq_length;
        thread_args[i].hidden_size = hidden_size;
        thread_args[i].indices = &indices;
        thread_args[i].thread_id = i;
        thread_args[i].num_threads = num_threads;
        
        pthread_create(&threads[i], NULL, permutate_col_thread2, &thread_args[i]);
    }
    
    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; ++i) {
        pthread_join(threads[i], NULL);
    }
    
    // Copy back to input
    memcpy(input, temp, batch_size * seq_length * hidden_size * sizeof(float));
    free(temp);
}

// Helper function to perform column permutation - Cache friendly with OpenMP (no AVX)
void permutate_col_optimized3(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, 
                   const Indices& indices) {
    // Create temporary buffer for permutation
    float* temp = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (!temp) return;
    
    // Calculate the range of batches and sequences
    size_t total_items = batch_size * seq_length;
    
    // Process batches and sequences in parallel using OpenMP
    #pragma omp parallel for num_threads(MAX_THREADS) schedule(static)
    for (size_t item = 0; item < total_items; ++item) {
        size_t b = item / seq_length;
        size_t s = item % seq_length;
        
        size_t base_idx = b * seq_length * hidden_size + s * hidden_size;
        
        // Perform permutation with improved cache locality (no AVX)
        for (size_t h = 0; h < hidden_size; ++h) {
            size_t permuted_h = indices.data[h];
            temp[base_idx + h] = input[base_idx + permuted_h];
        }
    }
    
    // Copy back to input
    memcpy(input, temp, batch_size * seq_length * hidden_size * sizeof(float));
    free(temp);
}

// Original column permutation function for comparison
void permutate_col(float* input, size_t batch_size, size_t seq_length, size_t hidden_size, 
                   const Indices& indices) {
    // Create temporary buffer for permutation
    float* temp = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (!temp) return;
    
    // Perform permutation
    for (size_t b = 0; b < batch_size; ++b) {
        for (size_t s = 0; s < seq_length; ++s) {
            for (size_t h = 0; h < hidden_size; ++h) {
                size_t src_idx = b * seq_length * hidden_size + s * hidden_size + h;
                size_t dst_idx = b * seq_length * hidden_size + s * hidden_size + indices.data[h];
                temp[src_idx] = input[dst_idx];
            }
        }
    }
    
    // Copy back to input
    memcpy(input, temp, batch_size * seq_length * hidden_size * sizeof(float));
    free(temp);
}

// Perform the combined operation: residual + _obfus_col_inv(hidden_states, ...) then _obfus_col(...)
int ecall_perform_obfuscation(float* hidden_states, float* residual, 
                          size_t batch_size, size_t seq_length, size_t hidden_size, 
                          float* output) {
    if (g_v_list0.vectors_list == nullptr || g_indices_list0.indices == nullptr || 
        g_v_list.vectors_list == nullptr || g_indices_list.indices == nullptr) {
        return -1;
    }
    
    // Step 0: Copy hidden_states to output initially
    float* enclave_output = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (enclave_output == nullptr) {
        return -1;
    }

    memcpy(enclave_output, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));

    start_clock();

    // Step 1: Apply _obfus_col_inv to hidden_states
    for (size_t i = g_v_list0.count; i > 0; --i) {
        permutate_col(enclave_output, batch_size, seq_length, hidden_size, g_indices_list0.indices[i - 1]);
        house_holder_cal_col(enclave_output, batch_size, seq_length, hidden_size, g_v_list0.vectors_list[i - 1]);
    }
    
    // Step 2: residual + _obfus_col_inv(hidden_states)
    for (size_t i = 0; i < batch_size * seq_length * hidden_size; ++i) {
        enclave_output[i] += residual[i];
    }
    
    // Step 3: Apply _obfus_col to the result
    for (size_t i = 0; i < g_v_list.count; ++i) {
        house_holder_cal_col(enclave_output, batch_size, seq_length, hidden_size, g_v_list.vectors_list[i]);
        permutate_col(enclave_output, batch_size, seq_length, hidden_size, g_indices_list.indices[i]);
    }

    end_clock("----------------------------------------------------\nSgx cost(input_obf): %.6f milliseconds\n----------------------------------------------------\n");

    // Step4: Copy enclave_output to output
    memcpy(output, enclave_output, batch_size * seq_length * hidden_size * sizeof(float));
    
    // Step5: Free enclave_output
    free(enclave_output);
    
    return 0;
}

// Optimization 0: Cache friendly
int ecall_perform_obfuscation_optimized0(float* hidden_states, float* residual, 
                          size_t batch_size, size_t seq_length, size_t hidden_size, 
                          float* output) {
    if (g_v_list0.vectors_list == nullptr || g_indices_list0.indices == nullptr || 
        g_v_list.vectors_list == nullptr || g_indices_list.indices == nullptr) {
        return -1;
    }
    
    // Step 0: Copy hidden_states to output initially
    float* enclave_output = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (enclave_output == nullptr) {
        return -1;
    }

    memcpy(enclave_output, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));

    start_clock();

    // Step 1: Apply _obfus_col_inv to hidden_states using optimized functions
    for (size_t i = g_v_list0.count; i > 0; --i) {
        permutate_col_optimized0(enclave_output, batch_size, seq_length, hidden_size, g_indices_list0.indices[i - 1]);
        house_holder_cal_col_optimized0(enclave_output, batch_size, seq_length, hidden_size, g_v_list0.vectors_list[i - 1]);
    }
    
    // Step 2: residual + _obfus_col_inv(hidden_states) - Optimized for cache
    // Process data in chunks to improve cache locality
    const size_t chunk_size = 1024; // Adjust based on cache size
    size_t total_elements = batch_size * seq_length * hidden_size;
    
    for (size_t chunk_start = 0; chunk_start < total_elements; chunk_start += chunk_size) {
        size_t chunk_end = (chunk_start + chunk_size < total_elements) ? 
                          (chunk_start + chunk_size) : total_elements;
        
        // Process this chunk
        for (size_t i = chunk_start; i < chunk_end; ++i) {
            enclave_output[i] += residual[i];
        }
    }
    
    // Step 3: Apply _obfus_col to the result using optimized functions
    for (size_t i = 0; i < g_v_list.count; ++i) {
        house_holder_cal_col_optimized0(enclave_output, batch_size, seq_length, hidden_size, g_v_list.vectors_list[i]);
        permutate_col_optimized0(enclave_output, batch_size, seq_length, hidden_size, g_indices_list.indices[i]);
    }

    end_clock("----------------------------------------------------\nSgx cost(input_obf optimized0): %.6f milliseconds\n----------------------------------------------------\n");

    // Step4: Copy enclave_output to output
    memcpy(output, enclave_output, batch_size * seq_length * hidden_size * sizeof(float));
    
    // Step5: Free enclave_output
    free(enclave_output);
    
    return 0;
}

// Optimization 1: Cache friendly + AVX + Multi-thread
int ecall_perform_obfuscation_optimized1(float* hidden_states, float* residual, 
                          size_t batch_size, size_t seq_length, size_t hidden_size, 
                          float* output) {
    if (g_v_list0.vectors_list == nullptr || g_indices_list0.indices == nullptr || 
        g_v_list.vectors_list == nullptr || g_indices_list.indices == nullptr) {
        return -1;
    }
    
    // Step 0: Copy hidden_states to output initially
    float* enclave_output = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (enclave_output == nullptr) {
        return -1;
    }

    memcpy(enclave_output, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));

    start_clock();

    // Step 1: Apply _obfus_col_inv to hidden_states using optimized functions
    for (size_t i = g_v_list0.count; i > 0; --i) {
        permutate_col_optimized1(enclave_output, batch_size, seq_length, hidden_size, g_indices_list0.indices[i - 1]);
        house_holder_cal_col_optimized1(enclave_output, batch_size, seq_length, hidden_size, g_v_list0.vectors_list[i - 1]);
    }

    // Step 2: residual + _obfus_col_inv(hidden_states) - Optimized with multi-threading
    size_t num_threads = MAX_THREADS;
    size_t total_elements = batch_size * seq_length * hidden_size;
    
    // Adjust the number of threads based on the actual workload
    if (total_elements < num_threads) {
        num_threads = total_elements;
    }
    
    // Create threads and thread arguments for residual addition
    pthread_t threads[MAX_THREADS];
    ResidualAddThreadArgs thread_args[MAX_THREADS];
    
    // Calculate elements per thread
    size_t elements_per_thread = (total_elements + num_threads - 1) / num_threads;
    
    // Initialize and start threads
    for (size_t i = 0; i < num_threads; ++i) {
        thread_args[i].enclave_output = enclave_output;
        thread_args[i].residual = residual;
        thread_args[i].start_idx = i * elements_per_thread;
        thread_args[i].end_idx = (i + 1) * elements_per_thread;
        if (thread_args[i].end_idx > total_elements) {
            thread_args[i].end_idx = total_elements;
        }
        
        pthread_create(&threads[i], NULL, residual_add_thread, &thread_args[i]);
    }
    
    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; ++i) {
        pthread_join(threads[i], NULL);
    }

    // Step 3: Apply _obfus_col to the result using optimized functions
    for (size_t i = 0; i < g_v_list.count; ++i) {
        house_holder_cal_col_optimized1(enclave_output, batch_size, seq_length, hidden_size, g_v_list.vectors_list[i]);
        permutate_col_optimized1(enclave_output, batch_size, seq_length, hidden_size, g_indices_list.indices[i]);
    }

    end_clock("----------------------------------------------------\nSgx cost(input_obf optimized1): %.6f milliseconds\n----------------------------------------------------\n");

    // Step4: Copy enclave_output to output
    memcpy(output, enclave_output, batch_size * seq_length * hidden_size * sizeof(float));
    
    // Step5: Free enclave_output
    free(enclave_output);
    
    return 0;
}

// Optimization 2: Cache friendly + Multi-thread (no AVX)
int ecall_perform_obfuscation_optimized2(float* hidden_states, float* residual, 
                          size_t batch_size, size_t seq_length, size_t hidden_size, 
                          float* output) {
    if (g_v_list0.vectors_list == nullptr || g_indices_list0.indices == nullptr || 
        g_v_list.vectors_list == nullptr || g_indices_list.indices == nullptr) {
        return -1;
    }
    
    // Step 0: Copy hidden_states to output initially
    float* enclave_output = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (enclave_output == nullptr) {
        return -1;
    }

    memcpy(enclave_output, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));

    start_clock();

    // Step 1: Apply _obfus_col_inv to hidden_states using optimized functions (no AVX)
    for (size_t i = g_v_list0.count; i > 0; --i) {
        permutate_col_optimized2(enclave_output, batch_size, seq_length, hidden_size, g_indices_list0.indices[i - 1]);
        house_holder_cal_col_optimized2(enclave_output, batch_size, seq_length, hidden_size, g_v_list0.vectors_list[i - 1]);
    }

    // Step 2: residual + _obfus_col_inv(hidden_states) - Optimized with multi-threading (no AVX)
    size_t num_threads = MAX_THREADS;
    size_t total_elements = batch_size * seq_length * hidden_size;
    
    // Adjust the number of threads based on the actual workload
    if (total_elements < num_threads) {
        num_threads = total_elements;
    }
    
    // Create threads and thread arguments for residual addition
    pthread_t threads[MAX_THREADS];
    ResidualAddThreadArgs thread_args[MAX_THREADS];
    
    // Calculate elements per thread
    size_t elements_per_thread = (total_elements + num_threads - 1) / num_threads;
    
    // Initialize and start threads
    for (size_t i = 0; i < num_threads; ++i) {
        thread_args[i].enclave_output = enclave_output;
        thread_args[i].residual = residual;
        thread_args[i].start_idx = i * elements_per_thread;
        thread_args[i].end_idx = (i + 1) * elements_per_thread;
        if (thread_args[i].end_idx > total_elements) {
            thread_args[i].end_idx = total_elements;
        }
        
        pthread_create(&threads[i], NULL, residual_add_thread2, &thread_args[i]);
    }
    
    // Wait for all threads to complete
    for (size_t i = 0; i < num_threads; ++i) {
        pthread_join(threads[i], NULL);
    }

    // Step 3: Apply _obfus_col to the result using optimized functions (no AVX)
    for (size_t i = 0; i < g_v_list.count; ++i) {
        house_holder_cal_col_optimized2(enclave_output, batch_size, seq_length, hidden_size, g_v_list.vectors_list[i]);
        permutate_col_optimized2(enclave_output, batch_size, seq_length, hidden_size, g_indices_list.indices[i]);
    }

    end_clock("----------------------------------------------------\nSgx cost(input_obf optimized2): %.6f milliseconds\n----------------------------------------------------\n");

    // Step4: Copy enclave_output to output
    memcpy(output, enclave_output, batch_size * seq_length * hidden_size * sizeof(float));
    
    // Step5: Free enclave_output
    free(enclave_output);
    
    return 0;
}

// Optimization 3: Cache friendly + OpenMP (no AVX)
int ecall_perform_obfuscation_optimized3(float* hidden_states, float* residual, 
                          size_t batch_size, size_t seq_length, size_t hidden_size, 
                          float* output) {
    if (g_v_list0.vectors_list == nullptr || g_indices_list0.indices == nullptr || 
        g_v_list.vectors_list == nullptr || g_indices_list.indices == nullptr) {
        return -1;
    }
    
    // Step 0: Copy hidden_states to output initially
    float* enclave_output = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    if (enclave_output == nullptr) {
        return -1;
    }

    memcpy(enclave_output, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));

    start_clock();

    // Step 1: Apply _obfus_col_inv to hidden_states using optimized functions (no AVX)
    for (size_t i = g_v_list0.count; i > 0; --i) {
        permutate_col_optimized3(enclave_output, batch_size, seq_length, hidden_size, g_indices_list0.indices[i - 1]);
        house_holder_cal_col_optimized3(enclave_output, batch_size, seq_length, hidden_size, g_v_list0.vectors_list[i - 1]);
    }

    // Step 2: residual + _obfus_col_inv(hidden_states) - Optimized with OpenMP (no AVX)
    size_t total_elements = batch_size * seq_length * hidden_size;
    
    // Use OpenMP for parallelization instead of pthread
    #pragma omp parallel for simd schedule(static) num_threads(MAX_THREADS)
    for (size_t i = 0; i < total_elements; ++i) {
        enclave_output[i] += residual[i];
    }

    // Step 3: Apply _obfus_col to the result using optimized functions (no AVX)
    for (size_t i = 0; i < g_v_list.count; ++i) {
        house_holder_cal_col_optimized3(enclave_output, batch_size, seq_length, hidden_size, g_v_list.vectors_list[i]);
        permutate_col_optimized3(enclave_output, batch_size, seq_length, hidden_size, g_indices_list.indices[i]);
    }

    end_clock("----------------------------------------------------\nSgx cost(input_obf optimized3): %.6f milliseconds\n----------------------------------------------------\n");

    // Step4: Copy enclave_output to output
    memcpy(output, enclave_output, batch_size * seq_length * hidden_size * sizeof(float));
    
    // Step5: Free enclave_output
    free(enclave_output);
    
    return 0;
}