/*
 */

/* Create macros so that the matrices are stored in row-major order */
#include "gemm_mask_float.h"
#include "utils.h"

// A: row major, m x k, lda = k, k must be multiper of 8
// B: row major, n x k, ldb = k
// C: row major, m x n, ldc = n
// C = alpha * A x B.T + beta * C
void gemm_mul_p_mask(int m, int n, int k, float alpha, 
          float *a, int lda,   // matrix A, aligned
          float *b, int ldb,   // matrix B, aligned
          float *c, int ldc, 
          int *mask, int ldm, int *count, float beta) { // matrix C
    
    if (n % 8 != 0 || k % 8 != 0){
        perror("n or k is not a multiple of 8");
    }

    if (!aligned(a) || !aligned(b)){
        perror("input arr a or b is not aligned");
    }

    // int i, j;
    // for (i = 0; i + 16 <= m; i+=16){
    //     for (j = 0; j < n; j += 8){
    //         kernel_ori_mask(&A(i,      0), lda, &B(0, 0), ldb, &C(i,      0), ldc, k, alpha, &Mask(i,      j), i + 0,  j, count, beta);
    //         kernel_ori_mask(&A(i + 1,  0), lda, &B(0, 0), ldb, &C(i + 1,  0), ldc, k, alpha, &Mask(i + 1,  j), i + 1,  j, count, beta);
    //         kernel_ori_mask(&A(i + 2,  0), lda, &B(0, 0), ldb, &C(i + 2,  0), ldc, k, alpha, &Mask(i + 2,  j), i + 2,  j, count, beta);
    //         kernel_ori_mask(&A(i + 3,  0), lda, &B(0, 0), ldb, &C(i + 3,  0), ldc, k, alpha, &Mask(i + 3,  j), i + 3,  j, count, beta);
    //         kernel_ori_mask(&A(i + 4,  0), lda, &B(0, 0), ldb, &C(i + 4,  0), ldc, k, alpha, &Mask(i + 4,  j), i + 4,  j, count, beta);
    //         kernel_ori_mask(&A(i + 5,  0), lda, &B(0, 0), ldb, &C(i + 5,  0), ldc, k, alpha, &Mask(i + 5,  j), i + 5,  j, count, beta);
    //         kernel_ori_mask(&A(i + 6,  0), lda, &B(0, 0), ldb, &C(i + 6,  0), ldc, k, alpha, &Mask(i + 6,  j), i + 6,  j, count, beta);
    //         kernel_ori_mask(&A(i + 7,  0), lda, &B(0, 0), ldb, &C(i + 7,  0), ldc, k, alpha, &Mask(i + 7,  j), i + 7,  j, count, beta);
    //         kernel_ori_mask(&A(i + 8,  0), lda, &B(0, 0), ldb, &C(i + 8,  0), ldc, k, alpha, &Mask(i + 8,  j), i + 8,  j, count, beta);
    //         kernel_ori_mask(&A(i + 9,  0), lda, &B(0, 0), ldb, &C(i + 9,  0), ldc, k, alpha, &Mask(i + 9,  j), i + 9,  j, count, beta);
    //         kernel_ori_mask(&A(i + 10, 0), lda, &B(0, 0), ldb, &C(i + 10, 0), ldc, k, alpha, &Mask(i + 10, j), i + 10, j, count, beta);
    //         kernel_ori_mask(&A(i + 11, 0), lda, &B(0, 0), ldb, &C(i + 11, 0), ldc, k, alpha, &Mask(i + 11, j), i + 11, j, count, beta);
    //         kernel_ori_mask(&A(i + 12, 0), lda, &B(0, 0), ldb, &C(i + 12, 0), ldc, k, alpha, &Mask(i + 12, j), i + 12, j, count, beta);
    //         kernel_ori_mask(&A(i + 13, 0), lda, &B(0, 0), ldb, &C(i + 13, 0), ldc, k, alpha, &Mask(i + 13, j), i + 13, j, count, beta);
    //         kernel_ori_mask(&A(i + 14, 0), lda, &B(0, 0), ldb, &C(i + 14, 0), ldc, k, alpha, &Mask(i + 14, j), i + 14, j, count, beta);
    //         kernel_ori_mask(&A(i + 15, 0), lda, &B(0, 0), ldb, &C(i + 15, 0), ldc, k, alpha, &Mask(i + 15, j), i + 15, j, count, beta);
    //         // kernel_ori_mask(&A(i + 16, 0), lda, &B(0, 0), ldb, &C(i + 16, 0), ldc, k, alpha, &Mask(i + 16, j), i + 16, j, count, beta);
    //         // kernel_ori_mask(&A(i + 17, 0), lda, &B(0, 0), ldb, &C(i + 17, 0), ldc, k, alpha, &Mask(i + 17, j), i + 17, j, count, beta);
    //         // kernel_ori_mask(&A(i + 18, 0), lda, &B(0, 0), ldb, &C(i + 18, 0), ldc, k, alpha, &Mask(i + 18, j), i + 18, j, count, beta);
    //         // kernel_ori_mask(&A(i + 19, 0), lda, &B(0, 0), ldb, &C(i + 19, 0), ldc, k, alpha, &Mask(i + 19, j), i + 19, j, count, beta);
    //         // kernel_ori_mask(&A(i + 20, 0), lda, &B(0, 0), ldb, &C(i + 20, 0), ldc, k, alpha, &Mask(i + 20, j), i + 20, j, count, beta);
    //         // kernel_ori_mask(&A(i + 21, 0), lda, &B(0, 0), ldb, &C(i + 21, 0), ldc, k, alpha, &Mask(i + 21, j), i + 21, j, count, beta);
    //         // kernel_ori_mask(&A(i + 22, 0), lda, &B(0, 0), ldb, &C(i + 22, 0), ldc, k, alpha, &Mask(i + 22, j), i + 22, j, count, beta);
    //         // kernel_ori_mask(&A(i + 23, 0), lda, &B(0, 0), ldb, &C(i + 23, 0), ldc, k, alpha, &Mask(i + 23, j), i + 23, j, count, beta);
    //         // kernel_ori_mask(&A(i + 24, 0), lda, &B(0, 0), ldb, &C(i + 24, 0), ldc, k, alpha, &Mask(i + 24, j), i + 24, j, count, beta);
    //         // kernel_ori_mask(&A(i + 25, 0), lda, &B(0, 0), ldb, &C(i + 25, 0), ldc, k, alpha, &Mask(i + 25, j), i + 25, j, count, beta);
    //         // kernel_ori_mask(&A(i + 26, 0), lda, &B(0, 0), ldb, &C(i + 26, 0), ldc, k, alpha, &Mask(i + 26, j), i + 26, j, count, beta);
    //         // kernel_ori_mask(&A(i + 27, 0), lda, &B(0, 0), ldb, &C(i + 27, 0), ldc, k, alpha, &Mask(i + 27, j), i + 27, j, count, beta);
    //         // kernel_ori_mask(&A(i + 28, 0), lda, &B(0, 0), ldb, &C(i + 28, 0), ldc, k, alpha, &Mask(i + 28, j), i + 28, j, count, beta);
    //         // kernel_ori_mask(&A(i + 29, 0), lda, &B(0, 0), ldb, &C(i + 29, 0), ldc, k, alpha, &Mask(i + 29, j), i + 29, j, count, beta);
    //         // kernel_ori_mask(&A(i + 30, 0), lda, &B(0, 0), ldb, &C(i + 30, 0), ldc, k, alpha, &Mask(i + 30, j), i + 30, j, count, beta);
    //         // kernel_ori_mask(&A(i + 31, 0), lda, &B(0, 0), ldb, &C(i + 31, 0), ldc, k, alpha, &Mask(i + 31, j), i + 31, j, count, beta);
    //     }
    // }
    // for (; i < m; i++){
    //     for (j = 0; j < n; j += 8) {
    //         kernel_ori_mask(&A(i,      0), lda, &B(0, 0), ldb, &C(i,      0), ldc, k, alpha, &Mask(i,      j), i, j, count, beta);
    //     }
    // }

    const int ROW_BLOCK_A = 64;
    const int ROW_BLOCK_B = 64;

    // A block
    for (int i_block = 0; i_block < m; i_block += ROW_BLOCK_A) {
        int A_rows = (i_block + ROW_BLOCK_A > m) ? m - i_block : ROW_BLOCK_A;
        // A: i_block ---- i_block + A_rows
        // if (i_block != 0 || A_rows != m) printf("aaa\n");

        // B block
        for (int j_block = 0; j_block < n; j_block += ROW_BLOCK_B) {
            int B_rows = (j_block + ROW_BLOCK_B > n) ? n - j_block : ROW_BLOCK_B;
            // B: j_block --- j_block + B_rows
            // if (j_block != 0 || B_rows != n) printf("bbb\n");

            // matmul_kernel(&A(i_block, 0), lda, &B(0, 0), ldb, &C(i_block, 0), ldc,
            //     &Mask(i_block, j_block), ldm, count + i_block,
            //     A_rows, B_rows, k, alpha, beta, j_block);

            // matmul_kernel(&A(0, 0), lda, &B(0, 0), ldb, &C(0, 0), ldc,
            //     &Mask(0, 0), ldm, count,
            //     A_rows, B_rows, k, alpha, beta, i_block, j_block);

            // matmul_kernel(a, lda, b, ldb, c, ldc, mask, ldm, count, m, n, k, alpha, beta, i_block, j_block);
            int i, j;
            // for (i = i_block; i + 16 <= A_rows + i_block; i+=16){
            //     for (j = j_block; j < B_rows + j_block; j += 8){
            //         kernel_ori_mask(&A(i,      0), lda, &B(0, 0), ldb, &C(i,      0), ldc, k, alpha, &Mask(i,      j), i + 0,  j, count, beta);
            //         kernel_ori_mask(&A(i + 1,  0), lda, &B(0, 0), ldb, &C(i + 1,  0), ldc, k, alpha, &Mask(i + 1,  j), i + 1,  j, count, beta);
            //         kernel_ori_mask(&A(i + 2,  0), lda, &B(0, 0), ldb, &C(i + 2,  0), ldc, k, alpha, &Mask(i + 2,  j), i + 2,  j, count, beta);
            //         kernel_ori_mask(&A(i + 3,  0), lda, &B(0, 0), ldb, &C(i + 3,  0), ldc, k, alpha, &Mask(i + 3,  j), i + 3,  j, count, beta);
            //         kernel_ori_mask(&A(i + 4,  0), lda, &B(0, 0), ldb, &C(i + 4,  0), ldc, k, alpha, &Mask(i + 4,  j), i + 4,  j, count, beta);
            //         kernel_ori_mask(&A(i + 5,  0), lda, &B(0, 0), ldb, &C(i + 5,  0), ldc, k, alpha, &Mask(i + 5,  j), i + 5,  j, count, beta);
            //         kernel_ori_mask(&A(i + 6,  0), lda, &B(0, 0), ldb, &C(i + 6,  0), ldc, k, alpha, &Mask(i + 6,  j), i + 6,  j, count, beta);
            //         kernel_ori_mask(&A(i + 7,  0), lda, &B(0, 0), ldb, &C(i + 7,  0), ldc, k, alpha, &Mask(i + 7,  j), i + 7,  j, count, beta);
            //         kernel_ori_mask(&A(i + 8,  0), lda, &B(0, 0), ldb, &C(i + 8,  0), ldc, k, alpha, &Mask(i + 8,  j), i + 8,  j, count, beta);
            //         kernel_ori_mask(&A(i + 9,  0), lda, &B(0, 0), ldb, &C(i + 9,  0), ldc, k, alpha, &Mask(i + 9,  j), i + 9,  j, count, beta);
            //         kernel_ori_mask(&A(i + 10, 0), lda, &B(0, 0), ldb, &C(i + 10, 0), ldc, k, alpha, &Mask(i + 10, j), i + 10, j, count, beta);
            //         kernel_ori_mask(&A(i + 11, 0), lda, &B(0, 0), ldb, &C(i + 11, 0), ldc, k, alpha, &Mask(i + 11, j), i + 11, j, count, beta);
            //         kernel_ori_mask(&A(i + 12, 0), lda, &B(0, 0), ldb, &C(i + 12, 0), ldc, k, alpha, &Mask(i + 12, j), i + 12, j, count, beta);
            //         kernel_ori_mask(&A(i + 13, 0), lda, &B(0, 0), ldb, &C(i + 13, 0), ldc, k, alpha, &Mask(i + 13, j), i + 13, j, count, beta);
            //         kernel_ori_mask(&A(i + 14, 0), lda, &B(0, 0), ldb, &C(i + 14, 0), ldc, k, alpha, &Mask(i + 14, j), i + 14, j, count, beta);
            //         kernel_ori_mask(&A(i + 15, 0), lda, &B(0, 0), ldb, &C(i + 15, 0), ldc, k, alpha, &Mask(i + 15, j), i + 15, j, count, beta);
            //     }
            // }
            for (i = i_block; i < A_rows + i_block; i++){
                for (j = j_block; j < B_rows + j_block; j += 8) {
                    kernel_ori_mask(&A(i,      0), lda, &B(0, 0), ldb, &C(i,      0), ldc, k, alpha, &Mask(i,      j), i, j, count, beta);
                }
            }
        }
    }
}

