#ifndef GEMM_MASK_FLOAT_H
#define GEMM_MASK_FLOAT_H

/* Create macros so that the matrices are stored in row-major order */

#define A(i, j) a[(i) * lda + (j)]
#define B(i, j) b[(i) * ldb + (j)]
#define C(i, j) c[(i) * ldc + (j)]
#define Mask(i, j) mask[(i) * ldm + (j)]

/* Block sizes */
#define mc 256
#define kc 128
// #define mc 512
// #define kc 256

#define min(i, j) ((i) < (j) ? (i) : (j))

#include <immintrin.h> // AVX/AVX2
#include <stdlib.h>
#include <cblas.h>
#include <stdbool.h>

/* Routine for computing C = alpha * A * B.T + beta * C */

void matmul_kernel(float *a, int lda, float *b, int ldb, float *c, int ldc, int *mask, int ldm, int *count, 
    int m, int n, int k, float alpha, float beta, int i_start, int j_start);

void masked_gemm_hybrid(float *X, float *Cen, int *mask, int num, int dim, int clu, float *D);
void gemm_mul_p_mask  (int m, int n, int k, float alpha, float *a, int lda, float *b, int ldb, float *c, int ldc, int *mask, int ldm, int *count, float beta);
void macro_kernel_mask(int m, int n, int k, float alpha, float *a, int lda, float *b, int ldb, float *c, int ldc, int *mask, int ldm, int *count, float beta);

void kernel_ori_mask(float *a, int lda, float *b, int ldb, float *c, int ldc, int k, float alpha, int *mask, int i, int j, int *count, float beta);
void kernel_ori_mask_le8(float *a, int lda, float *b, int ldb, float *c, int ldc, int k, float alpha, int *mask, float beta);
void kernel_ori_mask_8(float *a, int lda, float *b, int ldb, float *c, int ldc, int k, float alpha, int *mask, float beta);

void gemm_mul_p(int m, int n, int k, float alpha, 
          float *a, int lda,   // matrix A, aligned
          float *b, int ldb,   // matrix B, aligned
          float *c, int ldc);  // matrix C

void gemm(int m, int n, int k, float alpha, float *a, int lda, float *b, int ldb, float *c, int ldc);

void macro_kernel(int m, int n, int k, float alpha, float *a, int lda, float *b, int ldb, float *c, int ldc);

void micro_kernel(int i, int j, int k, float alpha, float *a, int lda, float *b, int ldb, float *c, int ldc);
void kernel_ori(
    float *a, int lda, 
    float *b, int ldb, 
    float *c, int ldc, 
    // __m256 b0, __m256 b1, __m256 b2, __m256 b3, __m256 b4, __m256 b5, __m256 b6, __m256 b7,
    // __m256 c0, __m256 c1, __m256 c2, __m256 c3, __m256 c4, __m256 c5, __m256 c6, __m256 c7,
    int k, float alpha);

void transpose_8x8_avx(
    __m256 c0, __m256 c1, __m256 c2, __m256 c3, __m256 c4, __m256 c5, __m256 c6, __m256 c7,
    __m256 b0, __m256 b1, __m256 b2, __m256 b3, __m256 b4, __m256 b5, __m256 b6, __m256 b7);

#endif