/* LUT-GEMM
 * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef KERNELS_MV_FP16_BIAS_HPP
#define KERNELS_MV_FP16_BIAS_HPP


namespace kernel{

#include <stdio.h>
#include "./codegemm.hpp"
#include "./aqlm.hpp"
#include "./quip_sharp.hpp"
#include "./qtip.hpp"

inline int div_roundup(int x , int y){return (x + y - 1)/ y;}

template<int k_tile_size, int num_codebook, int len_vector>
inline void _excute_nqmv(
        __half *output, nQWeight_fp16 &nqW, __half *input,
        int num_thraeds, int m_tile_size){
    
    dim3 grid(
        codeGEMM::kernel::div_roundup(nqW.mSize, m_tile_size), 
        codeGEMM::kernel::div_roundup(nqW.kSize, k_tile_size)); 
    dim3 block(num_thraeds);
    kernel::_codegemm<k_tile_size, num_codebook, len_vector><<<grid, block>>>(
        nqW.qWeight, (__half*)nqW.alpha,
        (__half*)nqW.codebook, input, output, nqW.mSize, nqW.kSize,
        m_tile_size, nqW.kSize/nqW.num_groups);
}

inline void nqmv_bias(__half *output, nQWeight_fp16 &nqW, __half *input){
    int k_tile_idx   =     0;
    int m_tile_size  =  2048;
    int num_thraeds  =   256;

    if(nqW.num_codebook==2 && nqW.len_vector==8){
        _excute_nqmv<32, 2, 8>(output, nqW, input, num_thraeds, m_tile_size);
    } else if(nqW.num_codebook==1 && nqW.len_vector==4){
        _excute_nqmv<32, 1, 4>(output, nqW, input, num_thraeds, m_tile_size);
    }
}

const int THREAD_M = 16;

inline void  aqlm_gemv(
    __half *output, nQWeight_fp16 &nqW, __half *input
) {

    int cc_major;
    cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, 0);
    int cc_minor;
    cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, 0);
    
    int sms;
    cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
    int waves = 0;
    int thread_m;
    do {
        waves++;
        thread_m = codeGEMM::kernel::div_roundup(nqW.mSize, waves * sms);
    } while (thread_m > THREAD_M);

    int blocks = codeGEMM::kernel::div_roundup(nqW.mSize, thread_m);
    int threads = 32 * thread_m;

    if(nqW.nbits_per_codebook==8){
        int shared = 16 * (2 * 256 * 8 + 32 * 9);
        cudaFuncSetAttribute(
            kernel::Code2x8MatVec<false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
        );

        kernel::Code2x8MatVec<false><<<blocks, threads, shared>>>(
            (const int4*) nqW.qWeight,
            (const int4*) input,
            (int4*) output,
            (const int4*) nqW.codebook,
            nqW.mSize,
            nqW.kSize
        );
    } else{
        kernel::Code1x16MatVec<false, 8><<<blocks, threads, 16*32*(8 + 1)>>>(
            (const int4*) nqW.qWeight,
            (const int4*) input,
            (int4*) output,
            (const int4*) nqW.codebook,
            nqW.mSize,
            nqW.kSize
          );
    }
}

inline void  aqlm_dequant(
    __half *output, nQWeight_fp16 &nqW
) {

    int cc_major;
    cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, 0);
    int cc_minor;
    cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, 0);
    
    int sms;
    cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
    int waves = 0;
    int thread_m;
    do {
        waves++;
        thread_m = codeGEMM::kernel::div_roundup(nqW.mSize, waves * sms);
    } while (thread_m > THREAD_M);

    int blocks = codeGEMM::kernel::div_roundup(nqW.mSize, thread_m);
    int threads = 32 * thread_m;
    if(nqW.len_vector==8 && nqW.num_codebook==2){
        int shared = 16 * (2 * 256 * 8 + 32 * 9);
        cudaFuncSetAttribute(
            kernel::Code2x8Dequant<false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
        );

        kernel::Code2x8Dequant<false><<<blocks, threads, shared>>>(
            (const int4*) nqW.qWeight,
            (int4*) output,
            (const int4*) nqW.codebook,
            nqW.mSize,
            nqW.kSize
        );
    } else{
        int shared = 16 * (1 * 256 * 4 + 32 * 9);
        cudaFuncSetAttribute(
            kernel::Code1x4Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
        );
        kernel::Code1x4Dequant<<<blocks, threads, shared>>>(
            (const uint8_t*) nqW.qWeight,
            (__half*) output,
            (const __half*) nqW.codebook,
            nqW.mSize,
            nqW.kSize
        );

        // kernel::Code1x4Dequant<false><<<blocks, threads, shared>>>(
        //     (const int2*) nqW.qWeight,
        //     (int2*) output,
        //     (const int2*) nqW.codebook,
        //     nqW.mSize,
        //     nqW.kSize
        // );
    }
    
}

inline void  quip_sharp_gemv(
    float* output, uint2* input, uint2* qidxs, uint32_t* codebook,
    int n, int k
) {

  int grid_size = 108;
  const dim3 block_size(32,32);

  kernel::decode_matvec_e8p_kernel<<<grid_size, block_size>>>(
        (float*) output,
        (const uint2*)input,
        (const uint2*)qidxs,
        (const uint32_t*)codebook,
        n,
        k
    );
}

#define BLOCK_COUNT             128
#define BLOCK_SIZE              1024

template<int M, int K>
inline void  qtip_gemv(
    float* output, half2* input, uint32_t* d_compressed, half2* codebook
) {

    constexpr uint32_t gridSize = BLOCK_COUNT;
    constexpr uint32_t blockSize = BLOCK_SIZE;

    uint32_t smemCodebookSize = 1<<(9+5+1+1);
    cudaFuncSetAttribute(
        kernel::kernel_decompress_matvec<16U, 9U, 2U, 1U, M, 1U, K>,
        cudaFuncAttributeMaxDynamicSharedMemorySize,
        smemCodebookSize);

    kernel::kernel_decompress_matvec<16U, 9U, 2U, 1U, M, 1U, K><<<gridSize, blockSize, smemCodebookSize>>>(
        (float*) output,
        (const uint32_t*)d_compressed,
        (const half2*)input,
        (const half2*)codebook
    );
}

}

#endif //KERNELS_MV_FP16_HPP