// A: m x k
// B: n x k
// C: alpha * A * B.T + beta * C

void matmul_kernel(float *a, int lda, float *b, int ldb, float *c, int ldc, int *mask, int ldm, int *count, 
    int m, int n, int k, float alpha, float beta, int i_start, int j_start){

    int i, j;
    // for (i = i_start; i + 16 <= m + i_start; i+=16){
    //     for (j = j_start; j < n + j_start; j += 8){
    //         kernel_ori_mask(&A(i,      0), lda, &B(0, 0), ldb, &C(i,      0), ldc, k, alpha, &Mask(i,      j), i + 0,  j, count, beta);
    //         kernel_ori_mask(&A(i + 1,  0), lda, &B(0, 0), ldb, &C(i + 1,  0), ldc, k, alpha, &Mask(i + 1,  j), i + 1,  j, count, beta);
    //         kernel_ori_mask(&A(i + 2,  0), lda, &B(0, 0), ldb, &C(i + 2,  0), ldc, k, alpha, &Mask(i + 2,  j), i + 2,  j, count, beta);
    //         kernel_ori_mask(&A(i + 3,  0), lda, &B(0, 0), ldb, &C(i + 3,  0), ldc, k, alpha, &Mask(i + 3,  j), i + 3,  j, count, beta);
    //         kernel_ori_mask(&A(i + 4,  0), lda, &B(0, 0), ldb, &C(i + 4,  0), ldc, k, alpha, &Mask(i + 4,  j), i + 4,  j, count, beta);
    //         kernel_ori_mask(&A(i + 5,  0), lda, &B(0, 0), ldb, &C(i + 5,  0), ldc, k, alpha, &Mask(i + 5,  j), i + 5,  j, count, beta);
    //         kernel_ori_mask(&A(i + 6,  0), lda, &B(0, 0), ldb, &C(i + 6,  0), ldc, k, alpha, &Mask(i + 6,  j), i + 6,  j, count, beta);
    //         kernel_ori_mask(&A(i + 7,  0), lda, &B(0, 0), ldb, &C(i + 7,  0), ldc, k, alpha, &Mask(i + 7,  j), i + 7,  j, count, beta);
    //         kernel_ori_mask(&A(i + 8,  0), lda, &B(0, 0), ldb, &C(i + 8,  0), ldc, k, alpha, &Mask(i + 8,  j), i + 8,  j, count, beta);
    //         kernel_ori_mask(&A(i + 9,  0), lda, &B(0, 0), ldb, &C(i + 9,  0), ldc, k, alpha, &Mask(i + 9,  j), i + 9,  j, count, beta);
    //         kernel_ori_mask(&A(i + 10, 0), lda, &B(0, 0), ldb, &C(i + 10, 0), ldc, k, alpha, &Mask(i + 10, j), i + 10, j, count, beta);
    //         kernel_ori_mask(&A(i + 11, 0), lda, &B(0, 0), ldb, &C(i + 11, 0), ldc, k, alpha, &Mask(i + 11, j), i + 11, j, count, beta);
    //         kernel_ori_mask(&A(i + 12, 0), lda, &B(0, 0), ldb, &C(i + 12, 0), ldc, k, alpha, &Mask(i + 12, j), i + 12, j, count, beta);
    //         kernel_ori_mask(&A(i + 13, 0), lda, &B(0, 0), ldb, &C(i + 13, 0), ldc, k, alpha, &Mask(i + 13, j), i + 13, j, count, beta);
    //         kernel_ori_mask(&A(i + 14, 0), lda, &B(0, 0), ldb, &C(i + 14, 0), ldc, k, alpha, &Mask(i + 14, j), i + 14, j, count, beta);
    //         kernel_ori_mask(&A(i + 15, 0), lda, &B(0, 0), ldb, &C(i + 15, 0), ldc, k, alpha, &Mask(i + 15, j), i + 15, j, count, beta);
    //     }
    // }
    for (i = i_start; i < m + i_start; i++){
        for (j = j_start; j < n + j_start; j += 8) {
            kernel_ori_mask(&A(i,      0), lda, &B(0, 0), ldb, &C(i,      0), ldc, k, alpha, &Mask(i,      j), i, j, count, beta);
        }
    }
}

