/* LUT-GEMM
 * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tests.h"

#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>

int64_t rand_int64() {
    int64_t high = (int64_t)rand();
    int64_t low = (int64_t)rand();
    return (high << 32) | low;
}


template <typename T, typename S>
inline cublasStatus_t cublas_gemm_ex(T *A,  T *B,  S *C,
                                    int m, int n, int k);
                                    
template<int M, int N, int K, int num_codebook, int len_vector, int A_GROUP_SIZE=K>
class int3_col_wise_matmul_fp16{
public:
    int16_t     qidxs[N][K/len_vector];
    int64_t    codebook[256];

    float    input[M][K];
    float   output[M][N];


    __half*  d_input;
    int16_t*  d_qidxs;
    int64_t*  d_codebook;

    float* d_nq_output;

    double run(
            bool run_cublas=true, bool run_lutgemm=false, bool run_gptq=false,
            int iter=16){
        
        alloc_cuda();
        makeRandomInput();
        makeRandomWeight();
        makeRandomCodebook();
        copy_cpuToCuda();

        cudaDeviceSynchronize();

        quip_sharp_latency(M, N, K,
            d_input, d_qidxs, d_codebook, d_nq_output, iter);

        free_cuda();
        return 0;
    }

    void quip_sharp_latency(int m, int n, int k, 
        __half* A, int16_t *B, int64_t *C, float *D, int iter=64){
        
        timer tm;
        codeGEMM::matmul_quip_sharp((void*)D, (void*)A, (void*)B, (void*)C, m, n, k);
        cudaDeviceSynchronize();

        for(int i=0;i<iter;i++){
            tm.start();
            codeGEMM::matmul_quip_sharp((void*)D, (void*)A, (void*)B, (void*)C, m, n, k);
            cudaDeviceSynchronize();
            tm.end();
        }
        printf("latency min : %.5fms, max : %.5fms, avg:%.5f\n", tm.min(), tm.max(), tm.mean());
    }

    void makeRandomInput(){
        for(int m=0;m<M;m++)
            for(int k=0;k<K;k++)
                input[m][k] = rand_fp32(); // (-1.0, 1.0) / 2^b
                //input[m][k] = 0.01; // (-1.0, 1.0) / 2^b
    }

    void makeRandomCodebook(){
        for(int i=0;i<256;i++)
            codebook[i] = rand_int64()>>1;
    }

    void makeRandomWeight(){
        for(int n=0; n<N; n++){
            for(int k=0; k<K/len_vector; k++){ 
                qidxs[n][k] = (rand() >> 2) - 0x8000;
            }
        }
    }

    void alloc_cuda(){
        cudaMallocManaged(&d_input    , sizeof(float) * M * K);   
        cudaMallocManaged(&d_qidxs, sizeof(int16_t) * N * K / len_vector);   
        cudaMallocManaged(&d_codebook, sizeof(int64_t) * 256);   
        cudaMallocManaged(&d_nq_output, sizeof(float) * M * N);

    }
    
    void free_cuda(){
        cudaFree(d_input);
        cudaFree(d_qidxs);
        cudaFree(d_codebook);
        cudaFree(d_nq_output);
    }
    void copy_cpuToCuda(){
        
        fhCpy(d_input , (float*)input , M * K);
        aCpy(d_qidxs, (int16_t*)qidxs, N * K / len_vector);
        bCpy(d_codebook, (int64_t*)codebook, 256);

        cudaDeviceSynchronize();
    }

    void fhCpy(__half* a, float* b, int size){
        for(int i=0;i<size;i++) a[i] = __float2half(b[i]);
     }

    void ffCpy(float* a, float* b, int size){
        for(int i=0;i<size;i++) a[i] = (b[i]);
     }
     void aCpy(int16_t* a, int16_t* b, int size){
        for(int i=0;i<size;i++) a[i] = (b[i]);
     }
     void bCpy(int64_t* a, int64_t* b, int size){
        for(int i=0;i<size;i++) a[i] = (b[i]);
     }

};
// Llama-3.1-8B
// const int H = 4096;
// const int A = 6144;
// const int I = 14336;

// Llama-3.2-1B
// const int H = 2048;
// const int A = 3072;
// const int I = 8192;

// Llama-3.2-3B
const int H = 3072;
const int A = 5120;
const int I = 8192;

// Llama-3.1-70B
// const int H = 8192;
// const int A = 10240;
// const int I = 28672;

const int M = 2;
const int V = 8;


TEST(int3_col_wise_matmul_fp16, layer_175b){
    double total_error = 0;
    int e_cnt = 0;
    printf("Start!\n");
    printf("M = 1, N = %d, K = %d\n", H, H);
    // if(true){
    if(false){
    printf("QuIP# group size (BS=1)\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, H, M, V, 128>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, H, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, A, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, A, H, M, V, 128>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, A, H, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, I, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, I, M, V, 128>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, I, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, I, H*1, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, I, H*1, M, V, 128>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, I, H*1, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    } else{

    printf("QuIP batch size (g=1024)\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  H*1, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  H*1, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  H*1, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, H*1, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("****************************************************************\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  A, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  A, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  A, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, A, H, M, V, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("****************************************************************\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  H*1, I, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  H*1, I, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  H*1, I, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, H*1, I, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("****************************************************************\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  I, H*1, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  I, H*1, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  I, H*1, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, I, H*1, M, V, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("****************************************************************\n");
    }

}




template <typename T, typename S>
inline cublasStatus_t cublas_gemm_ex(T *A,  T *B,  S *C,
                                    int m, int n, int k) {
    static S alpha = 1;
    static S beta  = 0;
    static cublasHandle_t handle = nullptr;
    if(handle == nullptr) cublasCreate(&handle);
    
    cudaDataType_t AType, BType, CType;
    cublasComputeType_t  ComputeType;
    if (std::is_same<T, float>::value) {
        AType = BType = CType = CUDA_R_32F;
        ComputeType = CUBLAS_COMPUTE_32F_FAST_TF32;
    } else if (std::is_same<T, __half>::value) {
        AType = BType = CType = CUDA_R_16F;
        ComputeType = CUBLAS_COMPUTE_16F;
    } else if (std::is_same<T, int8_t>::value) {
        AType = BType = CUDA_R_8I;
        CType = CUDA_R_32I;
        ComputeType = CUBLAS_COMPUTE_32I;
    } else {
        return CUBLAS_STATUS_NOT_SUPPORTED;
    }
    return cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N,
                          n, m, k, 
                          &alpha,
                          B, BType, n,
                          A, AType, k,
                          &beta,
                          C, CType, n,
                          ComputeType,
                          CUBLAS_GEMM_DFALT);
}
