#include "lib/util.h"

float** mat_add_ori(int Row, int Col, float mat1[Row][Col],float mat2[Row][Col])
{
   float** out_mat;
   out_mat = (float**)malloc(Row*sizeof(float*));  //分配行长度
   for(int i=0; i<Row; i++)
      out_mat[i] = (float*)malloc(Col*sizeof(float));  //分配列长度

   for(int i = 0; i < Row; i++)
   {
      for(int j = 0; j < Col; j++)
      {
         out_mat[i][j] = mat1[i][j] + mat2[i][j];
      }
   }
   return out_mat;
}
//列优先矩阵乘，列优先：指针加一指向下一行
void naive_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc)                                 //输出dst指针 步长
{
   int a_stride_m = transa == 'n' ? 1 : lda;
   int a_stride_k = transa == 'n' ? lda : 1;
   int b_stride_k = transb == 'n' ? 1 : ldb;
   int b_stride_n = transb == 'n' ? ldb : 1;
   
   for(int m=0;m<M;m++) 
   {
      for(int n=0;n<N;n++) 
      {
         float acc = 0.f;
         const float * a_ptr = src_a + m * a_stride_m;
         const float * b_ptr = src_b + n * b_stride_n;

         for(int k=0;k<K;k++) {
               acc += a_ptr[0] * b_ptr[0];
               a_ptr += a_stride_k;
               b_ptr += b_stride_k;
         }

         dst[m + n * ldc] = alpha * acc + beta * dst[m + n * ldc];
      }
   }
}

//列优先矩阵乘，列优先：指针加一指向下一行
void neom_v1_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc)                                 //输出dst指针 步长
{
   int a_stride_m = transa == 'n' ? 1 : lda;
   int a_stride_k = transa == 'n' ? lda : 1;
   int b_stride_k = transb == 'n' ? 1 : ldb;
   int b_stride_n = transb == 'n' ? ldb : 1;
   
   for(int n=0;n<N;n++) 
   {
      for(int m=0;m<M;m+=4) {
         float32x4_t buf = vdupq_n_f32(0.0f);
         for(int k=0;k<K;k++) 
         {
            float32x4_t va = vld1q_f32(&src_a[m + k * lda]);
            register float vb = src_b[k + n * ldb];
            buf = vmlaq_n_f32(buf, va, vb);
         }
         vst1q_f32(&dst[m + n * ldc], buf);
      }
   }
}