void kernel_ori_mask(
    float *a, int lda, 
    float *b, int ldb, 
    float *c, int ldc, 
    int k, float alpha, int *mask, int i, int j, int *count, float beta) {

    // for (int i = 1; i < 8; i++){
    //     if (mask[i] == -1){
    //         mask[i] = 0;
    //     }
    // }
    if (mask[0] == -1 || j > count[i]){
        if (j < count[i]){
            printf("count[%d] = %d, mask[%d, %d] = -1\t", i, count[i], i, j);
        }
        return;
    } else if (mask[7] == -1){
        kernel_ori_mask_le8(a, lda, b, ldb, c, ldc, k, alpha, mask, beta);
    }else{
        kernel_ori_mask_8(a, lda, b, ldb, c, ldc, k, alpha, mask, beta);
    }
}

void kernel_ori_mask_le8(
    float *a, int lda, 
    float *b, int ldb, 
    float *c, int ldc, 
    int k, float alpha, int *mask, float beta) {

    int i;
    int num_b = 0;
    for (i = 0; i < 8; i++){
        if (mask[i] < 0){
            break;
        }else{
            num_b ++;
        }
    }

    __m256 b_vec[num_b];
    __m256 c_vec[num_b];
    __m256 a0;

    // c16 ininit by 0
    for (i = 0; i < num_b; i++){
        c_vec[i] = _mm256_setzero_ps();
    }

    int p;
    for (p = 0; p + 8 <= k; p += 8) {
        for (i = 0; i < num_b; i++){
            b_vec[i] = _mm256_load_ps((float *)(b + mask[i] * ldb));
        }

        a0 = _mm256_load_ps((float *)(a));

        for (i = 0; i < num_b; i++){
            c_vec[i]  = _mm256_fmadd_ps(a0, b_vec[i],  c_vec[i]);
        }

        a += 8;
        b += 8;
    }

    int c_id;
    for (i = 0; i < num_b; i++){
        c_id = mask[i];
        c[c_id] = beta * c[c_id] + alpha * (c_vec[i][0] + c_vec[i][1] + c_vec[i][2] + c_vec[i][3] + c_vec[i][4] + c_vec[i][5] + c_vec[i][6] + c_vec[i][7]);
    }
}


