// Define M_PI directly for SGX compatibility
#include <cstdlib>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#include <cstring>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <immintrin.h>
#include <sgx_trts.h>
#include <sgx_thread.h>
#include <pthread.h>

#include "Enclave.h"
#include "Enclave_t.h"
#include "DataStructure.h"
#include "../../third_party/fmath.hpp"

extern ObfParamArray g_obf_params[ObfParamType::OBF_COUNT];
extern NormParamArray g_norm_params[NormParamType::NORM_COUNT];

int rms_norm(float* hidden_states, NormParam norm_param, size_t batch_size, size_t seq_length, size_t hidden_size) {
    // rms norm
    size_t batch_seq = batch_size * seq_length;
    for (size_t bs = 0; bs < batch_seq; bs++) {
        float norm = 0.0f;
        size_t base_idx = bs * hidden_size;
        for (size_t h = 0; h < hidden_size; h++) {
            float val = hidden_states[base_idx + h];
            norm += val * val;
        }
        norm = norm / hidden_size;
        norm = norm + norm_param.eps;
        norm = 1.0f / sqrt(norm);
        for (size_t h = 0; h < hidden_size; h++) {
            float bias = 0.0f;
            if (norm_param.bias != nullptr) {
                bias = norm_param.bias[h];
            }
            hidden_states[base_idx + h] = hidden_states[base_idx + h] * norm * norm_param.weight[h] + bias;
        }
    }
    return 0;
}

int layer_norm(float* hidden_states, NormParam norm_param, size_t batch_size, size_t seq_length, size_t hidden_size) {
    // layer norm
    size_t batch_seq = batch_size * seq_length;
    for (size_t bs = 0; bs < batch_seq; bs++) {
        float mean = 0.0f;
        size_t base_idx = bs * hidden_size;
        for (size_t h = 0; h < hidden_size; h++) {
            mean += hidden_states[base_idx + h];
        }
        mean = mean / hidden_size;
        float var = 0.0f;
        for (size_t h = 0; h < hidden_size; h++) {
            float diff = hidden_states[base_idx + h] - mean;
            var += diff * diff;
        }
        var = var / hidden_size;
        var = var + norm_param.eps;
        var = 1.0f / sqrt(var);
        for (size_t h = 0; h < hidden_size; h++) {
            float bias = 0.0f;
            if (norm_param.bias != nullptr) {
                bias = norm_param.bias[h];
            }
            hidden_states[base_idx + h] = (hidden_states[base_idx + h] - mean) * var * norm_param.weight[h] + bias;
        }
    }
    return 0;
}

void restore(float* y, ObfParam obf_param, size_t batch_size, size_t seq_length, size_t hidden_size) {
    // Step 1: Compute inverse permutation
    int* inv_perm = (int*)aligned_alloc(32, hidden_size * sizeof(int));
    for (size_t i = 0; i < hidden_size; i++) {
        inv_perm[obf_param.permutation[i]] = i;
    }

    // Step 2: Create a temporary buffer for permuted y
    float* y_permuted = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    
    // Step 3: Permute y using inverse permutation
    for (size_t b = 0; b < batch_size; b++) {
        for (size_t s = 0; s < seq_length; s++) {
            for (size_t h = 0; h < hidden_size; h++) {
                size_t permuted_idx = b * seq_length * hidden_size + s * hidden_size + h;
                size_t original_idx = b * seq_length * hidden_size + s * hidden_size + inv_perm[h];
                y_permuted[permuted_idx] = y[original_idx];
            }
        }
    }

    float* temp = (float*) aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    
    // Process each block
    size_t current_start = 0;
    for (size_t i = 0; i < obf_param.block_count; i++) {
        Block block = obf_param.blocks[i];
        size_t block_size = block.size;
        size_t end = current_start + block_size;
        
        // Perform matrix multiplication for this block
        for (size_t b = 0; b < batch_size; b++) {
            for (size_t s = 0; s < seq_length; s++) {
                size_t base = b * (seq_length * hidden_size) + s * hidden_size + current_start;

                for (size_t k = 0; k < block_size; k++) {
                    float sum = 0.0;
                    for (size_t j = 0; j < block_size; j++) {
                        sum += y_permuted[base + j] * block.data[j * block_size + k];
                    }
                    temp[base + k] = sum;
                }
            }
        }
        
        current_start = end;
    }
    
    // Step 5: Copy result back to y
    memcpy(y, temp, batch_size * seq_length * hidden_size * sizeof(float));

    // Step 6: Free temporary memory
    free(inv_perm);
    free(y_permuted);
    free(temp);
}


