/* LUT-GEMM
 * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tests.h"

#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>

int64_t rand_int64() {
    int64_t high = (int64_t)rand();
    int64_t low = (int64_t)rand();
    return (high << 32) | low;
}


template <typename T, typename S>
inline cublasStatus_t cublas_gemm_ex(T *A,  T *B,  S *C,
                                    int m, int n, int k);
               
template<int M, int N, int K>
class int3_col_wise_matmul_fp16{
public:
    int32_t    compressed[2 * M * K / 32];
    float    codebook[1 << (9 + 1)];


    float    input[K][N];
    float   output[M][N];


    __half*  d_input;
    int32_t*  d_compressed;
    __half*  d_codebook;

    float* d_nq_output;

    double run(
            bool run_cublas=true, bool run_lutgemm=false, bool run_gptq=false,
            int iter=16){
        
        alloc_cuda();
        makeRandomInput();
        makeRandomWeight();
        makeRandomCodebook();
        copy_cpuToCuda();

        cudaDeviceSynchronize();
        

        qtip_latency(M, N, K,
            d_input, d_compressed, d_codebook, d_nq_output, iter);
        

        free_cuda();
        return 0;
    }

    void qtip_latency(int m, int n, int k, 
        __half* A, int32_t *B, __half *C, float *D, int iter=64){
        
        timer tm;
        codeGEMM::matmul_qtip<M, K>((void*)D, (void*)A, (void*)B, (void*)C, n);
        cudaDeviceSynchronize();

        for(int i=0;i<iter;i++){
            tm.start();
            codeGEMM::matmul_qtip<M, K>((void*)D, (void*)A, (void*)B, (void*)C, n);
            cudaDeviceSynchronize();
            tm.end();
        }
        printf("latency min : %.5fms, max : %.5fms, avg:%.5f\n", tm.min(), tm.max(), tm.mean());
    }

    void makeRandomInput(){
        for(int n=0;n<N;n++)
            for(int k=0;k<K;k++)
                input[k][n] = rand_fp32(); // (-1.0, 1.0) / 2^b
                //input[m][k] = 0.01; // (-1.0, 1.0) / 2^b
    }

    void makeRandomCodebook(){
        for(int i=0;i<1024;i++)
            codebook[i] = std::clamp(rand_fp32()/16, -1.0f, 1.0f);
    }


    // void makeRandomWeight(){
    //     for(int i=0; i<2*M*K/32; i++){
    //         compressed[i] = rand();
    //     }
    // }
    void makeRandomWeight() {
        for (int i = 0; i < 2*M*K/32; ++i) {
            int32_t high = rand() & 0xFFFF;       
            int32_t low  = rand() & 0xFFFF;       
            compressed[i] = (high << 16) | low;   
    
            if (rand() % 2) compressed[i] = -compressed[i];
        }
    }

    void alloc_cuda(){
        cudaMallocManaged(&d_input    , sizeof(float) * N * K);   
        cudaMallocManaged(&d_compressed, sizeof(int32_t) * 2*M*K/32);   
        cudaMallocManaged(&d_codebook, sizeof(float) * 1024);   
        cudaMallocManaged(&d_nq_output, sizeof(float) * M * N);

    }
    
    void free_cuda(){
        cudaFree(d_input);
        cudaFree(d_compressed);
        cudaFree(d_codebook);
        cudaFree(d_nq_output);
    }
    void copy_cpuToCuda(){
        fhCpy(d_input , (float*)input , N * K);
        aCpy(d_compressed, (int32_t*)compressed, 2*M*K/32);
        fhCpy(d_codebook, (float*)codebook, 1024);

        cudaDeviceSynchronize();
    }

    void fhCpy(__half* a, float* b, int size){
        for(int i=0;i<size;i++) a[i] = __float2half(b[i]);
     }

    void ffCpy(float* a, float* b, int size){
        for(int i=0;i<size;i++) a[i] = (b[i]);
     }
     void aCpy(int32_t* a, int32_t* b, int size){
        for(int i=0;i<size;i++) a[i] = (b[i]);
     }
     void bCpy(float* a, float* b, int size){
        for(int i=0;i<size;i++) a[i] = (b[i]);
     }

};
// Llama-3.1-8B
const int H = 4096;
const int A = 6144;
const int I = 14336;

// Llama-3.2-1B
// const int H = 2048;
// const int A = 3072;
// const int I = 8192;

// Llama-3.2-3B
// const int H = 3072;
// const int A = 5120;
// const int I = 8192;

// Llama-3.1-70B
// const int H = 8192;
// const int A = 10240;
// const int I = 28672;


TEST(int3_col_wise_matmul_fp16, layer_175b){
    double total_error = 0;
    int e_cnt = 0;
    printf("Start!\n");
    printf("M = 1, N = %d, K = %d\n", H, H);
    // if(true){
    

    printf("QTIP batch size (g=1024)\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<H*1, 1,  H>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<H*1, 4,  H>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<H*1, 8,  H>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<H*1, 16, H>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("****************************************************************\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<A, 1,  H>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<A, 4,  H>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<A, 8,  H>>(); total_error += t->run(true, false, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<A, 16, H>>(); total_error += t->run(true, false, false); e_cnt++; }
    printf("****************************************************************\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<I, 1,  H>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<I, 4,  H>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<I, 8,  H>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<I, 16, H>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("****************************************************************\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<H, 1,  I>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<H, 4,  I>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<H, 8,  I>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<H, 16, I>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("****************************************************************\n");

}




template <typename T, typename S>
inline cublasStatus_t cublas_gemm_ex(T *A,  T *B,  S *C,
                                    int m, int n, int k) {
    static S alpha = 1;
    static S beta  = 0;
    static cublasHandle_t handle = nullptr;
    if(handle == nullptr) cublasCreate(&handle);
    
    cudaDataType_t AType, BType, CType;
    cublasComputeType_t  ComputeType;
    if (std::is_same<T, float>::value) {
        AType = BType = CType = CUDA_R_32F;
        ComputeType = CUBLAS_COMPUTE_32F_FAST_TF32;
    } else if (std::is_same<T, __half>::value) {
        AType = BType = CType = CUDA_R_16F;
        ComputeType = CUBLAS_COMPUTE_16F;
    } else if (std::is_same<T, int8_t>::value) {
        AType = BType = CUDA_R_8I;
        CType = CUDA_R_32I;
        ComputeType = CUBLAS_COMPUTE_32I;
    } else {
        return CUBLAS_STATUS_NOT_SUPPORTED;
    }
    return cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N,
                          n, m, k, 
                          &alpha,
                          B, BType, n,
                          A, AType, k,
                          &beta,
                          C, CType, n,
                          ComputeType,
                          CUBLAS_GEMM_DFALT);
}