void kernel_ori_mask_8(
    float *a, int lda, 
    float *b, int ldb, 
    float *c, int ldc, 
    int k, float alpha, int *mask, float beta) {

    __m256 b0, b1, b2, b3, b4, b5, b6, b7,
           c0, c1, c2, c3, c4, c5, c6, c7;

    __m256 a0;

    // c16 ininit by 0
    c0  = _mm256_setzero_ps();
    c1  = _mm256_setzero_ps();
    c2  = _mm256_setzero_ps();
    c3  = _mm256_setzero_ps();
    c4  = _mm256_setzero_ps();
    c5  = _mm256_setzero_ps();
    c6  = _mm256_setzero_ps();
    c7  = _mm256_setzero_ps();

    int p;
    for (p = 0; p + 8 <= k; p += 8) {
        b0  = _mm256_load_ps((float *)(b + mask[0] * ldb));
        b1  = _mm256_load_ps((float *)(b + mask[1] * ldb));
        b2  = _mm256_load_ps((float *)(b + mask[2] * ldb));
        b3  = _mm256_load_ps((float *)(b + mask[3] * ldb));
        b4  = _mm256_load_ps((float *)(b + mask[4] * ldb));
        b5  = _mm256_load_ps((float *)(b + mask[5] * ldb));
        b6  = _mm256_load_ps((float *)(b + mask[6] * ldb));
        b7  = _mm256_load_ps((float *)(b + mask[7] * ldb));

        a0 = _mm256_load_ps((float *)(a));

        c0  = _mm256_fmadd_ps(a0, b0,  c0);
        c1  = _mm256_fmadd_ps(a0, b1,  c1);
        c2  = _mm256_fmadd_ps(a0, b2,  c2);
        c3  = _mm256_fmadd_ps(a0, b3,  c3);
        c4  = _mm256_fmadd_ps(a0, b4,  c4);
        c5  = _mm256_fmadd_ps(a0, b5,  c5);
        c6  = _mm256_fmadd_ps(a0, b6,  c6);
        c7  = _mm256_fmadd_ps(a0, b7,  c7);

        a += 8;
        b += 8;
    }

    __m256 scalar_vec = _mm256_set1_ps(alpha); // 直接广播常量

    b0 = _mm256_unpacklo_ps(c0, c1); // [r00, r10, r01, r11, ...]
    b1 = _mm256_unpackhi_ps(c0, c1); // [r02, r12, r03, r13, ...]
    b2 = _mm256_unpacklo_ps(c2, c3);
    b3 = _mm256_unpackhi_ps(c2, c3);
    b4 = _mm256_unpacklo_ps(c4, c5);
    b5 = _mm256_unpackhi_ps(c4, c5);
    b6 = _mm256_unpacklo_ps(c6, c7);
    b7 = _mm256_unpackhi_ps(c6, c7);

    // 步骤 2: 合并 128-bit 块
    c0 = _mm256_shuffle_ps(b0, b2, _MM_SHUFFLE(1, 0, 1, 0)); // [r00, r10, r20, r30, ...]
    c1 = _mm256_shuffle_ps(b0, b2, _MM_SHUFFLE(3, 2, 3, 2));
    c2 = _mm256_shuffle_ps(b1, b3, _MM_SHUFFLE(1, 0, 1, 0));
    c3 = _mm256_shuffle_ps(b1, b3, _MM_SHUFFLE(3, 2, 3, 2));
    c4 = _mm256_shuffle_ps(b4, b6, _MM_SHUFFLE(1, 0, 1, 0));
    c5 = _mm256_shuffle_ps(b4, b6, _MM_SHUFFLE(3, 2, 3, 2));
    c6 = _mm256_shuffle_ps(b5, b7, _MM_SHUFFLE(1, 0, 1, 0));
    c7 = _mm256_shuffle_ps(b5, b7, _MM_SHUFFLE(3, 2, 3, 2));

    // 步骤 3: 交换 128-bit 跨通道
    b0 = _mm256_permute2f128_ps(c0, c4, 0x20); // [r00, r10, r20, r30, r40, r50, r60, r70]
    b1 = _mm256_permute2f128_ps(c1, c5, 0x20);
    b2 = _mm256_permute2f128_ps(c2, c6, 0x20);
    b3 = _mm256_permute2f128_ps(c3, c7, 0x20);
    b4 = _mm256_permute2f128_ps(c0, c4, 0x31);
    b5 = _mm256_permute2f128_ps(c1, c5, 0x31);
    b6 = _mm256_permute2f128_ps(c2, c6, 0x31);
    b7 = _mm256_permute2f128_ps(c3, c7, 0x31);

    b0 = _mm256_add_ps(b1, b0);
    b0 = _mm256_add_ps(b2, b0);
    b0 = _mm256_add_ps(b3, b0);
    b0 = _mm256_add_ps(b4, b0);
    b0 = _mm256_add_ps(b5, b0);
    b0 = _mm256_add_ps(b6, b0);
    b0 = _mm256_add_ps(b7, b0);

    b0 = _mm256_mul_ps(b0, scalar_vec);
    int c_id;
    for (p = 0; p < 8; p++){
        c_id = mask[p];
        c[c_id] = b0[p] + beta * c[c_id];
    }
}