int ecall_norm(float* hidden_states, int layer_idx, int norm_type, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (norm_type == NormParamType::NORM || norm_type == NormParamType::LN_F) {
        layer_idx = 0;
    }
    if (layer_idx < 0) {
        return -1;
    }

    // copy hidden_states to output
    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));

    start_clock();
    int ret = 0;
    switch (norm_type) {
        case NormParamType::INPUT_LAYERNORM:
        case NormParamType::POST_ATTENTION_LAYERNORM:
        case NormParamType::Q_NORM:
        case NormParamType::K_NORM:
        case NormParamType::PRE_FEEDFORWARD_LAYERNORM:
        case NormParamType::POST_FEEDFORWARD_LAYERNORM:
        case NormParamType::NORM:
            if (layer_idx >= g_norm_params[norm_type].count) {
                return -1;
            }
            ret = rms_norm(output_enclave, g_norm_params[norm_type].params[layer_idx], batch_size, seq_length, hidden_size);
            break;

        case NormParamType::LN_1:
        case NormParamType::LN_2:
        case NormParamType::LN_F:
            ret =  layer_norm(output_enclave, g_norm_params[norm_type].params[layer_idx], batch_size, seq_length, hidden_size);
            break;
        default:
            ret = -1;
            printf("ecall_norm: norm_type %d is not supported\n", norm_type);
            break;
    }
    ocall_log_clock();
    
    memcpy(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return ret;
}

