/* LUT-GEMM
 * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tests.h"

#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>



template <typename T, typename S>
inline cublasStatus_t cublas_gemm_ex(T *A,  T *B,  S *C,
                                    int m, int n, int k);
                                    
template<int M, int N, int K, int num_codebook,
int len_vector, int nbits_per_codebook, int A_GROUP_SIZE=K>
class int3_col_wise_matmul_fp16{
public:
    static const int num_groups = K/A_GROUP_SIZE;
    uint8_t     qW[N][K/len_vector][num_codebook];
    float    codebook[num_codebook][1 << nbits_per_codebook][len_vector];
    float     alpha[num_groups][N];

    float   weight[K][N];
    float    input[M][K];
    float   output[M][N];

    int num_centroid = 1 << nbits_per_codebook;
    __half* d_weight_fp16;  // note: weight[N][K]

    codeGEMM::nQWeight_fp16 nqW;

    double run(
            bool run_cublas=true, bool run_lutgemm=false, bool run_gptq=false,
            int iter=16){
        
        alloc_cuda();
        makeRandomWeight();
        makeRandomAlpha();
        makeRandomCodebook();
        dequantizeFrom_qW();
        copy_cpuToCuda();

        nqW.parsing((uint8_t*)qW, (float*)alpha, K, N, false, num_groups,
            num_codebook, nbits_per_codebook, len_vector, (float*)codebook);
        cudaDeviceSynchronize();

        double meanError = checkErr();
        printf("Error: %.6f\t", meanError);
        //double meanError = 0;
        cudaDeviceSynchronize();

        aqlm_latency(nqW, M, N, K, d_weight_fp16, iter);

        free_cuda();
        return meanError;
    }

    void aqlm_latency(
            codeGEMM::nQWeight_fp16 &nqW, 
            int m, int n, int k, __half *B, int iter=64){
        
        timer tm;
        codeGEMM::dequant_aqlm((void*)B, nqW);
        cudaDeviceSynchronize();

        for(int i=0;i<iter;i++){
            tm.start();
            codeGEMM::dequant_aqlm((void*)B, nqW);
            cudaDeviceSynchronize();
            tm.end();
        }
        printf("latency min : %.5fms, max : %.5fms, avg:%.5f\n", tm.min(), tm.max(), tm.mean());
    }


    double checkErr(){
        cudaMemset(d_weight_fp16, 0, sizeof(float) * K * N);
        codeGEMM::dequant_aqlm(d_weight_fp16, nqW);
        cudaDeviceSynchronize();
        return checkOutputMeanError(d_weight_fp16);
    }

    double checkOutputMeanError(__half *o1){
        double err=0;
        for(int k=0;k<K;k++){
            for(int n=0;n<N;n++){
                err += std::abs(float(o1[n*K + k]) - float(weight[k][n]));
                // if(n<50 | n>N-50) printf("[n:%d] %f %f\n", n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n%(N/4)==0 && k==0) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==0) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==1) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==2) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==3) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==4) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==8) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==10) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==12) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==14) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==16) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==1 && k==0) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==1 && k==1025) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==4094) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                // if(n==0 && k==4095) printf("[k:%d, n:%d] %f %f\n", k, n, float(o1[n*K + k]), float(weight[k][n]));
                
            }
        }
        return err/K/N;
    }

    void makeRandomAlpha(){
        for(int g=0;g<num_groups;g++)
            for(int n=0;n<N;n++){
                // alpha[g][n] = rand_fp32(); // (-1.0, 1.0) / 2^b
                alpha[g][n] = 1; 
            }
    }

    void makeRandomCodebook(){
        for(int k=0;k<num_codebook;k++)
            for(int i=0;i<num_centroid;i++)
                for(int j=0;j<len_vector;j++){
                    codebook[k][i][j] = rand_fp32();
                }
    }

    void makeRandomWeight(){
        for(int n=0; n<N; n++){
            for(int k=0; k<K/len_vector; k++){
                uint8_t u_0 = 0;
                for(int j=0; j<num_codebook; j++){
                    u_0 = 0;
                    for(int t=0;t<8;t++){
                        if(rand_bool()){
                                u_0 |= 1<<(t);
                        }
                    }
                    qW[n][k][j] = u_0;
                }
            }
        }
    }


    void dequantizeFrom_qW(){
        for(int n=0; n<N; n++){
            for(int k=0; k<K/len_vector; k++){
                for(int j=0; j<len_vector; j++){
                    weight[len_vector*k+j][n] = 0;
                    for(int i=0; i<num_codebook; i++){
                        uint8_t code = qW[n][k][i];
                        weight[len_vector*k+j][n] += (codebook[i][code][j]);
                    }
                    weight[len_vector*k+j][n] *= alpha[(len_vector*k+j)/A_GROUP_SIZE][n];
                }
            }
        }
    }

    void alloc_cuda(){
        cudaMallocManaged(&d_weight_fp16, sizeof(float) * K * N);   
    }
    
    void free_cuda(){
        cudaFree(d_weight_fp16);
    }
    void copy_cpuToCuda(){
        // fhCpy(d_input , (float*)input  ,M * K);
        // fhCpy(d_weight_fp16, (float*)weight ,K * N);

        cudaDeviceSynchronize();
    }

    void hfCpy(float* a, __half* b, int size){
       for(int i=0;i<size;i++) a[i] = float(b[i]);
    }
    void fhCpy(__half* a, float* b, int size){
       for(int i=0;i<size;i++) a[i] = __float2half(b[i]);
    }

};
// Llama-3.1-8B
const int H = 4096;
const int A = 6144;
const int I = 14336;

// Llama-3.2-1B
// const int H = 2048;
// const int A = 3072;
// const int I = 8192;

// Llama-3.2-3B
// const int H = 3072;
// const int A = 5120;
// const int I = 8192;

// Llama-3.1-70B
// const int H = 8192;
// const int A = 10240;
// const int I = 28672;

const int M = 1;
const int V = 4;


TEST(int3_col_wise_matmul_fp16, layer_175b){
    double total_error = 0;
    int e_cnt = 0;
    printf("Start!\n");
    printf("M = 1, N = %d, K = %d\n", H, H);
    // if(true){
    printf("AQLM 1x4 Dequantization\n");
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, A, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, A, H, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, I, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, H*1, I, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, I, H*1, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    { auto t = std::make_shared<int3_col_wise_matmul_fp16<1, I, H*1, M, V, 8, 1024>>(); total_error += t->run(false, true, false); e_cnt++; }
    printf("----------------------------------------------------------------\n");

}




template <typename T, typename S>
inline cublasStatus_t cublas_gemm_ex(T *A,  T *B,  S *C,
                                    int m, int n, int k) {
    static S alpha = 1;
    static S beta  = 0;
    static cublasHandle_t handle = nullptr;
    if(handle == nullptr) cublasCreate(&handle);
    
    cudaDataType_t AType, BType, CType;
    cublasComputeType_t  ComputeType;
    if (std::is_same<T, float>::value) {
        AType = BType = CType = CUDA_R_32F;
        ComputeType = CUBLAS_COMPUTE_32F_FAST_TF32;
    } else if (std::is_same<T, __half>::value) {
        AType = BType = CType = CUDA_R_16F;
        ComputeType = CUBLAS_COMPUTE_16F;
    } else if (std::is_same<T, int8_t>::value) {
        AType = BType = CUDA_R_8I;
        CType = CUDA_R_32I;
        ComputeType = CUBLAS_COMPUTE_32I;
    } else {
        return CUBLAS_STATUS_NOT_SUPPORTED;
    }
    return cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N,
                          n, m, k, 
                          &alpha,
                          B, BType, n,
                          A, AType, k,
                          &beta,
                          C, CType, n,
                          ComputeType,
                          CUBLAS_GEMM_DFALT);
}