// A: row major, m x k, lda = k, k must be multiper of 8
// B: row major, n x k, ldb = k
// C: row major, m x n, ldc = n
void gemm_mul_p(int m, int n, int k, float alpha, 
          float *a, int lda,   // matrix A, aligned
          float *b, int ldb,   // matrix B, aligned
          float *c, int ldc) { // matrix C
    
    if (n % 8 != 0 || k % 8 != 0){
        perror("n or k is not a multiple of 8");
    }

    if (!aligned(a) || !aligned(b)){
        perror("input arr a or b is not aligned");
    }

    int block_k = 512;
    int p;
    for (p = 0; p + block_k <= k; p += block_k){
        macro_kernel(m, n, block_k, alpha, &A(0, p), lda, &B(0, p), ldb, &C(0, 0), ldc);
    }
    if (p < k){
        macro_kernel(m, n, k-p, alpha, &A(0, p), lda, &B(0, p), ldb, &C(0, 0), ldc);
    }
}

// 带掩码的矩阵乘法（混合方法）
void masked_gemm_hybrid(float *X, float *Cen, int *mask, int num, int dim, int clu, float *D) {
    const int BLOCK_SIZE = 64;
    const float DENSITY_THRESHOLD = 0.3f;  // 密度阈值，低于此值使用稀疏计算
    
    int i_block, j_block;
    for (i_block = 0; i_block < num; i_block += BLOCK_SIZE) {
        for (j_block = 0; j_block < clu; j_block += BLOCK_SIZE) {
            // 计算当前块的掩码密度
            int block_m = (i_block + BLOCK_SIZE > num) ? num - i_block : BLOCK_SIZE;
            int block_n = (j_block + BLOCK_SIZE > clu) ? clu - j_block : BLOCK_SIZE;
            int total = block_m * block_n;
            
            for (int i = 0; i < block_m; i++) {
                for (int j = 0; j < block_n; j++) {
                    // if (mask[(i_block+i)* clu + (j_block+j)] == 1) continue;
                    float sum = 0.0f;
                    for (int t = 0; t < dim; t++) {
                        sum += X[(i_block+i)*dim + t] * Cen[(j_block+j) * dim + t];
                    }
                    D[(i_block+i)*clu + (j_block+j)] = -2.0 * sum;
                }
            }
        }
    }
}

