// Block AVX-512 headers
#define __AVX512FP16INTRIN_H_INCLUDED 1
#define __AVX512FP16VLINTRIN_H_INCLUDED 1
#define __AVX512INTRIN_H_INCLUDED 1
#define __AVX512VLINTRIN_H_INCLUDED 1
#define __AVX512CDINTRIN_H_INCLUDED 1
#define __AVX512BWINTRIN_H_INCLUDED 1
#define __AVX512DQINTRIN_H_INCLUDED 1
#define __AVX512ERINTRIN_H_INCLUDED 1
#define __AVX512PFINTRIN_H_INCLUDED 1
#define __AVX512VBMIINTRIN_H_INCLUDED 1
#define __AVX512VPOPCNTDQINTRIN_H_INCLUDED 1

// Define missing AVX-512 related types that are not available in AVX2
#ifndef _Float16
#define _Float16 float
#endif

#ifndef __m128h
#define __m128h __m128i
#endif

#ifndef __m256h
#define __m256h __m256i
#endif

#include "Otp.h"
#include "Internal.h"
#include "Enclave.h"
#include "Enclave_t.h"
#include <stdlib.h>
#include <string.h>
#include <sgx_trts.h>
#include <sgx_thread.h>
#include <pthread.h>
#include <immintrin.h>   
#include <omp.h>


// Global variables declarations (extern)
extern VectorListList g_v_list;
extern IndicesList g_indices_list_inv;

Otp g_otp = {0, 0, 0, NULL};
OtpLogits g_otp_logits = {0, 0, 0, NULL};


int ecall_prepare_otp_params(Otp otp, OtpLogits otp_logits) {
    // check param
    if (otp.batch_size <= 0 || otp.logits_to_keep <= 0 || otp.hidden_size <= 0 ||
        otp_logits.batch_size <= 0 || otp_logits.logits_to_keep <= 0 || otp_logits.vocab_size <= 0 ||
        otp.data == NULL || otp_logits.logits == NULL) {
        return -4;
    }

    // cleanup previous allocation
    if (g_otp.data != NULL) {
        free(g_otp.data);
        g_otp.data = NULL;
    }
    if (g_otp_logits.logits != NULL) {
        free(g_otp_logits.logits);
        g_otp_logits.logits = NULL;
    }
    
    // allocate new memory
    g_otp.batch_size = otp.batch_size;
    g_otp.logits_to_keep = otp.logits_to_keep;
    g_otp.hidden_size = otp.hidden_size;
    g_otp.data = (float*) aligned_alloc(32, g_otp.batch_size * g_otp.logits_to_keep * g_otp.hidden_size * sizeof(float));
    if (g_otp.data == NULL) {
        return -1;
    }
    
    g_otp_logits.batch_size = otp_logits.batch_size;
    g_otp_logits.logits_to_keep = otp_logits.logits_to_keep;
    g_otp_logits.vocab_size = otp_logits.vocab_size;
    g_otp_logits.logits = (float*) aligned_alloc(32, g_otp_logits.batch_size * g_otp_logits.logits_to_keep * g_otp_logits.vocab_size * sizeof(float));
    if (g_otp_logits.logits == NULL) {
        return -2;
    }

    // copy data
    memcpy(g_otp.data, otp.data, g_otp.batch_size * g_otp.logits_to_keep * g_otp.hidden_size * sizeof(float));
    memcpy(g_otp_logits.logits, otp_logits.logits, g_otp_logits.batch_size * g_otp_logits.logits_to_keep * g_otp_logits.vocab_size * sizeof(float));

    return 0;
}

int ecall_perform_otp(float* hidden_states, size_t batch_size, size_t logits_to_keep, size_t hidden_size, float* output) {
    // check param
    if (hidden_states == NULL || output == NULL || batch_size <= 0 || logits_to_keep <= 0 || hidden_size <= 0) {
        return -4;
    }

    if (batch_size != g_otp.batch_size || logits_to_keep != g_otp.logits_to_keep || hidden_size != g_otp.hidden_size) {
        return -5;
    }

    float* enclave_output = (float*) aligned_alloc(32, batch_size * logits_to_keep * hidden_size * sizeof(float));
    if (enclave_output == NULL) {
        return -1;
    }

    memcpy(enclave_output, hidden_states, batch_size * logits_to_keep * hidden_size * sizeof(float));

    start_clock();
    for (size_t i = g_v_list.count; i > 0; --i) {
        permutate_col(enclave_output, batch_size, logits_to_keep, hidden_size, g_indices_list_inv.indices[i - 1]);
        house_holder_cal_col(enclave_output, batch_size, logits_to_keep, hidden_size, g_v_list.vectors_list[i - 1]);
    }
    // add g_otp to enclave_output
    for (size_t b = 0; b < batch_size; ++b) {
        size_t base_b = b * logits_to_keep * hidden_size;
        for (size_t l = 0; l < logits_to_keep; ++l) {
            size_t base_l = l * hidden_size + base_b;
            for (size_t h = 0; h < hidden_size; ++h) {
                size_t idx = base_l + h;
                enclave_output[idx] += g_otp.data[idx];
            }
        }
    }

    end_clock("----------------------------------------------------\nSgx cost(otp): %.6f milliseconds\n----------------------------------------------------\n");
    
    memcpy(output, enclave_output, batch_size * logits_to_keep * hidden_size * sizeof(float));
    free(enclave_output);
    
    return 0;
}


int ecall_perform_logits_recover(float* hidden_states, size_t batch_size, size_t logits_to_keep, size_t hidden_size, int64_t* output) {
    // check param
    if (hidden_states == NULL || output == NULL || batch_size <= 0 || logits_to_keep <= 0 || hidden_size <= 0) {
        return -4;
    }

    if (batch_size != g_otp_logits.batch_size || logits_to_keep != g_otp_logits.logits_to_keep || hidden_size != g_otp_logits.vocab_size) {
        return -5;
    }

    float* enclave_input = (float*) aligned_alloc(32, batch_size * logits_to_keep * hidden_size * sizeof(float));
    if (enclave_input == NULL) {
        return -1;
    }

    memcpy(enclave_input, hidden_states, batch_size * logits_to_keep * hidden_size * sizeof(float));

    start_clock();

    // hidden_states - g_otp_logits
    for (size_t b = 0; b < batch_size; ++b) {
        size_t base_b = b * logits_to_keep * hidden_size;
        for (size_t l = 0; l < logits_to_keep; ++l) {
            size_t base_l = l * hidden_size + base_b;
            for (size_t h = 0; h < hidden_size; ++h) {
                size_t idx = base_l + h;
                enclave_input[idx] -= g_otp_logits.logits[idx];
            }
        }
    }

    // output保存enclave_input最后一维中最大值的索引
    for (size_t b = 0; b < batch_size; ++b) {
        size_t base_b = b * logits_to_keep * hidden_size;
        for (size_t l = 0; l < logits_to_keep; ++l) {
            size_t base_l = l * hidden_size + base_b;
            int64_t max_idx = 0;
            float max_val = enclave_input[base_l];
            for (size_t h = 1; h < hidden_size; ++h) {
                size_t idx = base_l + h;
                if (enclave_input[idx] > max_val) {
                    max_val = enclave_input[idx];
                    max_idx = h;
                }
            }
            output[b * logits_to_keep + l] = max_idx;
        }
    }

    end_clock("----------------------------------------------------\nSgx cost(logits recover): %.6f milliseconds\n----------------------------------------------------\n");
    
    free(enclave_input);
    return 0;
}