void neom_v2_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc)                                 //输出dst指针 步长
{
   int a_stride_m = transa == 'n' ? 1 : lda;
   int a_stride_k = transa == 'n' ? lda : 1;
   int b_stride_k = transb == 'n' ? 1 : ldb;
   int b_stride_n = transb == 'n' ? ldb : 1;
   int m;
   int n;

   for(n=0;n<=N-4;n+=4) 
   {
      for(m=0;m<=M-4;m+=4) 
      {
         float32x4_t buf0 = vdupq_n_f32(0.0f);
         float32x4_t buf1 = vdupq_n_f32(0.0f);
         float32x4_t buf2 = vdupq_n_f32(0.0f);
         float32x4_t buf3 = vdupq_n_f32(0.0f);
         for(int k=0;k<K;k++) 
         {
            float32x4_t va = vld1q_f32(&src_a[m + k * lda]);
            register float vb0 = src_b[k + n * ldb];
            register float vb1 = src_b[k + (n + 1) * ldb];
            register float vb2 = src_b[k + (n + 2) * ldb];
            register float vb3 = src_b[k + (n + 3) * ldb];
            buf0 = vmlaq_n_f32(buf0, va, vb0);
            buf1 = vmlaq_n_f32(buf1, va, vb1);
            buf2 = vmlaq_n_f32(buf2, va, vb2);
            buf3 = vmlaq_n_f32(buf3, va, vb3);
         }
         vst1q_f32(&dst[m + n * ldc], buf0);
         vst1q_f32(&dst[m + (n + 1) * ldc], buf1);
         vst1q_f32(&dst[m + (n + 2) * ldc], buf2);
         vst1q_f32(&dst[m + (n + 3) * ldc], buf3);
      }
      //行数不是4的倍数时的补充
      for (; m < M; m++)
      {
         register float dst_0, dst_1, dst_2, dst_3;
         dst_0 = 0;
         dst_1 = 0;
         dst_2 = 0;
         dst_3 = 0;
         for(int k=0;k<K;k++)
         {
            dst_0 = dst_0 + src_a[m + k * lda] * src_b[k + n * ldb];
            dst_1 = dst_1 + src_a[m + k * lda] * src_b[k + (n + 1) * ldb];
            dst_2 = dst_2 + src_a[m + k * lda] * src_b[k + (n + 2) * ldb];
            dst_3 = dst_3 + src_a[m + k * lda] * src_b[k + (n + 3) * ldb];
         }
         dst[m + n * ldc] = dst_0;
         dst[m + (n + 1) * ldc] = dst_1;
         dst[m + (n + 2) * ldc] = dst_2;
         dst[m + (n + 3) * ldc] = dst_3;
      }
   }
   //列数不是4的倍数时的补充
   for (; n < N; n++)
   {
      for(m=0;m<=M-4;m+=4) 
      {
         float32x4_t buf = vdupq_n_f32(0.0f);
         for(int k=0;k<K;k++) 
         {
            float32x4_t va = vld1q_f32(&src_a[m + k * lda]);
            register float vb = src_b[k + n * ldb];
            buf = vmlaq_n_f32(buf, va, vb);
         }
         vst1q_f32(&dst[m + n * ldc], buf);
      }
      for(;m<=M;m++)
      {
         float temp = 0;
         for (int k = 0; k < K; k++){
            temp += src_a[m + k * lda] * src_b[k + n * ldb];
         }
         dst[m + n * ldc] = temp;
      }
   }
}