void macro_kernel(int m, int n, int k, float alpha, float *a, int lda, float *b, int ldb, float *c, int ldc){
    int i, j;
    for (i = 0; i + 8 <= m; i += 8) {
        for (j = 0; j + 8 <= n; j += 8) {
            micro_kernel(i, j, k, alpha, a, lda, b, ldb, c, ldc);
            // kernel_ori(&A(i,      0), lda, &B(j, 0), ldb, &C(i,      j), ldc, k, alpha);
            // kernel_ori(&A(i + 1,  0), lda, &B(j, 0), ldb, &C(i + 1,  j), ldc, k, alpha);
            // kernel_ori(&A(i + 2,  0), lda, &B(j, 0), ldb, &C(i + 2,  j), ldc, k, alpha);
            // kernel_ori(&A(i + 3,  0), lda, &B(j, 0), ldb, &C(i + 3,  j), ldc, k, alpha);
            // kernel_ori(&A(i + 4,  0), lda, &B(j, 0), ldb, &C(i + 4,  j), ldc, k, alpha);
            // kernel_ori(&A(i + 5,  0), lda, &B(j, 0), ldb, &C(i + 5,  j), ldc, k, alpha);
            // kernel_ori(&A(i + 6,  0), lda, &B(j, 0), ldb, &C(i + 6,  j), ldc, k, alpha);
            // kernel_ori(&A(i + 7,  0), lda, &B(j, 0), ldb, &C(i + 7,  j), ldc, k, alpha);
            // kernel_ori(&A(i + 8,  0), lda, &B(j, 0), ldb, &C(i + 8,  j), ldc, k, alpha);
            // kernel_ori(&A(i + 9,  0), lda, &B(j, 0), ldb, &C(i + 9,  j), ldc, k, alpha);
            // kernel_ori(&A(i + 10, 0), lda, &B(j, 0), ldb, &C(i + 10, j), ldc, k, alpha);
            // kernel_ori(&A(i + 11, 0), lda, &B(j, 0), ldb, &C(i + 11, j), ldc, k, alpha);
            // kernel_ori(&A(i + 12, 0), lda, &B(j, 0), ldb, &C(i + 12, j), ldc, k, alpha);
            // kernel_ori(&A(i + 13, 0), lda, &B(j, 0), ldb, &C(i + 13, j), ldc, k, alpha);
            // kernel_ori(&A(i + 14, 0), lda, &B(j, 0), ldb, &C(i + 14, j), ldc, k, alpha);
            // kernel_ori(&A(i + 15, 0), lda, &B(j, 0), ldb, &C(i + 15, j), ldc, k, alpha);
        }
    }
    for (; i < m; i++){
        for (j = 0; j + 8 <= n; j += 8) {
            kernel_ori(&A(i,      0), lda, &B(j, 0), ldb, &C(i,      j), ldc, k, alpha);
        }
    }
}

void micro_kernel(int i, int j, int k, float alpha, float *a, int lda, float *b, int ldb, float *c, int ldc){
    kernel_ori(&A(i,      0), lda, &B(j, 0), ldb, &C(i,      j), ldc, k, alpha);
    kernel_ori(&A(i + 1,  0), lda, &B(j, 0), ldb, &C(i + 1,  j), ldc, k, alpha);
    kernel_ori(&A(i + 2,  0), lda, &B(j, 0), ldb, &C(i + 2,  j), ldc, k, alpha);
    kernel_ori(&A(i + 3,  0), lda, &B(j, 0), ldb, &C(i + 3,  j), ldc, k, alpha);
    kernel_ori(&A(i + 4,  0), lda, &B(j, 0), ldb, &C(i + 4,  j), ldc, k, alpha);
    kernel_ori(&A(i + 5,  0), lda, &B(j, 0), ldb, &C(i + 5,  j), ldc, k, alpha);
    kernel_ori(&A(i + 6,  0), lda, &B(j, 0), ldb, &C(i + 6,  j), ldc, k, alpha);
    kernel_ori(&A(i + 7,  0), lda, &B(j, 0), ldb, &C(i + 7,  j), ldc, k, alpha);
}


// A: row major, m x k, lda = k, k must be multiper of 8
// B: row major, n x k, ldb = k
// C: row major, m x n, ldc = n
void gemm(int m, int n, int k, float alpha, 
          float *a, int lda,   // matrix A, aligned
          float *b, int ldb,   // matrix B, aligned
          float *c, int ldc) { // matrix C
    
    if (n % 8 != 0 || k % 8 != 0){
        perror("n or k is not a multiple of 8");
    }

    if (!aligned(a) || !aligned(b)){
        perror("input arr a or b is not aligned");
    }

    int i, j;
    for (i = 0; i + 16 <= m; i += 16) {
        for (j = 0; j + 8 <= n; j += 8) {
            kernel_ori(&A(i,      0), lda, &B(j, 0), ldb, &C(i,      j), ldc, k, alpha);
            kernel_ori(&A(i + 1,  0), lda, &B(j, 0), ldb, &C(i + 1,  j), ldc, k, alpha);
            kernel_ori(&A(i + 2,  0), lda, &B(j, 0), ldb, &C(i + 2,  j), ldc, k, alpha);
            kernel_ori(&A(i + 3,  0), lda, &B(j, 0), ldb, &C(i + 3,  j), ldc, k, alpha);
            kernel_ori(&A(i + 4,  0), lda, &B(j, 0), ldb, &C(i + 4,  j), ldc, k, alpha);
            kernel_ori(&A(i + 5,  0), lda, &B(j, 0), ldb, &C(i + 5,  j), ldc, k, alpha);
            kernel_ori(&A(i + 6,  0), lda, &B(j, 0), ldb, &C(i + 6,  j), ldc, k, alpha);
            kernel_ori(&A(i + 7,  0), lda, &B(j, 0), ldb, &C(i + 7,  j), ldc, k, alpha);
            kernel_ori(&A(i + 8,  0), lda, &B(j, 0), ldb, &C(i + 8,  j), ldc, k, alpha);
            kernel_ori(&A(i + 9,  0), lda, &B(j, 0), ldb, &C(i + 9,  j), ldc, k, alpha);
            kernel_ori(&A(i + 10, 0), lda, &B(j, 0), ldb, &C(i + 10, j), ldc, k, alpha);
            kernel_ori(&A(i + 11, 0), lda, &B(j, 0), ldb, &C(i + 11, j), ldc, k, alpha);
            kernel_ori(&A(i + 12, 0), lda, &B(j, 0), ldb, &C(i + 12, j), ldc, k, alpha);
            kernel_ori(&A(i + 13, 0), lda, &B(j, 0), ldb, &C(i + 13, j), ldc, k, alpha);
            kernel_ori(&A(i + 14, 0), lda, &B(j, 0), ldb, &C(i + 14, j), ldc, k, alpha);
            kernel_ori(&A(i + 15, 0), lda, &B(j, 0), ldb, &C(i + 15, j), ldc, k, alpha);
        }
    }
    for (; i < m; i++){
        for (j = 0; j + 8 <= n; j += 8) {
            kernel_ori(&A(i,      0), lda, &B(j, 0), ldb, &C(i,      j), ldc, k, alpha);
        }
    }
}

