#include "lib/util.h"
void PackMatrixA_8x8(int k, int lda, int m, int M, const float *a_in, float *a_out)
{
    if(M - m >= 8)
    {
        for(int j = 0; j < k; j++)
        {
            *a_out++ = a_in[m + j * lda];
            *a_out++ = a_in[m + 1 + j * lda];
            *a_out++ = a_in[m + 2 + j * lda];
            *a_out++ = a_in[m + 3 + j * lda];
            *a_out++ = a_in[m + 4 + j * lda];
            *a_out++ = a_in[m + 5 + j * lda];
            *a_out++ = a_in[m + 6 + j * lda];
            *a_out++ = a_in[m + 7 + j * lda];
        }
    }
}

void optim_v1_col_major_sgemm(
    char transa,                                          //a是否转置
    char transb,                                          //b是否转置
    int M, int N, int K,                                  //input size
    const float alpha,                                    //偏置
    const float * src_a, int lda,                         //输入a指针 步长
    const float * src_b, int ldb,                         //输入b指针 步长
    const float beta,                                     //偏置
    float * dst, int ldc)                                 //输出dst指针 步长
{
    int a_stride_m = transa == 'n' ? 1 : lda;
    int a_stride_k = transa == 'n' ? lda : 1;
    int b_stride_k = transb == 'n' ? 1 : ldb;
    int b_stride_n = transb == 'n' ? ldb : 1;
    int m;
    int n;
    float packedA[M * K];                       //函数内声明变量有size限制——>M和K过大会导致掉核心

    for(n=0;n<=N-8;n+=8) 
    {
        for(m=0;m<=M-8;m+=8) 
        {
            float32x4_t buf00 = vdupq_n_f32(0.0f);
            float32x4_t buf01 = vdupq_n_f32(0.0f);
            float32x4_t buf02 = vdupq_n_f32(0.0f);
            float32x4_t buf03 = vdupq_n_f32(0.0f);
            float32x4_t buf04 = vdupq_n_f32(0.0f);
            float32x4_t buf05 = vdupq_n_f32(0.0f);
            float32x4_t buf06 = vdupq_n_f32(0.0f);
            float32x4_t buf07 = vdupq_n_f32(0.0f);
            float32x4_t buf40 = vdupq_n_f32(0.0f);
            float32x4_t buf41 = vdupq_n_f32(0.0f);
            float32x4_t buf42 = vdupq_n_f32(0.0f);
            float32x4_t buf43 = vdupq_n_f32(0.0f);
            float32x4_t buf44 = vdupq_n_f32(0.0f);
            float32x4_t buf45 = vdupq_n_f32(0.0f);
            float32x4_t buf46 = vdupq_n_f32(0.0f);
            float32x4_t buf47 = vdupq_n_f32(0.0f);
            if (n == 0)
            {
                PackMatrixA_8x8(K, lda, m, M, src_a, &packedA[m * K]);
            }
            for(int k=0;k<K;k++) 
            {
                float32x4_t va0 = vld1q_f32(&packedA[m * K + k * 8]);
                float32x4_t va1 = vld1q_f32(&packedA[m * K + 4 + k * 8]);
                register float vb0 = src_b[k + n * ldb];
                register float vb1 = src_b[k + (n + 1) * ldb];
                register float vb2 = src_b[k + (n + 2) * ldb];
                register float vb3 = src_b[k + (n + 3) * ldb];
                register float vb4 = src_b[k + (n + 4) * ldb];
                register float vb5 = src_b[k + (n + 5) * ldb];
                register float vb6 = src_b[k + (n + 6) * ldb];
                register float vb7 = src_b[k + (n + 7) * ldb];
                buf00 = vmlaq_n_f32(buf00, va0, vb0);
                buf01 = vmlaq_n_f32(buf01, va0, vb1);
                buf02 = vmlaq_n_f32(buf02, va0, vb2);
                buf03 = vmlaq_n_f32(buf03, va0, vb3);
                buf04 = vmlaq_n_f32(buf04, va0, vb4);
                buf05 = vmlaq_n_f32(buf05, va0, vb5);
                buf06 = vmlaq_n_f32(buf06, va0, vb6);
                buf07 = vmlaq_n_f32(buf07, va0, vb7);
                buf40 = vmlaq_n_f32(buf40, va1, vb0);
                buf41 = vmlaq_n_f32(buf41, va1, vb1);
                buf42 = vmlaq_n_f32(buf42, va1, vb2);
                buf43 = vmlaq_n_f32(buf43, va1, vb3);
                buf44 = vmlaq_n_f32(buf44, va1, vb4);
                buf45 = vmlaq_n_f32(buf45, va1, vb5);
                buf46 = vmlaq_n_f32(buf46, va1, vb6);
                buf47 = vmlaq_n_f32(buf47, va1, vb7);

            }   
                vst1q_f32(&dst[m + n * ldc], buf00);
                vst1q_f32(&dst[m + (n + 1) * ldc], buf01);
                vst1q_f32(&dst[m + (n + 2) * ldc], buf02);
                vst1q_f32(&dst[m + (n + 3) * ldc], buf03);
                vst1q_f32(&dst[m + (n + 4) * ldc], buf04);
                vst1q_f32(&dst[m + (n + 5) * ldc], buf05);
                vst1q_f32(&dst[m + (n + 6) * ldc], buf06);
                vst1q_f32(&dst[m + (n + 7) * ldc], buf07);
                vst1q_f32(&dst[m + 4 + n * ldc], buf40);
                vst1q_f32(&dst[m + 4 + (n + 1) * ldc], buf41);
                vst1q_f32(&dst[m + 4 + (n + 2) * ldc], buf42);
                vst1q_f32(&dst[m + 4 + (n + 3) * ldc], buf43);
                vst1q_f32(&dst[m + 4 + (n + 4) * ldc], buf44);
                vst1q_f32(&dst[m + 4 + (n + 5) * ldc], buf45);
                vst1q_f32(&dst[m + 4 + (n + 6) * ldc], buf46);
                vst1q_f32(&dst[m + 4 + (n + 7) * ldc], buf47);
        }
      //行数不是8的倍数时的补充
        for (; m < M; m++)
        {
            if (n == 0)
            {
                PackMatrixA_8x8(K, lda, m, M, src_a, &packedA[m * K]);
            }
            register float dst_0, dst_1, dst_2, dst_3, dst_4, dst_5, dst_6, dst_7;
            dst_0 = 0;
            dst_1 = 0;
            dst_2 = 0;
            dst_3 = 0;
            dst_4 = 0;
            dst_5 = 0;
            dst_6 = 0;
            dst_7 = 0;         
            for(int k=0;k<K;k++)
            {
                dst_0 = dst_0 + src_a[m + k * lda] * src_b[k + n * ldb];
                dst_1 = dst_1 + src_a[m + k * lda] * src_b[k + (n + 1) * ldb];
                dst_2 = dst_2 + src_a[m + k * lda] * src_b[k + (n + 2) * ldb];
                dst_3 = dst_3 + src_a[m + k * lda] * src_b[k + (n + 3) * ldb];
                dst_4 = dst_4 + src_a[m + k * lda] * src_b[k + (n + 4) * ldb];
                dst_5 = dst_5 + src_a[m + k * lda] * src_b[k + (n + 5) * ldb];
                dst_6 = dst_6 + src_a[m + k * lda] * src_b[k + (n + 6) * ldb];
                dst_7 = dst_7 + src_a[m + k * lda] * src_b[k + (n + 7) * ldb];            
            }
            dst[m + n * ldc] = dst_0;
            dst[m + (n + 1) * ldc] = dst_1;
            dst[m + (n + 2) * ldc] = dst_2;
            dst[m + (n + 3) * ldc] = dst_3;
            dst[m + (n + 4) * ldc] = dst_4;
            dst[m + (n + 5) * ldc] = dst_5;
            dst[m + (n + 6) * ldc] = dst_6;
            dst[m + (n + 7) * ldc] = dst_7;         
        }
    }
   //列数不是8的倍数时的补充
    for (; n < N; n++)
    {
        for(m=0;m<=M-8;m+=8) 
        {
            float32x4_t buf0 = vdupq_n_f32(0.0f);
            float32x4_t buf1 = vdupq_n_f32(0.0f);
            for(int k=0;k<K;k++) 
            {
                float32x4_t va0 = vld1q_f32(&packedA[m * K + k * 8]);
                float32x4_t va1 = vld1q_f32(&packedA[m * K + 4 + k * 8]);
                register float vb = src_b[k + n * ldb];
                buf0 = vmlaq_n_f32(buf0, va0, vb);
                buf1 = vmlaq_n_f32(buf1, va1, vb);
            }
        vst1q_f32(&dst[m + n * ldc], buf0);
        vst1q_f32(&dst[m + 4 + n * ldc], buf1);
        }
        for(;m<M;m++)
        {
            float temp = 0;
            for (int k = 0; k < K; k++){
            temp += src_a[m + k * lda] * src_b[k + n * ldb];
            }
            dst[m + n * ldc] = temp;
        }
    }
}