#include "lib/util.h"
void PackMatrixA_16x8(int k, int lda, int m, int M, const uint8_t *a_in, uint8_t *a_out)
{
    if(M - m >= 16)
    {
        for(int j = 0; j < k; j++)
        {
            *a_out++ = a_in[m + j * lda];
            *a_out++ = a_in[m + 1 + j * lda];
            *a_out++ = a_in[m + 2 + j * lda];
            *a_out++ = a_in[m + 3 + j * lda];
            *a_out++ = a_in[m + 4 + j * lda];
            *a_out++ = a_in[m + 5 + j * lda];
            *a_out++ = a_in[m + 6 + j * lda];
            *a_out++ = a_in[m + 7 + j * lda];
            *a_out++ = a_in[m + 8 + j * lda];
            *a_out++ = a_in[m + 9 + j * lda];
            *a_out++ = a_in[m + 10 + j * lda];
            *a_out++ = a_in[m + 11 + j * lda];
            *a_out++ = a_in[m + 12 + j * lda];
            *a_out++ = a_in[m + 13 + j * lda];
            *a_out++ = a_in[m + 14 + j * lda];
            *a_out++ = a_in[m + 15 + j * lda];
        }
    }
}

void optim_uint8_col_major_sgemm(
    char transa,                                          //a是否转置
    char transb,                                          //b是否转置
    int M, int N, int K,                                  //input size
    const float alpha,                                    //偏置
    const uint8_t * src_a, int lda,                         //输入a指针 步长
    const uint8_t * src_b, int ldb,                         //输入b指针 步长
    const float beta,                                     //偏置
    uint16_t * dst, int ldc)                                 //输出dst指针 步长
{
    int a_stride_m = transa == 'n' ? 1 : lda;
    int a_stride_k = transa == 'n' ? lda : 1;
    int b_stride_k = transb == 'n' ? 1 : ldb;
    int b_stride_n = transb == 'n' ? ldb : 1;
    int m;
    int n;
    uint8_t packedA[M * K];                       //函数内声明变量有size限制——>M和K过大会导致掉核心(需要对每个线程分配适量堆栈内存)

    for(n=0;n<=N-8;n+=8) 
    {
        for(m=0;m<=M-16;m+=16) 
        {
            uint16x8_t buf0_0 = vdupq_n_u16(0);
            uint16x8_t buf0_1 = vdupq_n_u16(0);
            uint16x8_t buf0_2 = vdupq_n_u16(0);
            uint16x8_t buf0_3 = vdupq_n_u16(0);
            uint16x8_t buf0_4 = vdupq_n_u16(0);
            uint16x8_t buf0_5 = vdupq_n_u16(0);
            uint16x8_t buf0_6 = vdupq_n_u16(0);
            uint16x8_t buf0_7 = vdupq_n_u16(0);
            uint16x8_t buf8_0 = vdupq_n_u16(0);
            uint16x8_t buf8_1 = vdupq_n_u16(0);
            uint16x8_t buf8_2 = vdupq_n_u16(0);
            uint16x8_t buf8_3 = vdupq_n_u16(0);
            uint16x8_t buf8_4 = vdupq_n_u16(0);
            uint16x8_t buf8_5 = vdupq_n_u16(0);
            uint16x8_t buf8_6 = vdupq_n_u16(0);
            uint16x8_t buf8_7 = vdupq_n_u16(0);
            if (n == 0)
            {
                PackMatrixA_16x8(K, lda, m, M, src_a, &packedA[m * K]);
            }
            for(int k=0;k<K;k++) 
            {
                uint8x8_t va0 = vld1_u8(&packedA[m * K + k * 16]);
                uint8x8_t va1 = vld1_u8(&packedA[m * K + 8 + k * 16]);
                register uint8_t vb0 = src_b[k + n * ldb];
                register uint8_t vb1 = src_b[k + (n + 1) * ldb];
                register uint8_t vb2 = src_b[k + (n + 2) * ldb];
                register uint8_t vb3 = src_b[k + (n + 3) * ldb];
                register uint8_t vb4 = src_b[k + (n + 4) * ldb];
                register uint8_t vb5 = src_b[k + (n + 5) * ldb];
                register uint8_t vb6 = src_b[k + (n + 6) * ldb];
                register uint8_t vb7 = src_b[k + (n + 7) * ldb];
                uint8x8_t hvb0 = vdup_n_u8(vb0);
                uint8x8_t hvb1 = vdup_n_u8(vb1);
                uint8x8_t hvb2 = vdup_n_u8(vb2);
                uint8x8_t hvb3 = vdup_n_u8(vb3);
                uint8x8_t hvb4 = vdup_n_u8(vb4);
                uint8x8_t hvb5 = vdup_n_u8(vb5);
                uint8x8_t hvb6 = vdup_n_u8(vb6);
                uint8x8_t hvb7 = vdup_n_u8(vb7);
                buf0_0 = vmlal_u8(buf0_0, va0, hvb0);
                buf0_1 = vmlal_u8(buf0_1, va0, hvb1);
                buf0_2 = vmlal_u8(buf0_2, va0, hvb2);
                buf0_3 = vmlal_u8(buf0_3, va0, hvb3);
                buf0_4 = vmlal_u8(buf0_4, va0, hvb4);
                buf0_5 = vmlal_u8(buf0_5, va0, hvb5);
                buf0_6 = vmlal_u8(buf0_6, va0, hvb6);
                buf0_7 = vmlal_u8(buf0_7, va0, hvb7);
                buf8_0 = vmlal_u8(buf8_0, va1, hvb0);
                buf8_1 = vmlal_u8(buf8_1, va1, hvb1);
                buf8_2 = vmlal_u8(buf8_2, va1, hvb2);
                buf8_3 = vmlal_u8(buf8_3, va1, hvb3);
                buf8_4 = vmlal_u8(buf8_4, va1, hvb4);
                buf8_5 = vmlal_u8(buf8_5, va1, hvb5);
                buf8_6 = vmlal_u8(buf8_6, va1, hvb6);
                buf8_7 = vmlal_u8(buf8_7, va1, hvb7);

            }
                // buf0_0 = vshrq_n_u16(buf0_0, 2);        //buf0_0中所有元素右移2位
                vst1q_u16(&dst[m + n * ldc], buf0_0);
                vst1q_u16(&dst[m + (n + 1) * ldc], buf0_1);
                vst1q_u16(&dst[m + (n + 2) * ldc], buf0_2);
                vst1q_u16(&dst[m + (n + 3) * ldc], buf0_3);
                vst1q_u16(&dst[m + (n + 4) * ldc], buf0_4);
                vst1q_u16(&dst[m + (n + 5) * ldc], buf0_5);
                vst1q_u16(&dst[m + (n + 6) * ldc], buf0_6);
                vst1q_u16(&dst[m + (n + 7) * ldc], buf0_7);
                vst1q_u16(&dst[m + 8 + n * ldc], buf8_0);
                vst1q_u16(&dst[m + 8 + (n + 1) * ldc], buf8_1);
                vst1q_u16(&dst[m + 8 + (n + 2) * ldc], buf8_2);
                vst1q_u16(&dst[m + 8 + (n + 3) * ldc], buf8_3);
                vst1q_u16(&dst[m + 8 + (n + 4) * ldc], buf8_4);
                vst1q_u16(&dst[m + 8 + (n + 5) * ldc], buf8_5);
                vst1q_u16(&dst[m + 8 + (n + 6) * ldc], buf8_6);
                vst1q_u16(&dst[m + 8 + (n + 7) * ldc], buf8_7);
        }
      //行数不是8的倍数时的补充
        for (; m < M; m++)
        {
            if (n == 0)
            {
                PackMatrixA_16x8(K, lda, m, M, src_a, &packedA[m * K]);
            }
            uint16_t dst_0, dst_1, dst_2, dst_3, dst_4, dst_5, dst_6, dst_7;
            dst_0 = 0;
            dst_1 = 0;
            dst_2 = 0;
            dst_3 = 0;
            dst_4 = 0;
            dst_5 = 0;
            dst_6 = 0;
            dst_7 = 0;         
            for(int k=0;k<K;k++)
            {
                dst_0 = dst_0 + (uint16_t)src_a[m + k * lda] * src_b[k + n * ldb];
                dst_1 = dst_1 + (uint16_t)src_a[m + k * lda] * src_b[k + (n + 1) * ldb];
                dst_2 = dst_2 + (uint16_t)src_a[m + k * lda] * src_b[k + (n + 2) * ldb];
                dst_3 = dst_3 + (uint16_t)src_a[m + k * lda] * src_b[k + (n + 3) * ldb];
                dst_4 = dst_4 + (uint16_t)src_a[m + k * lda] * src_b[k + (n + 4) * ldb];
                dst_5 = dst_5 + (uint16_t)src_a[m + k * lda] * src_b[k + (n + 5) * ldb];
                dst_6 = dst_6 + (uint16_t)src_a[m + k * lda] * src_b[k + (n + 6) * ldb];
                dst_7 = dst_7 + (uint16_t)src_a[m + k * lda] * src_b[k + (n + 7) * ldb];            
            }
            dst[m + n * ldc] = dst_0;
            dst[m + (n + 1) * ldc] = dst_1;
            dst[m + (n + 2) * ldc] = dst_2;
            dst[m + (n + 3) * ldc] = dst_3;
            dst[m + (n + 4) * ldc] = dst_4;
            dst[m + (n + 5) * ldc] = dst_5;
            dst[m + (n + 6) * ldc] = dst_6;
            dst[m + (n + 7) * ldc] = dst_7;         
        }
    }
   //列数不是8的倍数时的补充
    for (; n < N; n++)
    {
        for(m=0;m<=M-16;m+=16) 
        {
            uint16x8_t buf0 = vdupq_n_u16(0);
            uint16x8_t buf1 = vdupq_n_u16(0);
            for(int k=0;k<K;k++) 
            {
                uint8x8_t va0 = vld1_u8(&packedA[m * K + k * 16]);
                uint8x8_t va1 = vld1_u8(&packedA[m * K + 8 + k * 16]);
                register uint8_t vb = src_b[k + n * ldb];
                uint8x8_t hvb = vdup_n_u8(vb);
                buf0 = vmlal_u8(buf0, va0, hvb);
                buf1 = vmlal_u8(buf1, va1, hvb);
            }
        vst1q_u16(&dst[m + n * ldc], buf0);
        vst1q_u16(&dst[m + 8 + n * ldc], buf1);
        }
        for(;m<M;m++)
        {
            uint16_t temp = 0;
            for (int k = 0; k < K; k++){
            temp += (uint16_t)src_a[m + k * lda] * src_b[k + n * ldb];
            }
            dst[m + n * ldc] = temp;
        }
    }
}