void kernel_ori(
    float *a, int lda, 
    float *b, int ldb, 
    float *c, int ldc, 
    int k, float alpha) {
    
    // // copy and blas begin
    // float *b2 = aligned_alloc(32, 8 * k * sizeof(float));
    // memcpy(b2 + 0 * k, b + 0 * k, k * sizeof(float));
    // memcpy(b2 + 1 * k, b + 1 * k, k * sizeof(float));
    // memcpy(b2 + 2 * k, b + 2 * k, k * sizeof(float));
    // memcpy(b2 + 3 * k, b + 3 * k, k * sizeof(float));
    // memcpy(b2 + 4 * k, b + 4 * k, k * sizeof(float));
    // memcpy(b2 + 5 * k, b + 5 * k, k * sizeof(float));
    // memcpy(b2 + 6 * k, b + 6 * k, k * sizeof(float));
    // memcpy(b2 + 7 * k, b + 7 * k, k * sizeof(float));
    // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, 1, 8, k, alpha, a, lda, b2, k, 0.0, c, ldc);  
    // free(b2);
    // copy and blas end

    __m256 b0, b1, b2, b3, b4, b5, b6, b7,
           c0, c1, c2, c3, c4, c5, c6, c7;

    __m256 a0;

    // c16 ininit by 0
    c0  = _mm256_setzero_ps();
    c1  = _mm256_setzero_ps();
    c2  = _mm256_setzero_ps();
    c3  = _mm256_setzero_ps();
    c4  = _mm256_setzero_ps();
    c5  = _mm256_setzero_ps();
    c6  = _mm256_setzero_ps();
    c7  = _mm256_setzero_ps();

    int p;
    for (p = 0; p + 8 <= k; p += 8) {
        b0  = _mm256_load_ps((float *)(b + 0 * ldb));
        b1  = _mm256_load_ps((float *)(b + 1 * ldb));
        b2  = _mm256_load_ps((float *)(b + 2 * ldb));
        b3  = _mm256_load_ps((float *)(b + 3 * ldb));
        b4  = _mm256_load_ps((float *)(b + 4 * ldb));
        b5  = _mm256_load_ps((float *)(b + 5 * ldb));
        b6  = _mm256_load_ps((float *)(b + 6 * ldb));
        b7  = _mm256_load_ps((float *)(b + 7 * ldb));

        a0 = _mm256_load_ps((float *)(a));

        c0  = _mm256_fmadd_ps(a0, b0,  c0);
        c1  = _mm256_fmadd_ps(a0, b1,  c1);
        c2  = _mm256_fmadd_ps(a0, b2,  c2);
        c3  = _mm256_fmadd_ps(a0, b3,  c3);
        c4  = _mm256_fmadd_ps(a0, b4,  c4);
        c5  = _mm256_fmadd_ps(a0, b5,  c5);
        c6  = _mm256_fmadd_ps(a0, b6,  c6);
        c7  = _mm256_fmadd_ps(a0, b7,  c7);

        a += 8;
        b += 8;
    }

    // __m256 scalar_vec = _mm256_set1_ps(alpha); // 直接广播常量
    __m256 scalar_vec = _mm256_broadcast_ss(&alpha); // 直接广播常量

    b0 = _mm256_unpacklo_ps(c0, c1); // [r00, r10, r01, r11, ...]
    b1 = _mm256_unpackhi_ps(c0, c1); // [r02, r12, r03, r13, ...]
    b2 = _mm256_unpacklo_ps(c2, c3);
    b3 = _mm256_unpackhi_ps(c2, c3);
    b4 = _mm256_unpacklo_ps(c4, c5);
    b5 = _mm256_unpackhi_ps(c4, c5);
    b6 = _mm256_unpacklo_ps(c6, c7);
    b7 = _mm256_unpackhi_ps(c6, c7);

    // 步骤 2: 合并 128-bit 块
    c0 = _mm256_shuffle_ps(b0, b2, _MM_SHUFFLE(1, 0, 1, 0)); // [r00, r10, r20, r30, ...]
    c1 = _mm256_shuffle_ps(b0, b2, _MM_SHUFFLE(3, 2, 3, 2));
    c2 = _mm256_shuffle_ps(b1, b3, _MM_SHUFFLE(1, 0, 1, 0));
    c3 = _mm256_shuffle_ps(b1, b3, _MM_SHUFFLE(3, 2, 3, 2));
    c4 = _mm256_shuffle_ps(b4, b6, _MM_SHUFFLE(1, 0, 1, 0));
    c5 = _mm256_shuffle_ps(b4, b6, _MM_SHUFFLE(3, 2, 3, 2));
    c6 = _mm256_shuffle_ps(b5, b7, _MM_SHUFFLE(1, 0, 1, 0));
    c7 = _mm256_shuffle_ps(b5, b7, _MM_SHUFFLE(3, 2, 3, 2));

    // 步骤 3: 交换 128-bit 跨通道
    b0 = _mm256_permute2f128_ps(c0, c4, 0x20); // [r00, r10, r20, r30, r40, r50, r60, r70]
    b1 = _mm256_permute2f128_ps(c1, c5, 0x20);
    b2 = _mm256_permute2f128_ps(c2, c6, 0x20);
    b3 = _mm256_permute2f128_ps(c3, c7, 0x20);
    b4 = _mm256_permute2f128_ps(c0, c4, 0x31);
    b5 = _mm256_permute2f128_ps(c1, c5, 0x31);
    b6 = _mm256_permute2f128_ps(c2, c6, 0x31);
    b7 = _mm256_permute2f128_ps(c3, c7, 0x31);

    b0 = _mm256_add_ps(b1, b0);
    b0 = _mm256_add_ps(b2, b0);
    b0 = _mm256_add_ps(b3, b0);
    b0 = _mm256_add_ps(b4, b0);
    b0 = _mm256_add_ps(b5, b0);
    b0 = _mm256_add_ps(b6, b0);
    b0 = _mm256_add_ps(b7, b0);

    b0 = _mm256_mul_ps(b0, scalar_vec);
    _mm256_storeu_ps(c, b0);
}

