/* LUT-GEMM
 * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "../include/kernels.h"
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>

namespace codeGEMM{

#include "../src/cuda/kernels/cublas.h"
#include "../src/cuda/kernels/mv_fp16.hpp"

void matmul(void* output, nQWeight_fp16 &nqW, void* input, int n);
void matmul(void* output, void* input, nQWeight_fp16 &nqW, int m);
void matmul_aqlm(void* output, void* input, nQWeight_fp16 &nqW, int m);
void matmul_quip_sharp(void* output, void* input, void* qidxs, void* codebook, int m, int n, int k);
template<int M, int K>
void matmul_qtip(void* output, void* input, void* d_compressed, void* codebook, 
    int n);
void dequant_aqlm(void* output, nQWeight_fp16 &nqW);

/* float16 */
// inline void matmul_useCublas(__half* output, nQWeight_fp16 &nqW, __half* input, int n);
// inline void matmul_useCublas(__half* output, __half* input, nQWeight_fp16 &nqW, int m);
/************************** float16 ***********************/

void matmul(void* output, nQWeight_fp16 &nqW, void* input, int n){
    if(n==1){
        cudaMemset(output, 0, sizeof(__half) * nqW.mSize);  // 0.007ms 0.04
        kernel::nqmv_bias((__half*)output, nqW, (__half*)input);
    } 
    //else     matmul_useCublas((__half*)output, nqW, (__half*)input, n);
}
void matmul(void* output, void* input, nQWeight_fp16 &nqW, int m){
    if(m==1){
        cudaMemset(output, 0, sizeof(__half) * nqW.mSize);
        kernel::nqmv_bias((__half*)output, nqW, (__half*)input);
    }
    else {
        cudaMemset(output, 0, sizeof(__half) * nqW.mSize * m);
        for(int i=0; i<m; i++){
            kernel::nqmv_bias((__half*)output + nqW.mSize*i, nqW, (__half*)input + nqW.kSize*i);
        }
    }
    //else     matmul_useCublas((__half*)output, (__half*)input, nqW, m);
}

void matmul_aqlm(void* output, void* input, nQWeight_fp16 &nqW, int m){
    if(m==1){
        cudaMemset(output, 0, sizeof(__half) * nqW.mSize);
        kernel::aqlm_gemv((__half*)output, nqW, (__half*)input);
    } 
    else {
        cudaMemset(output, 0, sizeof(__half) * nqW.mSize * m);
        for(int i=0; i<m; i++){
            kernel::aqlm_gemv((__half*)output + nqW.mSize*i, nqW, (__half*)input + nqW.kSize*i);
        }
    }
    //else     matmul_useCublas((__half*)output, (__half*)input, nqW, m);
}

void dequant_aqlm(void* output, nQWeight_fp16 &nqW){

    cudaMemset(output, 0, sizeof(__half) * nqW.mSize * nqW.kSize);
    kernel::aqlm_dequant((__half*)output, nqW);

    //else     matmul_useCublas((__half*)output, (__half*)input, nqW, m);
}

void matmul_quip_sharp(void* output, void* input, void* qidxs, void* codebook, 
    int m, int n, int k){
    if(m==1){
        cudaMemset(output, 0, sizeof(float) * n);
        kernel::quip_sharp_gemv((float*)output, (uint2*)input, (uint2*)qidxs, (uint32_t*)codebook, n, k);
    } 
    else {
        cudaMemset(output, 0, sizeof(float) * n * m);
        for(int i=0; i<m; i++){
            kernel::quip_sharp_gemv((float*)output, (uint2*)input, (uint2*)qidxs, (uint32_t*)codebook,
                n, k);
        }
    }
    //else     matmul_useCublas((__half*)output, (__half*)input, nqW, m);
}

template<int M, int K>
void matmul_qtip(void* output, void* input, void* d_compressed, void* codebook, 
    int n){
    if(n==1){
        cudaMemset(output, 0, sizeof(float) * M);
        kernel::qtip_gemv<M, K>((float*)output, (half2*)input, 
                        (uint32_t*)d_compressed, (half2*)codebook);
    } 
    else {
        cudaMemset(output, 0, sizeof(float) * n * M);
        for(int i=0; i<n; i++){
            kernel::qtip_gemv<M, K>((float*)output, (half2*)input, 
                        (uint32_t*)d_compressed, (half2*)codebook);
        }
    }
    //else     matmul_useCublas((__half*)output, (__half*)input, nqW, m);
}


}



namespace codeGEMM {
    template void matmul_qtip<4096U, 4096U>(void*, void*, void*, void*, int);
    template void matmul_qtip<6144U, 4096U>(void*, void*, void*, void*, int);
    template void matmul_qtip<14336U, 4096U>(void*, void*, void*, void*, int);
    template void matmul_qtip<4096U, 14336U>(void*, void*, void*, void*, int);
    template void matmul_qtip<8192U, 8192U>(void*, void*, void*, void*, int);
    template void matmul_qtip<10240U, 8192U>(void*, void*, void*, void*, int);
    template void matmul_qtip<28672U, 8192U>(void*, void*, void*, void*, int);
    template void matmul_qtip<8192U, 28672U>(void*, void*, void*, void*, int);
}
