/* LUT-GEMM
 * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tests.h"

#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>



template <typename T, typename S>
inline cublasStatus_t cublas_gemm_ex(T *A,  T *B,  S *C,
                                    int m, int n, int k);
                                    
template<int M, int N, int K, int num_codebook,
int len_vector, int nbits_per_codebook, int A_GROUP_SIZE=K>
class int3_col_wise_matmul_fp16{
public:
    static const int num_groups = K/A_GROUP_SIZE;
    uint8_t     qW[num_codebook][K/len_vector][N]; // (0, 255) code 
    uint32_t  bW[num_codebook][K/len_vector/4][N]; // bit packed

    float    codebook[num_codebook][256][len_vector]; // easy for nQWeight
    float     alpha[num_groups][N];

    float   weight[K][N];           // float weight
    float    input[M][K];
    float   output[M][N];


    __half* d_weight_fp16;
    __half*  d_input;

    __half* d_cu_output;
    __half* d_nq_output;

    codeGEMM::nQWeight_fp16 nqW;

    double run(
            bool run_cublas=true, bool run_lutgemm=false, bool run_gptq=false,
            int iter=16){
        
        alloc_cuda();
        makeRandomInput();
        makeRandomWeight();
        makeRandomAlpha();
        makeRandomCodebook();
        dequantizeFrom_qW();
        copy_cpuToCuda();

        nqW.parsing((uint32_t*)bW, (float*)alpha, K, N, false, num_groups,
            num_codebook, nbits_per_codebook, len_vector, (float*)codebook);
        cudaDeviceSynchronize();

        double meanError = checkErr();
        printf("Error: %.6f\t", meanError);
        //double meanError = 0;
        cudaDeviceSynchronize();

        if(run_cublas) cublas_latency(M, N, K, 
                                      d_input, d_weight_fp16, d_cu_output, iter);
        if(run_lutgemm) lutgemm_latency(nqW, M, N, K, 
                                        d_input, d_weight_fp16, d_cu_output, iter);

        free_cuda();
        return meanError;
    }

    void lutgemm_latency(
            codeGEMM::nQWeight_fp16 &nqW, 
            int m, int n, int k, __half* A, __half *B, __half *C, int iter=64){
        
        timer tm;
        codeGEMM::matmul((void*)C, (void*)A, nqW, m);
        cudaDeviceSynchronize();

        for(int i=0;i<iter;i++){
            tm.start();
            codeGEMM::matmul((void*)C, (void*)A, nqW, m);
            cudaDeviceSynchronize();
            tm.end();
        }
        printf("latency min : %.5fms, max : %.5fms, avg:%.5f\n", tm.min(), tm.max(), tm.mean());
    }

    void cublas_latency(
            int m, int n, int k, __half* A, __half *B, __half *C, int iter=64){
        
        timer tm;
        float th = 0;
        cublas_gemm_ex(A, B, C,
                            m, n, k);
        cudaDeviceSynchronize();
        for (int i = 0; i < iter; ++i) {
            tm.start();
            cublasStatus_t success;
            success = cublas_gemm_ex(A, B, C,
                                    m, n, k);
            cudaDeviceSynchronize();
            tm.end();
        }
            printf("latency min : %.5fms, max : %.5fms, avg:%.5f\n", tm.min(), tm.max(), tm.mean());

    }


    double checkErr(){
        cublas_gemm_ex(d_input, d_weight_fp16, d_cu_output, M, N, K);
        matmul_cpu();
        cudaMemset(d_nq_output, 0, sizeof(float) * M * N);
        codeGEMM::matmul(d_nq_output, d_input, nqW, M);
        cudaDeviceSynchronize();
        return checkOutputMeanError(d_cu_output, d_nq_output);
    }

    double checkOutputMeanError(__half *o1, __half *o2){
        double err=0;
        for(int m=0;m<M;m++){
            for(int n=0;n<N;n++){
                err += std::abs(float(o1[m*N + n]) - float(o2[m*N + n]));
                //err += std::abs(float(output[m][n]) - float(o2[m*N + n]));
                // if(n<50 | n>N-50) printf("[m:%d][n:%d] %f %f %f\n", m, n, float(output[m][n]), float(o1[m*N + n]), float(o2[m*N + n]));
            }
        }
        return err/M/N;
    }

    void matmul_cpu(){
        for(int m=0;m<M;m++){
            for(int n=0;n<N;n++){
                output[m][n] = 0;
                for(int k=0;k<K;k++){
                    output[m][n] += input[m][k] * weight[k][n];
                }
            }
        }
    }

    void makeRandomInput(){
        for(int m=0;m<M;m++)
            for(int k=0;k<K;k++)
                input[m][k] = rand_fp32(); // (-1.0, 1.0) / 2^b
                //input[m][k] = 0.01; // (-1.0, 1.0) / 2^b
    }

    void makeRandomAlpha(){
        for(int g=0;g<num_groups;g++)
            for(int n=0;n<N;n++){
                alpha[g][n] = rand_fp32(); // (-1.0, 1.0) / 2^b
                //alpha[g][n] = 1; 
            }
    }

    void makeRandomCodebook(){
        for(int k=0;k<num_codebook;k++)
            for(int i=0;i<256;i++)
                for(int j=0;j<len_vector;j++){
                    codebook[k][i][j] = rand_fp32();
                }
    }

    // void makeRandomWeight(){
    //     for(int n=0; n<N; n++){
    //         for(int k=0; k<K/len_vector; k++){
    //             uint8_t u_0 = 0;
    //             for(int j=0; j<num_codebook; j++){
    //                 u_0 = 0;
    //                 for(int t=0;t<8;t++){
    //                     if(rand_bool()){
    //                             u_0 |= 1<<(t);
    //                     }
    //                 }
    //                 qW[j][k][n] = u_0;
    //             }
    //         }
    //     }
    // }

    void makeRandomWeight(){
        for(int n=0; n<N; n++){
            for(int k=0; k<K*8/len_vector; k+=32){ 
                uint32_t s_0 = 0;
                uint8_t u_0 = 0;
                for(int j=0; j<num_codebook; j++){
                    s_0 = 0;
                    for(int i=0;i<32;i+=8){
                        u_0 = 0;
                        for(int t=0;t<8;t++){
                            if(rand_bool()){
                                    s_0 |= 1<<(t+i);
                                    u_0 |= 1<<(t);
                            }
                        }
                        qW[j][(k+i)/8][n] = u_0;
                    }
                    bW[j][k/32][n] = s_0;
                }
            }
        }
    }


    void dequantizeFrom_qW(){
        for(int n=0; n<N; n++){
            for(int k=0; k<K/len_vector; k++){ 
                for(int j=0; j<len_vector; j++){
                    weight[len_vector*k+j][n] = 0;
                    for(int i=0; i<num_codebook; i++){
                        uint8_t code = qW[i][k][n];
                        weight[len_vector*k+j][n] += (codebook[i][code][j]);
                    }
                    weight[len_vector*k+j][n] *= alpha[(len_vector*k+j)/A_GROUP_SIZE][n];
                }
            }
        }
    }

    void alloc_cuda(){
        cudaMallocManaged(&d_input    , sizeof(float) * M * K);   
        cudaMallocManaged(&d_weight_fp16, sizeof(float) * K * N);   

        cudaMallocManaged(&d_cu_output, sizeof(float) * M * N);       
        cudaMallocManaged(&d_nq_output, sizeof(float) * M * N);

    }
    
    void free_cuda(){
        cudaFree(d_input);
        cudaFree(d_weight_fp16);
        cudaFree(d_cu_output);
        cudaFree(d_nq_output);
    }
    void copy_cpuToCuda(){
        fhCpy(d_input , (float*)input, M * K);
        fhCpy(d_weight_fp16, (float*)weight, K * N);

        cudaDeviceSynchronize();
    }

    void hfCpy(float* a, __half* b, int size){
       for(int i=0;i<size;i++) a[i] = float(b[i]);
    }
    void fhCpy(__half* a, float* b, int size){
       for(int i=0;i<size;i++) a[i] = __float2half(b[i]);
    }

};
// Llama-3.1-8B
const int H = 4096;
const int A = 6144;
const int I = 14336;

// Llama-3.2-1B
// const int H = 2048;
// const int A = 3072;
// const int I = 8192;

// Llama-3.2-3B
// const int H = 3072;
// const int A = 5120;
// const int I = 8192;

// Llama-3.1-70B
// const int H = 8192;
// const int A = 10240;
// const int I = 28672;

const int M = 2;
const int V = 8;


TEST(int3_col_wise_matmul_fp16, layer_175b){
    double total_error = 0;
    int e_cnt = 0;
    printf("Start!\n");
    printf("M = 1, N = %d, K = %d\n", H, H);
    if(true){
    // if(false){
    printf("CodeGEMM group size (BS=1)\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, H, M, V, 8, 32>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, H, M, V, 8, 128>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, A, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, A, H, M, V, 8, 32>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, A, H, M, V, 8, 128>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, A, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, I, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, I, M, V, 8, 32>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, I, M, V, 8, 128>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, I, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, I, H*1, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, I, H*1, M, V, 8, 32>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, I, H*1, M, V, 8, 128>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, I, H*1, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    } else{

    printf("CodeGEMM batch size (g=1024)\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  H*1, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  H*1, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  H*1, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, H*1, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  H*1, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  H*1, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  H*1, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, H*1, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("****************************************************************\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  A, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  A, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  A, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, A, H, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  A, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  A, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  A, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, A, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("****************************************************************\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  H*1, I, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  H*1, I, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  H*1, I, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, H*1, I, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  H*1, I, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  H*1, I, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  H*1, I, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, H*1, I, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("****************************************************************\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  I, H*1, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  I, H*1, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  I, H*1, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, I, H*1, M, V, 8, 1024>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1,  I, H*1, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<4,  I, H*1, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<8,  I, H*1, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<16, I, H*1, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("****************************************************************\n");
    }

}




template <typename T, typename S>
inline cublasStatus_t cublas_gemm_ex(T *A,  T *B,  S *C,
                                    int m, int n, int k) {
    static S alpha = 1;
    static S beta  = 0;
    static cublasHandle_t handle = nullptr;
    if(handle == nullptr) cublasCreate(&handle);
    
    cudaDataType_t AType, BType, CType;
    cublasComputeType_t  ComputeType;
    if (std::is_same<T, float>::value) {
        AType = BType = CType = CUDA_R_32F;
        ComputeType = CUBLAS_COMPUTE_32F_FAST_TF32;
    } else if (std::is_same<T, __half>::value) {
        AType = BType = CType = CUDA_R_16F;
        ComputeType = CUBLAS_COMPUTE_16F;
    } else if (std::is_same<T, int8_t>::value) {
        AType = BType = CUDA_R_8I;
        CType = CUDA_R_32I;
        ComputeType = CUBLAS_COMPUTE_32I;
    } else {
        return CUBLAS_STATUS_NOT_SUPPORTED;
    }
    return cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N,
                          n, m, k, 
                          &alpha,
                          B, BType, n,
                          A, AType, k,
                          &beta,
                          C, CType, n,
                          ComputeType,
                          CUBLAS_GEMM_DFALT);
}
