/* LUT-GEMM
 * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef N_Q_WEIGHT_FP16_H
#define N_Q_WEIGHT_FP16_H

#include <stdint.h>

namespace codeGEMM{

class nQWeight_fp16;

void dequantize_gpu(nQWeight_fp16 &nqw, void *d_fW, int algo=0);
void dequantize_cpu(nQWeight_fp16 &nqw, void *fW);

class nQWeight_fp16{
public:
    // unsigned char* qWeight;  // Weight[kSize/32][nb][mSize]   
    unsigned int* qWeight;  // Weight[kSize/32][nb][mSize]   
    void* alpha;     //  alpha[num_alpha_groups][nb][mSize]
    void* codebook;   //q_bias[num_alpha_groups][mSize]
    int num_groups;
    int group_size;
    int mSize;
    int kSize;
    int nbits_per_codebook;
    int len_vector;
    int num_codebook;
    bool is_row_wise_quantize;
    nQWeight_fp16() {}

    nQWeight_fp16(uint8_t *qW, float *A, int row, int col, int num_bits, 
        bool is_row_wise_quantize, int num_alpha_groups, int num_codebook,
        int nbits_per_codebook, int len_vector, float* codebook= nullptr){
        parsing(qW, A, row, col, is_row_wise_quantize, num_alpha_groups, 
            num_codebook, nbits_per_codebook, len_vector, codebook);
    }

    nQWeight_fp16(uint16_t *qW, float *A, int row, int col, int num_bits, 
        bool is_row_wise_quantize, int num_alpha_groups, int num_codebook,
        int nbits_per_codebook, int len_vector, float* codebook= nullptr){
        parsing(qW, A, row, col, is_row_wise_quantize, num_alpha_groups, 
            num_codebook, nbits_per_codebook, len_vector, codebook);
    }

    nQWeight_fp16(uint32_t *qW, float *A, int row, int col, int num_bits, 
        bool is_row_wise_quantize, int num_alpha_groups, int num_codebook,
        int nbits_per_codebook, int len_vector, float* codebook= nullptr){
        parsing(qW, A, row, col, is_row_wise_quantize, num_alpha_groups, 
            num_codebook, nbits_per_codebook, len_vector, codebook);
    }

    void parsing(uint8_t *qW, float *A, int row, int col,
        bool is_row_wise_quantize, int num_alpha_groups, int num_codebook, 
        int nbits_per_codebook, int len_vector, float* codebook);
    
    void parsing(uint16_t *qW, float *A, int row, int col,
        bool is_row_wise_quantize, int num_alpha_groups, int num_codebook, 
        int nbits_per_codebook, int len_vector, float* codebook);

    void parsing(uint32_t *qW, float *A, int row, int col,
        bool is_row_wise_quantize, int num_alpha_groups, int num_codebook, 
        int nbits_per_codebook, int len_vector, float* codebook);

    ~nQWeight_fp16();
    
    void* getDequantiedWeight(bool onGPU=true);
};

}
#endif // N_Q_WEIGHT_FP16_H