void neom_v3_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc)                                 //输出dst指针 步长
{
   int a_stride_m = transa == 'n' ? 1 : lda;
   int a_stride_k = transa == 'n' ? lda : 1;
   int b_stride_k = transb == 'n' ? 1 : ldb;
   int b_stride_n = transb == 'n' ? ldb : 1;
   int m;
   int n;
   int k;

   for(n=0;n<=N-4;n+=4) 
   {
      for(m=0;m<=M-4;m+=4) 
      {
         float32x4_t buf0 = vdupq_n_f32(0.0f);
         float32x4_t buf1 = vdupq_n_f32(0.0f);
         float32x4_t buf2 = vdupq_n_f32(0.0f);
         float32x4_t buf3 = vdupq_n_f32(0.0f);
         for(k=0;k<=K-4;k+=4) 
         {
            float32x4_t va0 = vld1q_f32(&src_a[m + k * lda]);
            float32x4_t va1 = vld1q_f32(&src_a[m + (k + 1) * lda]);
            float32x4_t va2 = vld1q_f32(&src_a[m + (k + 2) * lda]);
            float32x4_t va3 = vld1q_f32(&src_a[m + (k + 3) * lda]);
            
            register float vb00 = src_b[k + n * ldb];
            register float vb01 = src_b[k + 1 + n * ldb];
            register float vb02 = src_b[k + 2 + n * ldb];
            register float vb03 = src_b[k + 3 + n * ldb];

            register float vb10 = src_b[k + (n + 1) * ldb];
            register float vb11 = src_b[k + 1 + (n + 1) * ldb];
            register float vb12 = src_b[k + 2 + (n + 1) * ldb];
            register float vb13 = src_b[k + 3 + (n + 1) * ldb];

            register float vb20 = src_b[k + (n + 2) * ldb];
            register float vb21 = src_b[k + 1 + (n + 2) * ldb];
            register float vb22 = src_b[k + 2 + (n + 2) * ldb];
            register float vb23 = src_b[k + 3 + (n + 2) * ldb];

            register float vb30 = src_b[k + (n + 3) * ldb];
            register float vb31 = src_b[k + 1 + (n + 3) * ldb];
            register float vb32 = src_b[k + 2 + (n + 3) * ldb];
            register float vb33 = src_b[k + 3 + (n + 3) * ldb];

            buf0 = vmlaq_n_f32(buf0, va0, vb00);
            buf0 = vmlaq_n_f32(buf0, va1, vb01);
            buf0 = vmlaq_n_f32(buf0, va2, vb02);
            buf0 = vmlaq_n_f32(buf0, va3, vb03);

            buf1 = vmlaq_n_f32(buf1, va0, vb10);
            buf1 = vmlaq_n_f32(buf1, va1, vb11);
            buf1 = vmlaq_n_f32(buf1, va2, vb12);
            buf1 = vmlaq_n_f32(buf1, va3, vb13);

            buf2 = vmlaq_n_f32(buf2, va0, vb20);
            buf2 = vmlaq_n_f32(buf2, va1, vb21);
            buf2 = vmlaq_n_f32(buf2, va2, vb22);
            buf2 = vmlaq_n_f32(buf2, va3, vb23);

            buf3 = vmlaq_n_f32(buf3, va0, vb30);
            buf3 = vmlaq_n_f32(buf3, va1, vb31);
            buf3 = vmlaq_n_f32(buf3, va2, vb32);
            buf3 = vmlaq_n_f32(buf3, va3, vb33);
         }
         for(;k<K;k++) 
         {
            float32x4_t va = vld1q_f32(&src_a[m + k * lda]);
            register float vb0 = src_b[k + n * ldb];
            register float vb1 = src_b[k + (n + 1) * ldb];
            register float vb2 = src_b[k + (n + 2) * ldb];
            register float vb3 = src_b[k + (n + 3) * ldb];
            buf0 = vmlaq_n_f32(buf0, va, vb0);
            buf1 = vmlaq_n_f32(buf1, va, vb1);
            buf2 = vmlaq_n_f32(buf2, va, vb2);
            buf3 = vmlaq_n_f32(buf3, va, vb3);           
         }
         vst1q_f32(&dst[m + n * ldc], buf0);
         vst1q_f32(&dst[m + (n + 1) * ldc], buf1);
         vst1q_f32(&dst[m + (n + 2) * ldc], buf2);
         vst1q_f32(&dst[m + (n + 3) * ldc], buf3);
      }
      //行数不是4的倍数时的补充
      for (; m < M; m++)
      {
         register float dst_0, dst_1, dst_2, dst_3;
         dst_0 = 0;
         dst_1 = 0;
         dst_2 = 0;
         dst_3 = 0;
         for(int k=0;k<K;k++)
         {
            dst_0 = dst_0 + src_a[m + k * lda] * src_b[k + n * ldb];
            dst_1 = dst_1 + src_a[m + k * lda] * src_b[k + (n + 1) * ldb];
            dst_2 = dst_2 + src_a[m + k * lda] * src_b[k + (n + 2) * ldb];
            dst_3 = dst_3 + src_a[m + k * lda] * src_b[k + (n + 3) * ldb];
         }
         dst[m + n * ldc] = dst_0;
         dst[m + (n + 1) * ldc] = dst_1;
         dst[m + (n + 2) * ldc] = dst_2;
         dst[m + (n + 3) * ldc] = dst_3;
      }
   }
   //列数不是4的倍数时的补充
   for (; n < N; n++)
   {
      for(m=0;m<=M-4;m+=4) 
      {
         float32x4_t buf = vdupq_n_f32(0.0f);
         for(int k=0;k<K;k++) 
         {
            float32x4_t va = vld1q_f32(&src_a[m + k * lda]);
            register float vb = src_b[k + n * ldb];
            buf = vmlaq_n_f32(buf, va, vb);
         }
         vst1q_f32(&dst[m + n * ldc], buf);
      }
      for(;m<=M;m++)
      {
         float temp = 0;
         for (int k = 0; k < K; k++){
            temp += src_a[m + k * lda] * src_b[k + n * ldb];
         }
         dst[m + n * ldc] = temp;
      }
   }
}

