#include "Enclave.h"
#include "Enclave_t.h"
#include <cstddef>
#include <stdlib.h>
#include <string.h>
#include <sgx_trts.h>

#include "DataStructure.h"


ObfParamArray g_obf_params[ObfParamType::OBF_COUNT];
NormParamArray g_norm_params[NormParamType::NORM_COUNT];


int ecall_prepare_obf_params(int param_type, ObfParamArray params) {
    if (param_type < 0 || param_type >= ObfParamType::OBF_COUNT) {
        return -1;
    }
    if (params.count <= 0 || params.params == NULL) {
        return -2;
    }

    if (g_obf_params[param_type].params != NULL) {
        for (size_t i = 0; i < g_obf_params[param_type].count; i++) {
            if (g_obf_params[param_type].params[i].mask != NULL) {
                free(g_obf_params[param_type].params[i].mask);
            }
            if (g_obf_params[param_type].params[i].ratio_mask != NULL) {
                free(g_obf_params[param_type].params[i].ratio_mask);
            }
            if (g_obf_params[param_type].params[i].ratio_w != NULL) {
                free(g_obf_params[param_type].params[i].ratio_w);
            }
            if (g_obf_params[param_type].params[i].permutation != NULL) {
                free(g_obf_params[param_type].params[i].permutation);
            }
        }
        free(g_obf_params[param_type].params);
    }

    g_obf_params[param_type].count = params.count;
    g_obf_params[param_type].params = (ObfParam*)aligned_alloc(32, params.count * sizeof(ObfParam));
    for (size_t i = 0; i < params.count; i++) {
        g_obf_params[param_type].params[i].mask_size = params.params[i].mask_size;
        g_obf_params[param_type].params[i].ratio_size = params.params[i].ratio_size;

        g_obf_params[param_type].params[i].mask = (float*)aligned_alloc(32, params.params[i].mask_size * sizeof(float));
        memcpy(g_obf_params[param_type].params[i].mask, params.params[i].mask, params.params[i].mask_size * sizeof(float));
        
        g_obf_params[param_type].params[i].ratio_mask = (float*)aligned_alloc(32, params.params[i].ratio_size * sizeof(float));
        memcpy(g_obf_params[param_type].params[i].ratio_mask, params.params[i].ratio_mask, params.params[i].ratio_size * sizeof(float));
        g_obf_params[param_type].params[i].ratio_w = (float*)aligned_alloc(32, params.params[i].ratio_size * sizeof(float));
        memcpy(g_obf_params[param_type].params[i].ratio_w, params.params[i].ratio_w, params.params[i].ratio_size * sizeof(float));
        g_obf_params[param_type].params[i].permutation = (int*)aligned_alloc(32, params.params[i].ratio_size * sizeof(int));
        memcpy(g_obf_params[param_type].params[i].permutation, params.params[i].permutation, params.params[i].ratio_size * sizeof(int));
    }

    return 0;
}

int ecall_prepare_norm_params(int param_type, NormParamArray params) {
    if (param_type < 0 || param_type >= NormParamType::NORM_COUNT) {
        return -1;
    }
    if (params.count <= 0 || params.params == nullptr) {
        return -2;
    }

    if (g_norm_params[param_type].params != nullptr) {
        for (size_t i = 0; i < g_norm_params[param_type].count; i++) {
            if (g_norm_params[param_type].params[i].weight != NULL) {
                free(g_norm_params[param_type].params[i].weight);
            }
            if (g_norm_params[param_type].params[i].bias != NULL) {
                free(g_norm_params[param_type].params[i].bias);
            }
        }
        free(g_norm_params[param_type].params);
    }

    g_norm_params[param_type].count = params.count;
    g_norm_params[param_type].params = (NormParam*)aligned_alloc(32, params.count * sizeof(NormParam));
    for (size_t i = 0; i < params.count; i++) {
        g_norm_params[param_type].params[i].size = params.params[i].size;
        g_norm_params[param_type].params[i].weight = (float*)aligned_alloc(32, params.params[i].size * sizeof(float));
        memcpy(g_norm_params[param_type].params[i].weight, params.params[i].weight, params.params[i].size * sizeof(float));
        if (params.params[i].bias != nullptr) {
            g_norm_params[param_type].params[i].bias = (float*)aligned_alloc(32, params.params[i].size * sizeof(float));
            memcpy(g_norm_params[param_type].params[i].bias, params.params[i].bias, params.params[i].size * sizeof(float));
        } else {
            g_norm_params[param_type].params[i].bias = nullptr;
        }
        g_norm_params[param_type].params[i].eps = params.params[i].eps;
    }

    return 0;
}

