# include"lib/Multi.h"
void mat_init()
{
   for(int i = 0; i < K; i++)
   {
      for(int j = 0; j < M; j++)
      {
         *(mat1 + i * M + j) = (float)rand() / RAND_MAX;
      }
   }

   for(int i = 0; i < N; i++)
   {
      for(int j = 0; j < K; j++)
      {
         *(mat2 + i * K + j) = (float)rand() / RAND_MAX;
      }
   }

   for(int i = 0; i < N; i++)
   {
      for(int j = 0; j < M; j++)
      {
         *(out_mat + i * M + j) = 0;
         *(test_mat + i * M + j) = 0;
      }
   }
}

void print_mat_(float *mat_in, int Row, int Col, char *mat_name)
{
   printf("%s: \n", mat_name);
   for(int i = 0; i < Row; i++)
   {
      for(int j = 0; j < Col; j++)
      {
         printf("%f ", *(mat_in + j * Row + i));
      }
      printf("\n");
   }
   printf("\n");
}

void mat1_L2_norm(float *mat, int Raw, int Col)
{
   float mean = 0.0;
   double variance = 0.0;
   for(int i = 0; i < Col * Raw; i++)
   {
      mean += mat[i];
   }
   mean = mean / (Col * Raw);
   for(int i = 0; i < Col * Raw; i++)
   {
      variance += pow(mat[i] - mean, 2);
   }   
   variance = variance / (Col * Raw);
   for(int i = 0; i < Col * Raw; i++)
   {
      mat[i] = (mat[i] - mean) / variance;
   }
}

void mat1_L1_norm(float *mat, int Raw, int Col)
{
   float mean = 0.0;
   double variance = 0.0;
   for(int i = 0; i < Col * Raw; i++)
   {
      mean += mat[i];
   }
   mean = mean / (Col * Raw);
   for(int i = 0; i < Col * Raw; i++)
   {
      variance += abs(mat[i] - mean);
   }   
   variance = variance / (Col * Raw);
   for(int i = 0; i < Col * Raw; i++)
   {
      mat[i] = (mat[i] - mean) / variance;
   }
}

void mat_seg()
{
   for(int i = 0; i < K; i++)
   {
      for(int j = 0; j < M; j++)
      {
         if(j < M / 2)
            mat1_0[j + i * (M / 2)] = mat1[j + i * M];
         else if(j >= M / 2)
            mat1_1[j - M / 2 + i * (M / 2)] = mat1[j + i * M];
      }
   }
   for(int i = 0; i < N; i++)
   {
      for(int j = 0; j < K; j++)
      {
         if(i < N / 2)
            mat2_0[j + i * K] = mat2[j + i * K];
         else if(i >= N / 2)
            mat2_1[j + (i - N / 2) * K] = mat2[j + i * K];
      }
   }   
}
void mat_mont()
{
   for(int i = 0; i < N; i++)
   {
      for(int j = 0; j < M; j++)
      {
         if(i < N / 2 && j < M / 2)
            out_mat[j + i * M] = out_mat0[j + i * M / 2];
         else if(i < N / 2 && j >= M / 2)
            out_mat[j + i * M] = out_mat1[j - M / 2 + i * M / 2];
         else if(i >= N / 2 && j < M / 2)
            out_mat[j + i * M] = out_mat2[j + (i - N / 2) * M / 2];
         else if(i >= N / 2 && j >= M / 2)
            out_mat[j + i * M] = out_mat3[j - M / 2 + (i - N / 2) * M / 2];
      }
   }
}


struct sgemm_args {
    char transa;
    char transb;
    int m;
    int n;
    int k;
    float alpha;
    float *a;
    int lda;
    float *b;
    int ldb;
    float beta;
    float *c;
    int ldc;
};

// 包装函数，用于调用 sgemm 函数
void *run_sgemm(void *args) 
{
    struct sgemm_args *sgemm_args = (struct sgemm_args *)args;
    optim_v1_col_major_sgemm(
        sgemm_args->transa, sgemm_args->transb, 
        sgemm_args->m, sgemm_args->n, sgemm_args->k, 
        sgemm_args->alpha, sgemm_args->a, sgemm_args->lda, 
        sgemm_args->b, sgemm_args->ldb, 
        sgemm_args->beta, sgemm_args->c, sgemm_args->ldc
    );
    pthread_exit(NULL);
}

