#ifndef __UTIL_H__
#define __UTIL_H__
#include <stdio.h> 
#include <stdlib.h>
#include "arm_neon.h"
#endif

void naive_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc);
//利用128位定长寄存器对mat1四个一组计算
void neom_v1_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc);
//对mat2同时4列访问
void neom_v2_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc);
//对k的循环四个四个进行展开
void neom_v3_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc);

//4*4扩展为8*8
void neom_v4_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc);

//packing
// #define T 2048
// float packedA[T * T];
void optim_v1_col_major_sgemm(
      char transa,                                          //a是否转置
      char transb,                                          //b是否转置
      int M, int N, int K,                                  //input size
      const float alpha,                                    //偏置
      const float * src_a, int lda,                         //输入a指针 步长
      const float * src_b, int ldb,                         //输入b指针 步长
      const float beta,                                     //偏置
      float * dst, int ldc);

//uint8类型矩阵乘
void optim_uint8_col_major_sgemm(
    char transa,                                          //a是否转置
    char transb,                                          //b是否转置
    int M, int N, int K,                                  //input size
    const float alpha,                                    //偏置
    const uint8_t * src_a, int lda,                         //输入a指针 步长
    const uint8_t * src_b, int ldb,                         //输入b指针 步长
    const float beta,                                     //偏置
    uint16_t * dst, int ldc); 

//量化优化方法
void optim_QMat_col_major_sgemm(
    char transa,                                          //a是否转置
    char transb,                                          //b是否转置
    int M, int N, int K,                                  //input size
    const float alpha,                                    //偏置
    const uint8_t * src_a, int lda,                         //输入a指针 步长
    const uint8_t * src_b, int ldb,                         //输入b指针 步长
    const float beta,                                     //偏置
    uint16_t * dst, int ldc, int * QShift); 
