// Define M_PI directly for SGX compatibility
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#include <cstring>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <immintrin.h>
#include <sgx_trts.h>
#include <sgx_thread.h>
#include <pthread.h>

#include "Enclave.h"
#include "Enclave_t.h"
#include "DataStructure.h"
#include "../../third_party/fmath.hpp"
#include "../../third_party/FastMemcpy_Avx.h"

extern ObfParamArray g_obf_params[ObfParamType::OBF_COUNT];
extern NormParamArray g_norm_params[NormParamType::NORM_COUNT];

float Q_rsqrt(float number)
{
    union {
        float    f;
        uint32_t i;
    } conv = { .f = number };
    conv.i  = 0x5f3759df - (conv.i >> 1);
    conv.f *= 1.5F - (number * 0.5F * conv.f * conv.f);
    return conv.f;
}

int rms_norm(float* hidden_states, NormParam norm_param, size_t batch_size, size_t seq_length, size_t hidden_size) {
    // rms norm
    size_t batch_seq = batch_size * seq_length;
    for (size_t bs = 0; bs < batch_seq; bs++) {
        float norm = 0.0f;
        size_t base_idx = bs * hidden_size;
        for (size_t h = 0; h < hidden_size; h++) {
            float val = hidden_states[base_idx + h];
            norm += val * val;
        }
        norm = norm / hidden_size;
        norm = norm + norm_param.eps;
        norm = Q_rsqrt(norm);
        for (size_t h = 0; h < hidden_size; h++) {
            float bias = 0.0f;
            if (norm_param.bias != nullptr) {
                bias = norm_param.bias[h];
            }
            hidden_states[base_idx + h] = hidden_states[base_idx + h] * norm * norm_param.weight[h] + bias;
        }
    }
    return 0;
}

int layer_norm(float* hidden_states, NormParam norm_param, size_t batch_size, size_t seq_length, size_t hidden_size) {
    // layer norm
    size_t batch_seq = batch_size * seq_length;
    for (size_t bs = 0; bs < batch_seq; bs++) {
        float mean = 0.0f;
        size_t base_idx = bs * hidden_size;
        for (size_t h = 0; h < hidden_size; h++) {
            mean += hidden_states[base_idx + h];
        }
        mean = mean / hidden_size;
        float var = 0.0f;
        for (size_t h = 0; h < hidden_size; h++) {
            float diff = hidden_states[base_idx + h] - mean;
            var += diff * diff;
        }
        var = var / hidden_size;
        var = var + norm_param.eps;
        var = Q_rsqrt(var);
        for (size_t h = 0; h < hidden_size; h++) {
            float bias = 0.0f;
            if (norm_param.bias != nullptr) {
                bias = norm_param.bias[h];
            }
            hidden_states[base_idx + h] = (hidden_states[base_idx + h] - mean) * var * norm_param.weight[h] + bias;
        }
    }
    return 0;
}

void restore(float* x, float* y, ObfParam obf_param, size_t batch_size, size_t seq_length, size_t hidden_size_x, size_t hidden_size_y) {
    // inv_perm = torch.argsort(permutation)
    int* inv_perm = (int*)aligned_alloc(32, hidden_size_y * sizeof(int));
    for (size_t i = 0; i < hidden_size_y; i++) {
        inv_perm[obf_param.permutation[i]] = i;
    }

    // Create temporary storage for permuted y
    float* y_permuted = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size_y * sizeof(float));

    // Process each batch and sequence
    size_t batch_seq = batch_size * seq_length;
    for (size_t bs = 0; bs < batch_seq; bs++) {
        // Compute x @ mask.view(-1,1)
        float x_dot_mask = 0.0f;
        size_t x_idx = bs * hidden_size_x;
        for (size_t h = 0; h < hidden_size_x; h++) {
            x_dot_mask += x[x_idx + h] * obf_param.mask[h];
        }

        // Permute y: y[:,:,inv_perm]
        size_t y_idx = bs * hidden_size_y;
        size_t y_permuted_idx = bs * hidden_size_y;
        for (size_t h = 0; h < hidden_size_y; h++) {
            y_permuted[y_permuted_idx + h] = y[y_idx + inv_perm[h]];
        }

        // Compute (y[:,:,inv_perm] - (x @ mask.view(-1,1)).repeat(1, 1, y.size(-1)) * ratio_mask) * inv_ratio_w
        for (size_t h = 0; h < hidden_size_y; h++) {
            float term = x_dot_mask * obf_param.ratio_mask[h];
            y[y_idx + h] = (y_permuted[y_permuted_idx + h] - term) / obf_param.ratio_w[h];
        }
    }

    // Free temporary memory
    free(inv_perm);
    free(y_permuted);
}