void transpose_8x8_avx(
    __m256 c0, __m256 c1, __m256 c2, __m256 c3, __m256 c4, __m256 c5, __m256 c6, __m256 c7,
    __m256 b0, __m256 b1, __m256 b2, __m256 b3, __m256 b4, __m256 b5, __m256 b6, __m256 b7){

    // 步骤 1: 解包低位和高位
    b0 = _mm256_unpacklo_ps(c0, c1); // [r00, r10, r01, r11, ...]
    b1 = _mm256_unpackhi_ps(c0, c1); // [r02, r12, r03, r13, ...]
    b2 = _mm256_unpacklo_ps(c2, c3);
    b3 = _mm256_unpackhi_ps(c2, c3);
    b4 = _mm256_unpacklo_ps(c4, c5);
    b5 = _mm256_unpackhi_ps(c4, c5);
    b6 = _mm256_unpacklo_ps(c6, c7);
    b7 = _mm256_unpackhi_ps(c6, c7);

    // 步骤 2: 合并 128-bit 块
    c0 = _mm256_shuffle_ps(b0, b2, _MM_SHUFFLE(1, 0, 1, 0)); // [r00, r10, r20, r30, ...]
    c1 = _mm256_shuffle_ps(b0, b2, _MM_SHUFFLE(3, 2, 3, 2));
    c2 = _mm256_shuffle_ps(b1, b3, _MM_SHUFFLE(1, 0, 1, 0));
    c3 = _mm256_shuffle_ps(b1, b3, _MM_SHUFFLE(3, 2, 3, 2));
    c4 = _mm256_shuffle_ps(b4, b6, _MM_SHUFFLE(1, 0, 1, 0));
    c5 = _mm256_shuffle_ps(b4, b6, _MM_SHUFFLE(3, 2, 3, 2));
    c6 = _mm256_shuffle_ps(b5, b7, _MM_SHUFFLE(1, 0, 1, 0));
    c7 = _mm256_shuffle_ps(b5, b7, _MM_SHUFFLE(3, 2, 3, 2));

    // 步骤 3: 交换 128-bit 跨通道
    b0 = _mm256_permute2f128_ps(c0, c4, 0x20); // [r00, r10, r20, r30, r40, r50, r60, r70]
    b1 = _mm256_permute2f128_ps(c1, c5, 0x20);
    b2 = _mm256_permute2f128_ps(c2, c6, 0x20);
    b3 = _mm256_permute2f128_ps(c3, c7, 0x20);
    b4 = _mm256_permute2f128_ps(c0, c4, 0x31);
    b5 = _mm256_permute2f128_ps(c1, c5, 0x31);
    b6 = _mm256_permute2f128_ps(c2, c6, 0x31);
    b7 = _mm256_permute2f128_ps(c3, c7, 0x31);

    // __m256 sum = _mm256_setzero_ps();
    // sum = _mm256_add_ps(sum, b0);
    // sum = _mm256_add_ps(sum, b1);
    // sum = _mm256_add_ps(sum, b2);
    // sum = _mm256_add_ps(sum, b3);
    // sum = _mm256_add_ps(sum, b4);
    // sum = _mm256_add_ps(sum, b5);
    // sum = _mm256_add_ps(sum, b6);
    // sum = _mm256_add_ps(sum, b7);

    b0 = _mm256_add_ps(b1, b0);
    b0 = _mm256_add_ps(b2, b0);
    b0 = _mm256_add_ps(b3, b0);
    b0 = _mm256_add_ps(b4, b0);
    b0 = _mm256_add_ps(b5, b0);
    b0 = _mm256_add_ps(b6, b0);
    b0 = _mm256_add_ps(b7, b0);
}