int ecall_silu_activation(float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (hidden_states == nullptr || output == nullptr) {
        return -1;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();
    size_t total_elements = batch_size * seq_length * hidden_size;
    const size_t packet_size = 8;
    __m256 one_256 = _mm256_set1_ps(1.f);
    __m256 zero_256 = _mm256_setzero_ps();

    float* in_ptr = output_enclave;
    float* out_ptr = output_enclave;
    size_t index;

    for (index = 0; index <= total_elements - packet_size; index += packet_size) {
        __m256 p = _mm256_loadu_ps(in_ptr);
        p = _mm256_div_ps(p, _mm256_add_ps(one_256, fmath::exp_ps256(_mm256_sub_ps(zero_256, p))));
        _mm256_storeu_ps(out_ptr, p);
        in_ptr += packet_size;
        out_ptr += packet_size;
    }

    // Handle remaining elements
    for (; index < total_elements; index++) {
        float input = output_enclave[index];
        output_enclave[index] = input / (1.0f + expf(-input));
    }

    ocall_log_clock();

    memcpy(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

int ecall_new_gelu_activation(float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (hidden_states == nullptr || output == nullptr) {
        return -1;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();
    // 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
    size_t total_elements = batch_size * seq_length * hidden_size;
    const size_t packet_size = 8;
    
    __m256 half_256 = _mm256_set1_ps(0.5f);
    __m256 one_256 = _mm256_set1_ps(1.0f);
    __m256 sqrt_2_pi_256 = _mm256_set1_ps(sqrtf(2.0f / M_PI));
    __m256 coeff_256 = _mm256_set1_ps(0.044715f);
    
    float* in_ptr = output_enclave;
    float* out_ptr = output_enclave;
    size_t index;
    
    for (index = 0; index <= total_elements - packet_size; index += packet_size) {
        __m256 input = _mm256_loadu_ps(in_ptr);
        __m256 input_cubed = _mm256_mul_ps(_mm256_mul_ps(input, input), input);
        __m256 term = _mm256_add_ps(input, _mm256_mul_ps(coeff_256, input_cubed));
        term = _mm256_mul_ps(sqrt_2_pi_256, term);
        
        // Implement tanh using AVX2
        __m256 term_doubled = _mm256_add_ps(term, term);
        __m256 exp_2x = fmath::exp_ps256(term_doubled);
        __m256 tanh_val = _mm256_div_ps(_mm256_sub_ps(exp_2x, one_256), _mm256_add_ps(exp_2x, one_256));
        
        __m256 result = _mm256_mul_ps(half_256, _mm256_mul_ps(input, _mm256_add_ps(one_256, tanh_val)));
        _mm256_storeu_ps(out_ptr, result);
        
        in_ptr += packet_size;
        out_ptr += packet_size;
    }
    
    // Handle remaining elements
    for (; index < total_elements; index++) {
        float input = output_enclave[index];
        output_enclave[index] = 0.5f * input * 
                (1.0f + tanhf(sqrt(2.0 / M_PI) * (input + 0.044715 * input * input * input)));
    }
    
    ocall_log_clock();

    memcpy(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

int ecall_gelu_tanh_activation(float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (hidden_states == nullptr || output == nullptr) {
        return -1;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();
    // input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
    size_t total_elements = batch_size * seq_length * hidden_size;
    const size_t packet_size = 8;
    
    __m256 half_256 = _mm256_set1_ps(0.5f);
    __m256 one_256 = _mm256_set1_ps(1.0f);
    __m256 sqrt_2_pi_256 = _mm256_set1_ps(sqrtf(2.0f / M_PI));
    __m256 coeff_256 = _mm256_set1_ps(0.044715f);
    
    float* in_ptr = output_enclave;
    float* out_ptr = output_enclave;
    size_t index;
    
    for (index = 0; index <= total_elements - packet_size; index += packet_size) {
        __m256 input = _mm256_loadu_ps(in_ptr);
        __m256 input_cubed = _mm256_mul_ps(_mm256_mul_ps(input, input), input);
        __m256 term = _mm256_add_ps(input, _mm256_mul_ps(coeff_256, input_cubed));
        term = _mm256_mul_ps(sqrt_2_pi_256, term);
        
        // Implement tanh using AVX2
        __m256 term_doubled = _mm256_add_ps(term, term);
        __m256 exp_2x = fmath::exp_ps256(term_doubled);
        __m256 tanh_val = _mm256_div_ps(_mm256_sub_ps(exp_2x, one_256), _mm256_add_ps(exp_2x, one_256));
        
        __m256 result = _mm256_mul_ps(input, _mm256_mul_ps(half_256, _mm256_add_ps(one_256, tanh_val)));
        _mm256_storeu_ps(out_ptr, result);
        
        in_ptr += packet_size;
        out_ptr += packet_size;
    }
    
    // Handle remaining elements
    for (; index < total_elements; index++) {
        float input = output_enclave[index];
        output_enclave[index] = input * 0.5f *
                (1.0f + tanhf(sqrt(2.0 / M_PI) * (input + 0.044715 * input * input * input)));
    }

    ocall_log_clock();

    memcpy(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

int ecall_restore(float* x, int layer_idx, int layer_type, float* bias, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (x == nullptr || output == nullptr) {
        printf("Error: ecall_restore: x, output cannot be nullptr\n");
        return -1;
    }

    if (layer_type < 0 || layer_type >= ObfParamType::OBF_COUNT) {
        printf("Error: ecall_restore: layer_type %d is not supported\n", layer_type);
        return -2;
    }

    if (layer_idx < 0 || layer_idx >= g_obf_params[layer_type].count) {
        printf("Error: ecall_restore: layer_idx %d out of range [0, %d)\n", layer_idx, g_obf_params[layer_type].count);
        return -3;
    }

    if (hidden_size != g_obf_params[layer_type].params[layer_idx].perm_size) {
        return -4;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy(output_enclave, x, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();
    if (bias != nullptr) {
        // x = x - bias
        size_t batch_seq = batch_size * seq_length;
        for (size_t bs = 0; bs < batch_seq; bs++) {
            size_t base_idx = bs * hidden_size;
            for (size_t h = 0; h < hidden_size; h++) {
                output_enclave[base_idx + h] -= bias[h];
            }
        }
    }

    restore(output_enclave, g_obf_params[layer_type].params[layer_idx], batch_size, seq_length, hidden_size);
    
    // x = x + bias
    if (bias != nullptr) {
        size_t batch_seq = batch_size * seq_length;
        for (size_t bs = 0; bs < batch_seq; bs++) {
            size_t base_idx = bs * hidden_size;
            for (size_t h = 0; h < hidden_size; h++) {
                output_enclave[base_idx + h] += bias[h];
            }
        }
    }
    ocall_log_clock();

    memcpy(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

