#ifndef _DATASTRUCTURE_H_
#define _DATASTRUCTURE_H_

#include <stdint.h>
#include <stddef.h>

typedef struct ObfParam {
    size_t mask_size;
    size_t ratio_size;
    float* mask;
    float* ratio_mask;
    float* ratio_w;
    int* permutation;
} ObfParam;

typedef struct ObfParamArray {
    size_t count;
    ObfParam* params;
} ObfParamArray;

typedef struct NormParam {
    size_t size;
    float* weight;
    float* bias;
    float eps;
} NormParam;

typedef struct NormParamArray {
    size_t count;
    NormParam* params;
} NormParamArray;

enum NormParamType {
    INPUT_LAYERNORM = 0,
    POST_ATTENTION_LAYERNORM = 1,
    Q_NORM = 2,
    K_NORM = 3,
    LN_1 = 4,
    LN_2 = 5,
    LN_F = 6,
    NORM = 7,
    PRE_FEEDFORWARD_LAYERNORM = 8,
    POST_FEEDFORWARD_LAYERNORM = 9,

    NORM_COUNT = 10
};


enum ObfParamType {
    Q_PROJ = 0,
    K_PROJ = 1,
    V_PROJ = 2,
    O_PROJ = 3,
    C_FC = 4,
    C_PROJ = 5,
    GATE_PROJ = 6,
    UP_PROJ = 7,
    DOWN_PROJ = 8,

    OBF_COUNT = 9,
};

#endif /* !_DATASTRUCTURE_H_ */
