# include"lib/uint8.h"
void mat_init(float *max_mat1, float *min_mat1, float *max_mat2, float *min_mat2)
{
   *max_mat1 = 0.0;
   *min_mat1 = 0.0;
   *max_mat2 = 0.0;
   *min_mat2 = 0.0;
   for(int i = 0; i < FK; i++)
   {
      for(int j = 0; j < FM; j++)
      {
         *(ori_mat1 + i * FM + j) = (float)rand() / RAND_MAX;
         if(*(ori_mat1 + i * FM + j) > *max_mat1)
            *max_mat1 = *(ori_mat1 + i * FM + j);
         else if(*(ori_mat1 + i * FM + j) < *min_mat1)
            *min_mat1 = *(ori_mat1 + i * FM + j);
      }
   }
   for(int i = 0; i < FN; i++)
   {
      for(int j = 0; j < FK; j++)
      {
         *(ori_mat2 + i * FK + j) = (float)rand() / RAND_MAX;
         if(*(ori_mat2 + i * FM + j) > *max_mat2)
            *max_mat2 = *(ori_mat2 + i * FM + j);
         else if(*(ori_mat2 + i * FM + j) < *min_mat2)
            *min_mat2 = *(ori_mat2 + i * FM + j);
      }
   }
}

void mat_assign_u8(uint8_t *mat_in, int Row, int Col)
{
   for(int i = 0; i < Row; i++)
   {
      for(int j = 0; j < Col; j++)
      {
         *(mat_in + i * Row + j) = (uint8_t)rand();
      }
   }   
}


void quantization(float *mat_in1, float *mat_in2, 
                  uint8_t *mat_out1, uint8_t *mat_out2, 
                  int Raw1, int Col1, int Raw2, int Col2,
                  float *max_mat1, float *min_mat1, float *max_mat2, float *min_mat2)
{
   for(int i = 0; i < Col1; i++)
   {
      for(int j = 0; j < Raw1; j++)
      {
         *(mat_out1 + i * Raw1 + j) = (uint8_t)round((*(mat_in1 + i * Raw1 + j) - *min_mat1) / (*max_mat1 - *min_mat1) * (255));
      }
   }

   for(int i = 0; i < Col2; i++)
   {
      for(int j = 0; j < Raw2; j++)
      {
         *(mat_out2 + i * Raw2 + j) = (uint8_t)round((*(mat_in2 + i * Raw2 + j) - *min_mat2) / (*max_mat2 - *min_mat2) * (255));
      }
   }   
}

void mat_seg(uint8_t *u8_mat1, uint8_t *u8_mat2, 
            uint8_t *u8_mat1_0, uint8_t *u8_mat1_1, 
            uint8_t *u8_mat2_0, uint8_t *u8_mat2_1, 
            int Raw1, int Col1, int Raw2, int Col2)
{
   for(int i = 0; i < Col1; i++)
   {
      for(int j = 0; j < Raw1; j++)
      {
         if(j < Raw1 / 2)
            u8_mat1_0[j + i * (Raw1 / 2)] = u8_mat1[j + i * Raw1];
         else if(j >= Raw1 / 2)
            u8_mat1_1[j - Raw1 / 2 + i * (Raw1 / 2)] = u8_mat1[j + i * Raw1];
      }
   }
   for(int i = 0; i < Col2; i++)
   {
      for(int j = 0; j < Raw2; j++)
      {
         if(i < Col2 / 2)
            u8_mat2_0[j + i * Raw2] = u8_mat2[j + i * Raw2];
         else if(i >= Col2 / 2)
            u8_mat2_1[j + (i - Col2 / 2) * Raw2] = u8_mat2[j + i * Raw2];
      }
   }   
}
void mat_mont(uint16_t *u16_out_mat,
            uint16_t *u16_out_mat0, uint16_t *u16_out_mat1, 
            uint16_t *u16_out_mat2, uint16_t *u16_out_mat3, 
            int Raw, int Col)
{
   for(int i = 0; i < Col; i++)
   {
      for(int j = 0; j < Raw; j++)
      {
         if(i < Col / 2 && j < Raw / 2)
            u16_out_mat[j + i * Raw] = u16_out_mat0[j + i * Raw / 2];
         else if(i < Col / 2 && j >= Raw / 2)
            u16_out_mat[j + i * Raw] = u16_out_mat1[j - Raw / 2 + i * Raw / 2];
         else if(i >= Col / 2 && j < Raw / 2)
            u16_out_mat[j + i * Raw] = u16_out_mat2[j + (i - Col / 2) * Raw / 2];
         else if(i >= Col / 2 && j >= Raw / 2)
            u16_out_mat[j + i * Raw] = u16_out_mat3[j - Raw / 2 + (i - Col / 2) * Raw / 2];
      }
   }
}