int ecall_norm(float* hidden_states, int layer_idx, int norm_type, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (norm_type == NormParamType::NORM || norm_type == NormParamType::LN_F) {
        layer_idx = 0;
    }
    if (layer_idx < 0) {
        return -1;
    }

    // copy hidden_states to output
    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy_fast(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));

    start_clock();
    int ret = 0;
    switch (norm_type) {
        case NormParamType::INPUT_LAYERNORM:
        case NormParamType::POST_ATTENTION_LAYERNORM:
        case NormParamType::Q_NORM:
        case NormParamType::K_NORM:
        case NormParamType::PRE_FEEDFORWARD_LAYERNORM:
        case NormParamType::POST_FEEDFORWARD_LAYERNORM:
        case NormParamType::NORM:
            if (layer_idx >= g_norm_params[norm_type].count) {
                return -1;
            }
            ret = rms_norm(output_enclave, g_norm_params[norm_type].params[layer_idx], batch_size, seq_length, hidden_size);
            break;

        case NormParamType::LN_1:
        case NormParamType::LN_2:
        case NormParamType::LN_F:
            ret =  layer_norm(output_enclave, g_norm_params[norm_type].params[layer_idx], batch_size, seq_length, hidden_size);
            break;
        default:
            ret = -1;
            printf("ecall_norm: norm_type %d is not supported\n", norm_type);
            break;
    }
    ocall_log_clock();
    
    memcpy_fast(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return ret;
}


