#include "Enclave.h"
#include "Enclave_t.h"
#include <cstddef>
#include <stdlib.h>
#include <string.h>
#include <sgx_trts.h>

#include "DataStructure.h"


ObfParamArray g_obf_params[ObfParamType::OBF_COUNT];
NormParamArray g_norm_params[NormParamType::NORM_COUNT];


int ecall_prepare_obf_params(int param_type, ObfParamArray params) {
    if (param_type < 0 || param_type >= ObfParamType::OBF_COUNT) {
        return -1;
    }
    if (params.count <= 0 || params.params == NULL) {
        return -2;
    }

    if (g_obf_params[param_type].params != NULL) {
        for (size_t i = 0; i < g_obf_params[param_type].count; i++) {
            if (g_obf_params[param_type].params[i].permutation != NULL) {
                free(g_obf_params[param_type].params[i].permutation);
            }
            if (g_obf_params[param_type].params[i].blocks != NULL) {
                for (size_t j = 0; j < g_obf_params[param_type].params[i].block_count; j++) {
                    if (g_obf_params[param_type].params[i].blocks[j].data != NULL) {
                        free(g_obf_params[param_type].params[i].blocks[j].data);
                    }
                }
                free(g_obf_params[param_type].params[i].blocks);
            }
        }
        free(g_obf_params[param_type].params);
    }

    g_obf_params[param_type].count = params.count;
    g_obf_params[param_type].params = (ObfParam*)aligned_alloc(32, params.count * sizeof(ObfParam));
    for (size_t i = 0; i < params.count; i++) {
        g_obf_params[param_type].params[i].perm_size = params.params[i].perm_size;
        g_obf_params[param_type].params[i].block_count = params.params[i].block_count;
        g_obf_params[param_type].params[i].blocks = (Block*)aligned_alloc(32, params.params[i].block_count * sizeof(Block));
        for (size_t j = 0; j < params.params[i].block_count; j++) {
            g_obf_params[param_type].params[i].blocks[j].size = params.params[i].blocks[j].size;
            g_obf_params[param_type].params[i].blocks[j].data = (float*)aligned_alloc(32, params.params[i].blocks[j].size * params.params[i].blocks[j].size * sizeof(float));
            memcpy(g_obf_params[param_type].params[i].blocks[j].data, params.params[i].blocks[j].data, params.params[i].blocks[j].size * params.params[i].blocks[j].size * sizeof(float));
        }
        g_obf_params[param_type].params[i].permutation = (int*)aligned_alloc(32, params.params[i].perm_size * sizeof(int));
        memcpy(g_obf_params[param_type].params[i].permutation, params.params[i].permutation, params.params[i].perm_size * sizeof(int));
    }

    return 0;
}

int ecall_prepare_norm_params(int param_type, NormParamArray params) {
    if (param_type < 0 || param_type >= NormParamType::NORM_COUNT) {
        return -1;
    }
    if (params.count <= 0 || params.params == nullptr) {
        return -2;
    }

    if (g_norm_params[param_type].params != nullptr) {
        for (size_t i = 0; i < g_norm_params[param_type].count; i++) {
            if (g_norm_params[param_type].params[i].weight != NULL) {
                free(g_norm_params[param_type].params[i].weight);
            }
            if (g_norm_params[param_type].params[i].bias != NULL) {
                free(g_norm_params[param_type].params[i].bias);
            }
        }
        free(g_norm_params[param_type].params);
    }

    g_norm_params[param_type].count = params.count;
    g_norm_params[param_type].params = (NormParam*)aligned_alloc(32, params.count * sizeof(NormParam));
    for (size_t i = 0; i < params.count; i++) {
        g_norm_params[param_type].params[i].size = params.params[i].size;
        g_norm_params[param_type].params[i].weight = (float*)aligned_alloc(32, params.params[i].size * sizeof(float));
        memcpy(g_norm_params[param_type].params[i].weight, params.params[i].weight, params.params[i].size * sizeof(float));
        if (params.params[i].bias != nullptr) {
            g_norm_params[param_type].params[i].bias = (float*)aligned_alloc(32, params.params[i].size * sizeof(float));
            memcpy(g_norm_params[param_type].params[i].bias, params.params[i].bias, params.params[i].size * sizeof(float));
        } else {
            g_norm_params[param_type].params[i].bias = nullptr;
        }
        g_norm_params[param_type].params[i].eps = params.params[i].eps;
    }

    return 0;
}