// 包装函数，用于调用 sgemm 函数
void *run_sgemm(void *args) 
{
    struct sgemm_args *sgemm_args = (struct sgemm_args *)args;
    optim_uint8_col_major_sgemm(
        sgemm_args->transa, sgemm_args->transb, 
        sgemm_args->m, sgemm_args->n, sgemm_args->k, 
        sgemm_args->alpha, sgemm_args->a, sgemm_args->lda, 
        sgemm_args->b, sgemm_args->ldb, 
        sgemm_args->beta, sgemm_args->c, sgemm_args->ldc
    );
    pthread_exit(NULL);
}

void Multi_Thread(int mat_M, int mat_K, int mat_N,
                  uint8_t *t_mat1_0, uint8_t *t_mat1_1, uint8_t *t_mat2_0, uint8_t *t_mat2_1,
                  uint16_t *t_out_mat0, uint16_t *t_out_mat1, uint16_t *t_out_mat2, uint16_t *t_out_mat3)
{
   pthread_t threads[4];
   int rc;
   struct sgemm_args args[4];       //避免多线程多次赋值问题
   for (int t = 0; t < 4; t++)
   {
      args[t].transa = 'n';
      args[t].transb = 'n';
      args[t].m = mat_M / 2;
      args[t].n = mat_N / 2;
      args[t].k = mat_K;
      args[t].alpha = 1.0;
      args[t].lda = mat_M / 2;
      args[t].ldb = mat_K;
      args[t].beta = 0.0;
      args[t].ldc = mat_M / 2;
      switch (t)
      {
      case 0:
         args[t].a = t_mat1_0;
         args[t].b = t_mat2_0;
         args[t].c = t_out_mat0;
         break;
      case 1:
         args[t].a = t_mat1_1;
         args[t].b = t_mat2_0;
         args[t].c = t_out_mat1;     
         break; 
      case 2:
         args[t].a = t_mat1_0;
         args[t].b = t_mat2_1;
         args[t].c = t_out_mat2;    
         break;
      case 3:
         args[t].a = t_mat1_1;
         args[t].b = t_mat2_1;
         args[t].c = t_out_mat3;
         break;
      default:
         break;
      }
      pthread_attr_t attr;                                               //设置线程属性（包括栈大小、调度策略等）     
      size_t stacksize = 512 * 1024 * 1024;
      pthread_attr_init(&attr);
      pthread_attr_setstacksize(&attr, stacksize);
      rc = pthread_create(&threads[t], &attr, run_sgemm, (void*)&args[t]);
      if (rc) 
      {
         printf("ERROR: return code from pthread_create() is %d\n", rc);
         exit(-1);
      }
   }
   pthread_join(threads[0], NULL); // 等待线程结束
   pthread_join(threads[1], NULL); // 等待线程结束
   pthread_join(threads[2], NULL); // 等待线程结束
   pthread_join(threads[3], NULL); // 等待线程结束
}


void QMat_Mul(uint8_t *Qmat_in1, uint8_t *Qmat_in2, int Q_M, int Q_K, int Q_N,
            uint8_t *Qmat_in1_0, uint8_t *Qmat_in1_1, uint8_t *Qmat_in2_0, uint8_t *Qmat_in2_1,
            uint16_t *Qmat_out0, uint16_t *Qmat_out1, uint16_t *Qmat_out2, uint16_t *Qmat_out3,
            uint16_t *Qmat_out)
{
   mat_seg(Qmat_in1, Qmat_in2, Qmat_in1_0, Qmat_in1_1, Qmat_in2_0, Qmat_in2_1, Q_M, Q_K, Q_K ,Q_N);
   Multi_Thread(Q_M, Q_K, Q_N, Qmat_in1_0, Qmat_in1_1, Qmat_in2_0, Qmat_in2_1, 
               Qmat_out0, Qmat_out1, Qmat_out2, Qmat_out3);
   mat_mont(Qmat_out, Qmat_out0, Qmat_out1, Qmat_out2, Qmat_out3, Q_M, Q_N);
}