void neom_v4_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc)                                 //输出dst指针 步长
{
   int a_stride_m = transa == 'n' ? 1 : lda;
   int a_stride_k = transa == 'n' ? lda : 1;
   int b_stride_k = transb == 'n' ? 1 : ldb;
   int b_stride_n = transb == 'n' ? ldb : 1;
   int m;
   int n;

   for(n=0;n<=N-8;n+=8) 
   {
      for(m=0;m<=M-8;m+=8) 
      {
         float32x4_t buf00 = vdupq_n_f32(0.0f);
         float32x4_t buf01 = vdupq_n_f32(0.0f);
         float32x4_t buf02 = vdupq_n_f32(0.0f);
         float32x4_t buf03 = vdupq_n_f32(0.0f);
         float32x4_t buf04 = vdupq_n_f32(0.0f);
         float32x4_t buf05 = vdupq_n_f32(0.0f);
         float32x4_t buf06 = vdupq_n_f32(0.0f);
         float32x4_t buf07 = vdupq_n_f32(0.0f);
         float32x4_t buf40 = vdupq_n_f32(0.0f);
         float32x4_t buf41 = vdupq_n_f32(0.0f);
         float32x4_t buf42 = vdupq_n_f32(0.0f);
         float32x4_t buf43 = vdupq_n_f32(0.0f);
         float32x4_t buf44 = vdupq_n_f32(0.0f);
         float32x4_t buf45 = vdupq_n_f32(0.0f);
         float32x4_t buf46 = vdupq_n_f32(0.0f);
         float32x4_t buf47 = vdupq_n_f32(0.0f);
         for(int k=0;k<K;k++) 
         {
            float32x4_t va0 = vld1q_f32(&src_a[m + k * lda]);
            float32x4_t va1 = vld1q_f32(&src_a[m + 4 + k * lda]);
            register float vb0 = src_b[k + n * ldb];
            register float vb1 = src_b[k + (n + 1) * ldb];
            register float vb2 = src_b[k + (n + 2) * ldb];
            register float vb3 = src_b[k + (n + 3) * ldb];
            register float vb4 = src_b[k + (n + 4) * ldb];
            register float vb5 = src_b[k + (n + 5) * ldb];
            register float vb6 = src_b[k + (n + 6) * ldb];
            register float vb7 = src_b[k + (n + 7) * ldb];
            buf00 = vmlaq_n_f32(buf00, va0, vb0);
            buf01 = vmlaq_n_f32(buf01, va0, vb1);
            buf02 = vmlaq_n_f32(buf02, va0, vb2);
            buf03 = vmlaq_n_f32(buf03, va0, vb3);
            buf04 = vmlaq_n_f32(buf04, va0, vb4);
            buf05 = vmlaq_n_f32(buf05, va0, vb5);
            buf06 = vmlaq_n_f32(buf06, va0, vb6);
            buf07 = vmlaq_n_f32(buf07, va0, vb7);
            buf40 = vmlaq_n_f32(buf40, va1, vb0);
            buf41 = vmlaq_n_f32(buf41, va1, vb1);
            buf42 = vmlaq_n_f32(buf42, va1, vb2);
            buf43 = vmlaq_n_f32(buf43, va1, vb3);
            buf44 = vmlaq_n_f32(buf44, va1, vb4);
            buf45 = vmlaq_n_f32(buf45, va1, vb5);
            buf46 = vmlaq_n_f32(buf46, va1, vb6);
            buf47 = vmlaq_n_f32(buf47, va1, vb7);

         }
         vst1q_f32(&dst[m + n * ldc], buf00);
         vst1q_f32(&dst[m + (n + 1) * ldc], buf01);
         vst1q_f32(&dst[m + (n + 2) * ldc], buf02);
         vst1q_f32(&dst[m + (n + 3) * ldc], buf03);
         vst1q_f32(&dst[m + (n + 4) * ldc], buf04);
         vst1q_f32(&dst[m + (n + 5) * ldc], buf05);
         vst1q_f32(&dst[m + (n + 6) * ldc], buf06);
         vst1q_f32(&dst[m + (n + 7) * ldc], buf07);
         vst1q_f32(&dst[m + 4 + n * ldc], buf40);
         vst1q_f32(&dst[m + 4 + (n + 1) * ldc], buf41);
         vst1q_f32(&dst[m + 4 + (n + 2) * ldc], buf42);
         vst1q_f32(&dst[m + 4 + (n + 3) * ldc], buf43);
         vst1q_f32(&dst[m + 4 + (n + 4) * ldc], buf44);
         vst1q_f32(&dst[m + 4 + (n + 5) * ldc], buf45);
         vst1q_f32(&dst[m + 4 + (n + 6) * ldc], buf46);
         vst1q_f32(&dst[m + 4 + (n + 7) * ldc], buf47);
      }
      //行数不是8的倍数时的补充
      for (; m < M; m++)
      {
         register float dst_0, dst_1, dst_2, dst_3, dst_4, dst_5, dst_6, dst_7;
         dst_0 = 0;
         dst_1 = 0;
         dst_2 = 0;
         dst_3 = 0;
         dst_4 = 0;
         dst_5 = 0;
         dst_6 = 0;
         dst_7 = 0;         
         for(int k=0;k<K;k++)
         {
            dst_0 = dst_0 + src_a[m + k * lda] * src_b[k + n * ldb];
            dst_1 = dst_1 + src_a[m + k * lda] * src_b[k + (n + 1) * ldb];
            dst_2 = dst_2 + src_a[m + k * lda] * src_b[k + (n + 2) * ldb];
            dst_3 = dst_3 + src_a[m + k * lda] * src_b[k + (n + 3) * ldb];
            dst_4 = dst_4 + src_a[m + k * lda] * src_b[k + (n + 4) * ldb];
            dst_5 = dst_5 + src_a[m + k * lda] * src_b[k + (n + 5) * ldb];
            dst_6 = dst_6 + src_a[m + k * lda] * src_b[k + (n + 6) * ldb];
            dst_7 = dst_7 + src_a[m + k * lda] * src_b[k + (n + 7) * ldb];            
         }
         dst[m + n * ldc] = dst_0;
         dst[m + (n + 1) * ldc] = dst_1;
         dst[m + (n + 2) * ldc] = dst_2;
         dst[m + (n + 3) * ldc] = dst_3;
         dst[m + (n + 4) * ldc] = dst_4;
         dst[m + (n + 5) * ldc] = dst_5;
         dst[m + (n + 6) * ldc] = dst_6;
         dst[m + (n + 7) * ldc] = dst_7;         
      }
   }
   //列数不是4的倍数时的补充
   for (; n < N; n++)
   {
      for(m=0;m<=M-8;m+=8) 
      {
         float32x4_t buf0 = vdupq_n_f32(0.0f);
         float32x4_t buf1 = vdupq_n_f32(0.0f);
         for(int k=0;k<K;k++) 
         {
            float32x4_t va0 = vld1q_f32(&src_a[m + k * lda]);
            float32x4_t va1 = vld1q_f32(&src_a[m + 4 + k * lda]);
            register float vb = src_b[k + n * ldb];
            buf0 = vmlaq_n_f32(buf0, va0, vb);
            buf1 = vmlaq_n_f32(buf1, va1, vb);
         }
         vst1q_f32(&dst[m + n * ldc], buf0);
         vst1q_f32(&dst[m + 4 + n * ldc], buf1);
      }
      for(;m<=M;m++)
      {
         float temp = 0;
         for (int k = 0; k < K; k++){
            temp += src_a[m + k * lda] * src_b[k + n * ldb];
         }
         dst[m + n * ldc] = temp;
      }
   }
}