# include"lib/main.h"
void mat_init()
{
   for(int i = 0; i < K; i++)
   {
      for(int j = 0; j < M; j++)
      {
         *(mat1 + i * M + j) = (float)rand() / RAND_MAX;
      }
   }

   for(int i = 0; i < N; i++)
   {
      for(int j = 0; j < K; j++)
      {
         *(mat2 + i * K + j) = (float)rand() / RAND_MAX;
      }
   }

   for(int i = 0; i < N; i++)
   {
      for(int j = 0; j < M; j++)
      {
         *(out_mat + i * M + j) = 0;
         *(test_mat + i * M + j) = 0;
      }
   }
}

void GFLOPS_Compute(int warm_times, int infer_times)
{
   // for(int i = 0; i < warm_times; i++){optim_v1_col_major_sgemm('n', 'n', M, N, K, 1, mat1, M, mat2, K, 0, out_mat, M);}
   // gettimeofday(&start, NULL);
   // for(int i = 0; i < infer_times; i++){optim_v1_col_major_sgemm('n', 'n', M, N, K, 1, mat1, M, mat2, K, 0, out_mat, M);}
   for(int i = 0; i < warm_times; i++){cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, 1, mat1, M, mat2, K, 0, out_mat, M);}
   gettimeofday(&start, NULL);
   for(int i = 0; i < infer_times; i++){cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, 1, mat1, M, mat2, K, 0, out_mat, M);}
   gettimeofday(&end, NULL);
   seconds = end.tv_sec - start.tv_sec;
   microseconds = end.tv_usec - start.tv_usec;
   cpu_time_used = seconds + microseconds / 1000000.0;
   double flops = 2.0 * M * N * K * 1.0e-09;
   flops = flops * infer_times / cpu_time_used;
   printf("optim Gflops: %f \ncpu_time_used: %f\n", flops, cpu_time_used);
}

void test_matmul(double precision)
{
   double max_error = 0;
   char* test_flag = "Success";
   for(int i = 0; i < M; i++)
   {
      for(int j = 0; j < N; j++)
      {
         *(test_mat + j * M + i) = 0;
         for(int k = 0; k < K; k++)
         {
            *(test_mat + j * M + i) += *(mat1 + k * M + i) * *(mat2 + j * K + k);
         }
         if(*(test_mat + j * M + i) != *(out_mat + j * M + i))
         {
            if((*(test_mat + j * M + i) - *(out_mat + j * M + i)) / *(test_mat + j * M + i) > max_error)
            {
               max_error = (double)(*(test_mat + j * M + i) - *(out_mat + j * M + i)) / *(test_mat + j * M + i);
            }
            else if((*(out_mat + j * M + i) - *(test_mat + j * M + i)) / *(test_mat + j * M + i) > max_error)
            {
               max_error = (double)(*(out_mat + j * M + i) - *(test_mat + j * M + i)) / *(test_mat + j * M + i);
            }
         }
      }
   }

   if(max_error > precision)
   {
      test_flag = "Failed";
   }
   printf("relative error =  %g\n", max_error);
   printf("Precision = %g\n", precision);
   printf("Test %s!\n", test_flag);
}

void print_mat()
{
   printf("test_mat: \n");
   for(int i = 0; i < M; i++)
   {
      for(int j = 0; j < N; j++)
      {
         printf("%f ", *(test_mat + j * M + i));
      }
      printf("\n");
   }
   printf("\n");
   printf("out_mat: \n");
   for(int i = 0; i < M; i++)
   {
      for(int j = 0; j < N; j++)
      {
         printf("%f ", *(out_mat + j * M + i));
      }
      printf("\n");
   }   
}

int main()
{
   // goto_set_num_threads(1);
   // openblas_set_num_threads(1);
   mat_init();
   GFLOPS_Compute(10, 20);
   // test_matmul(5e-5);
   // print_mat();
   return 0;
}