void GFLOPS_Compute(int warm_times, int infer_times)
{
   for(int i = 0; i < warm_times; i++)
   {
      quantization(ori_mat1, ori_mat2, matA, matW, 
            FM, FK, FK, FN, &max_mat1, &min_mat1, &max_mat2, &min_mat2);
      QMat_Mul(matA, matW, FM, FK, FN, matA_0, matA_1, matW_0, matW_1,
               matB0, matB1, matB2, matB3, matB);
      QMat_Mul(matGB, matWT, FM, FN, FK, matGB_0, matGB_1, matWT_0, matWT_1,
               matGA0, matGA1, matGA2, matGA3, matGA);
      QMat_Mul(matAT, matGB, FK, FM, FN, matAT_0, matAT_1, matGB_0, matGB_1,
               matGW0, matGW1, matGW2, matGW3, matGW);
      // optim_uint8_col_major_sgemm('n', 'n', FM, FN, FK, 1, matA, FM, matW, FK, 0, matB, FM);
   }
   gettimeofday(&start, NULL);
   for(int i = 0; i < infer_times; i++)
   {
      quantization(ori_mat1, ori_mat2, matA, matW, 
            FM, FK, FK, FN, &max_mat1, &min_mat1, &max_mat2, &min_mat2);
      QMat_Mul(matA, matW, FM, FK, FN, matA_0, matA_1, matW_0, matW_1,
               matB0, matB1, matB2, matB3, matB);
      QMat_Mul(matGB, matWT, FM, FN, FK, matGB_0, matGB_1, matWT_0, matWT_1,
               matGA0, matGA1, matGA2, matGA3, matGA);
      QMat_Mul(matAT, matGB, FK, FM, FN, matAT_0, matAT_1, matGB_0, matGB_1,
               matGW0, matGW1, matGW2, matGW3, matGW);
      // optim_uint8_col_major_sgemm('n', 'n', FM, FN, FK, 1, matA, FM, matW, FK, 0, matB, FM);
   }
   gettimeofday(&end, NULL);
   seconds = end.tv_sec - start.tv_sec;
   microseconds = end.tv_usec - start.tv_usec;
   cpu_time_used = seconds + microseconds / 1000000.0;
   double flops = 3.0 * 2.0 * FM * FN * FK * 1.0e-09;
   flops = flops * infer_times / cpu_time_used;
   printf("optim Gflops: %f \ncpu_time_used: %f\n", flops, cpu_time_used);
}

void test_matmul(double precision, uint8_t *in_mat1, uint8_t *in_mat2, 
               uint16_t *test_mat, uint16_t *out_mat, int TM, int TK, int TN)
{
   double max_error = 0;
   char* test_flag = "Success";
   for(int i = 0; i < TM; i++)
   {
      for(int j = 0; j < TN; j++)
      {
         *(test_mat + j * TM + i) = 0;
         for(int k = 0; k < TK; k++)
         {
            *(test_mat + j * TM + i) += (uint16_t)*(in_mat1 + k * TM + i) * *(in_mat2 + j * TK + k);
         }
         if(*(test_mat + j * TM + i) != *(out_mat + j * TM + i))
         {
            if((*(test_mat + j * TM + i) - *(out_mat + j * TM + i)) / *(test_mat + j * TM + i) > max_error)
            {
               max_error = (double)(*(test_mat + j * TM + i) - *(out_mat + j * TM + i)) / *(test_mat + j * TM + i);
            }
            else if((*(out_mat + j * TM + i) - *(test_mat + j * TM + i)) / *(test_mat + j * TM + i) > max_error)
            {
               max_error = (double)(*(out_mat + j * TM + i) - *(test_mat + j * TM + i)) / *(test_mat + j * TM + i);
            }
         }
      }
   }

   if(max_error > precision)
   {
      test_flag = "Failed";
   }
   printf("relative error =  %g\n", max_error);
   printf("Precision = %g\n", precision);
   printf("Test %s!\n", test_flag);
}


int main()
{
   mat_init(&max_mat1, &min_mat1, &max_mat2, &min_mat2);
   mat_assign_u8(matGB, FM, FN);
   mat_assign_u8(matWT, FN, FK);
   mat_assign_u8(matAT, FK, FM);
   GFLOPS_Compute(10, 20);
   // test_matmul(5e-5, matA, matW, test_mat, matB, FM, FK, FN);
   // test_matmul(5e-5, matGB, matWT, test_mat, matGA, FM, FN, FK);
   // test_matmul(5e-5, matAT, matGB, test_mat, matGW, FK, FM, FN);
   // print_mat_u16(test_mat, M, N, "test_mat");
   // print_mat_u16(out_mat, M, N, "out_mat");
   return 0;
}