int ecall_silu_activation(float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (hidden_states == nullptr || output == nullptr) {
        return -1;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy_fast(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();
    size_t total_elements = batch_size * seq_length * hidden_size;
    const size_t packet_size = 8;
    __m256 one_256 = _mm256_set1_ps(1.f);
    __m256 zero_256 = _mm256_setzero_ps();

    float* in_ptr = output_enclave;
    float* out_ptr = output_enclave;
    size_t index;

    for (index = 0; index <= total_elements - packet_size; index += packet_size) {
        __m256 p = _mm256_loadu_ps(in_ptr);
        p = _mm256_div_ps(p, _mm256_add_ps(one_256, fmath::exp_ps256(_mm256_sub_ps(zero_256, p))));
        _mm256_storeu_ps(out_ptr, p);
        in_ptr += packet_size;
        out_ptr += packet_size;
    }

    // Handle remaining elements
    for (; index < total_elements; index++) {
        float input = output_enclave[index];
        output_enclave[index] = input / (1.0f + expf(-input));
    }

    ocall_log_clock();

    memcpy_fast(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

int ecall_new_gelu_activation(float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (hidden_states == nullptr || output == nullptr) {
        return -1;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy_fast(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();
    // 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
    size_t total_elements = batch_size * seq_length * hidden_size;
    const size_t packet_size = 8;
    
    __m256 half_256 = _mm256_set1_ps(0.5f);
    __m256 one_256 = _mm256_set1_ps(1.0f);
    __m256 sqrt_2_pi_256 = _mm256_set1_ps(sqrtf(2.0f / M_PI));
    __m256 coeff_256 = _mm256_set1_ps(0.044715f);
    
    float* in_ptr = output_enclave;
    float* out_ptr = output_enclave;
    size_t index;
    
    for (index = 0; index <= total_elements - packet_size; index += packet_size) {
        __m256 input = _mm256_loadu_ps(in_ptr);
        __m256 input_cubed = _mm256_mul_ps(_mm256_mul_ps(input, input), input);
        __m256 term = _mm256_add_ps(input, _mm256_mul_ps(coeff_256, input_cubed));
        term = _mm256_mul_ps(sqrt_2_pi_256, term);
        
        // Implement tanh using AVX2
        __m256 term_doubled = _mm256_add_ps(term, term);
        __m256 exp_2x = fmath::exp_ps256(term_doubled);
        __m256 tanh_val = _mm256_div_ps(_mm256_sub_ps(exp_2x, one_256), _mm256_add_ps(exp_2x, one_256));
        
        __m256 result = _mm256_mul_ps(half_256, _mm256_mul_ps(input, _mm256_add_ps(one_256, tanh_val)));
        _mm256_storeu_ps(out_ptr, result);
        
        in_ptr += packet_size;
        out_ptr += packet_size;
    }
    
    // Handle remaining elements
    for (; index < total_elements; index++) {
        float input = output_enclave[index];
        output_enclave[index] = 0.5f * input * 
                (1.0f + tanhf(sqrt(2.0 / M_PI) * (input + 0.044715 * input * input * input)));
    }
    
    ocall_log_clock();

    memcpy_fast(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

int ecall_gelu_tanh_activation(float* hidden_states, size_t batch_size, size_t seq_length, size_t hidden_size, float* output) {
    if (hidden_states == nullptr || output == nullptr) {
        return -1;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size * sizeof(float));
    memcpy_fast(output_enclave, hidden_states, batch_size * seq_length * hidden_size * sizeof(float));
    
    start_clock();
    // input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
    size_t total_elements = batch_size * seq_length * hidden_size;
    const size_t packet_size = 8;
    
    __m256 half_256 = _mm256_set1_ps(0.5f);
    __m256 one_256 = _mm256_set1_ps(1.0f);
    __m256 sqrt_2_pi_256 = _mm256_set1_ps(sqrtf(2.0f / M_PI));
    __m256 coeff_256 = _mm256_set1_ps(0.044715f);
    
    float* in_ptr = output_enclave;
    float* out_ptr = output_enclave;
    size_t index;
    
    for (index = 0; index <= total_elements - packet_size; index += packet_size) {
        __m256 input = _mm256_loadu_ps(in_ptr);
        __m256 input_cubed = _mm256_mul_ps(_mm256_mul_ps(input, input), input);
        __m256 term = _mm256_add_ps(input, _mm256_mul_ps(coeff_256, input_cubed));
        term = _mm256_mul_ps(sqrt_2_pi_256, term);
        
        // Implement tanh using AVX2
        __m256 term_doubled = _mm256_add_ps(term, term);
        __m256 exp_2x = fmath::exp_ps256(term_doubled);
        __m256 tanh_val = _mm256_div_ps(_mm256_sub_ps(exp_2x, one_256), _mm256_add_ps(exp_2x, one_256));
        
        __m256 result = _mm256_mul_ps(input, _mm256_mul_ps(half_256, _mm256_add_ps(one_256, tanh_val)));
        _mm256_storeu_ps(out_ptr, result);
        
        in_ptr += packet_size;
        out_ptr += packet_size;
    }
    
    // Handle remaining elements
    for (; index < total_elements; index++) {
        float input = output_enclave[index];
        output_enclave[index] = input * 0.5f *
                (1.0f + tanhf(sqrt(2.0 / M_PI) * (input + 0.044715 * input * input * input)));
    }

    ocall_log_clock();

    memcpy_fast(output, output_enclave, batch_size * seq_length * hidden_size * sizeof(float));

    free(output_enclave);
    return 0;
}

int ecall_restore(float* x, float* y, int layer_idx, int layer_type, float* bias, size_t batch_size, size_t seq_length, size_t hidden_size_x, size_t hidden_size_y, float* output) {
    if (x == nullptr || y == nullptr || output == nullptr) {
        printf("Error: ecall_restore: x, y, output cannot be nullptr\n");
        return -1;
    }

    if (layer_type < 0 || layer_type >= ObfParamType::OBF_COUNT) {
        printf("Error: ecall_restore: layer_type %d is not supported\n", layer_type);
        return -2;
    }

    if (layer_idx < 0 || layer_idx >= g_obf_params[layer_type].count) {
        printf("Error: ecall_restore: layer_idx %d out of range [0, %d)\n", layer_idx, g_obf_params[layer_type].count);
        return -3;
    }

    if (hidden_size_x != g_obf_params[layer_type].params[layer_idx].mask_size) {
        return -4;
    }

    if (hidden_size_y != g_obf_params[layer_type].params[layer_idx].ratio_size) {
        return -5;
    }

    float* output_enclave = (float*)aligned_alloc(32, batch_size * seq_length * hidden_size_y * sizeof(float));
    memcpy_fast(output_enclave, y, batch_size * seq_length * hidden_size_y * sizeof(float));
    
    start_clock();
    if (bias != nullptr) {
        // y = y - bias
        size_t batch_seq = batch_size * seq_length;
        for (size_t bs = 0; bs < batch_seq; bs++) {
            size_t base_idx = bs * hidden_size_y;
            for (size_t h = 0; h < hidden_size_y; h++) {
                output_enclave[base_idx + h] -= bias[h];
            }
        }
    }

    // x, y, mask, ratio_mask, ratio_w, permutation

    restore(x, output_enclave, g_obf_params[layer_type].params[layer_idx], batch_size, seq_length, hidden_size_x, hidden_size_y);
    
    // y = y + bias
    if (bias != nullptr) {
        // y = y + bias
        size_t batch_seq = batch_size * seq_length;
        for (size_t bs = 0; bs < batch_seq; bs++) {
            size_t base_idx = bs * hidden_size_y;
            for (size_t h = 0; h < hidden_size_y; h++) {
                output_enclave[base_idx + h] += bias[h];
            }
        }
    }
    ocall_log_clock();

    memcpy_fast(output, output_enclave, batch_size * seq_length * hidden_size_y * sizeof(float));

    free(output_enclave);
    return 0;
}