void Multi_Thread()
{
   pthread_t threads[4];
   int rc;
   struct sgemm_args args[4];       //避免多线程多次赋值问题
   for (int t = 0; t < 4; t++)
   {
      args[t].transa = 'n';
      args[t].transb = 'n';
      args[t].m = M / 2;
      args[t].n = N / 2;
      args[t].k = K;
      args[t].alpha = 1.0;
      args[t].lda = M / 2;
      args[t].ldb = K;
      args[t].beta = 0.0;
      args[t].ldc = M / 2;
      switch (t)
      {
      case 0:
         args[t].a = mat1_0;
         args[t].b = mat2_0;
         args[t].c = out_mat0;
         break;
      case 1:
         args[t].a = mat1_1;
         args[t].b = mat2_0;
         args[t].c = out_mat1;     
         break; 
      case 2:
         args[t].a = mat1_0;
         args[t].b = mat2_1;
         args[t].c = out_mat2;    
         break;
      case 3:
         args[t].a = mat1_1;
         args[t].b = mat2_1;
         args[t].c = out_mat3;
         break;
      default:
         break;
      }
      rc = pthread_create(&threads[t], NULL, run_sgemm, (void*)&args[t]);
      if (rc) 
      {
         printf("ERROR: return code from pthread_create() is %d\n", rc);
         exit(-1);
      }
   }
   pthread_join(threads[0], NULL); // 等待线程结束
   pthread_join(threads[1], NULL); // 等待线程结束
   pthread_join(threads[2], NULL); // 等待线程结束
   pthread_join(threads[3], NULL); // 等待线程结束
}

void GFLOPS_Compute(int warm_times, int infer_times)
{
   for(int i = 0; i < warm_times; i++)
   {
      // optim_v1_col_major_sgemm('n', 'n', M, N, K, 1, mat1, M, mat2, K, 0, out_mat, M);
      mat1_L2_norm(mat1, M, K);
      mat_seg();
      Multi_Thread();
      mat_mont();
   }
   gettimeofday(&start, NULL);
   for(int i = 0; i < infer_times; i++)
   {
      // optim_v1_col_major_sgemm('n', 'n', M, N, K, 1, mat1, M, mat2, K, 0, out_mat, M);
      mat1_L2_norm(mat1, M, K);
      mat_seg();
      Multi_Thread();
      mat_mont();
   }
   gettimeofday(&end, NULL);
   seconds = end.tv_sec - start.tv_sec;
   microseconds = end.tv_usec - start.tv_usec;
   cpu_time_used = seconds + microseconds / 1000000.0;
   double flops = 2.0 * M * N * K * 1.0e-09;
   flops = flops * infer_times / cpu_time_used;
   printf("optim Gflops: %f \ncpu_time_used: %f\n", flops, cpu_time_used);
}

void test_matmul(double precision)
{
   double max_error = 0;
   char* test_flag = "Success";
   for(int i = 0; i < M; i++)
   {
      for(int j = 0; j < N; j++)
      {
         *(test_mat + j * M + i) = 0;
         for(int k = 0; k < K; k++)
         {
            *(test_mat + j * M + i) += *(mat1 + k * M + i) * *(mat2 + j * K + k);
         }
         if(*(test_mat + j * M + i) != *(out_mat + j * M + i))
         {
            if((*(test_mat + j * M + i) - *(out_mat + j * M + i)) / *(test_mat + j * M + i) > max_error)
            {
               // printf("(%d,%d) ", i, j);
               max_error = (double)(*(test_mat + j * M + i) - *(out_mat + j * M + i)) / *(test_mat + j * M + i);
            }
            else if((*(out_mat + j * M + i) - *(test_mat + j * M + i)) / *(test_mat + j * M + i) > max_error)
            {
               // printf("(%d,%d) ", i, j);
               max_error = (double)(*(out_mat + j * M + i) - *(test_mat + j * M + i)) / *(test_mat + j * M + i);
            }
         }
      }
   }

   if(max_error > precision)
   {
      test_flag = "Failed";
   }
   printf("relative error =  %g\n", max_error);
   printf("Precision = %g\n", precision);
   printf("Test %s!\n", test_flag);
}

void print_mat()
{
   printf("test_mat: \n");
   for(int i = 0; i < M; i++)
   {
      for(int j = 0; j < N; j++)
      {
         printf("%f ", *(test_mat + j * M + i));
      }
      printf("\n");
   }
   printf("\n");
   printf("out_mat: \n");
   for(int i = 0; i < M; i++)
   {
      for(int j = 0; j < N; j++)
      {
         printf("%f ", *(out_mat + j * M + i));
      }
      printf("\n");
   }   
}


int main()
{
   mat_init();
   // mat_seg();
   GFLOPS_Compute(10, 20);
   // print_mat_(out_mat0, M / 2, N / 2, "out_mat0");
   // print_mat_(out_mat1, M / 2, N / 2, "out_mat1");
   // print_mat_(out_mat2, M / 2, N / 2, "out_mat2");
   // print_mat_(out_mat3, M / 2, N / 2, "out_mat3");
   // mat_mont();
   // test_matmul(5e-5);
   // print_mat();
   return 0;
}