/* LUT-GEMM
 * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <stdio.h>

#include "../include/nQWeight_fp16.h"

#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>

namespace codeGEMM{


void nQWeight_fp16::parsing(uint8_t *qW, float *alpha, int row, int col, 
        bool is_row_wise_quantize, int num_alpha_groups, int num_codebook,
        int nbits_per_codebook, int len_vector, float* codebook){
    this->num_groups = num_alpha_groups;
    this->group_size =  kSize/num_alpha_groups;
    this->nbits_per_codebook = nbits_per_codebook;
    this->len_vector = len_vector;
    this->num_codebook = num_codebook;
    int num_centroid = 1 << nbits_per_codebook;

    __half* p_alpha;
    __half* p_codebook;
    this->is_row_wise_quantize = is_row_wise_quantize;
    if(is_row_wise_quantize){
        mSize = row; 
        kSize = col; 
    }
    else{
        mSize = col; 
        kSize = row;             
    }

    if(codebook == nullptr) p_codebook = nullptr;
    else{
        cudaMallocManaged(&p_codebook,
            sizeof(__half  ) * num_centroid * len_vector * num_codebook);
        for(int i=0;i<num_centroid * len_vector * num_codebook;i++
            ) p_codebook[i] = __float2half(codebook[i]);
    }
    
    cudaMallocManaged(&p_alpha    ,sizeof(__half  ) * num_groups * mSize);
    for(int i=0;i<num_groups*mSize;i++) p_alpha[i] = __float2half(alpha[i]);

    cudaMallocManaged(&qWeight, 
        sizeof(uint8_t) * kSize * mSize / len_vector * num_codebook);
    cudaMemcpy(qWeight, qW,
        sizeof(uint8_t) * kSize * mSize / len_vector * num_codebook,
        cudaMemcpyHostToDevice);
    this->alpha = (void*)p_alpha;
    this->codebook = (void*)p_codebook;
}

void nQWeight_fp16::parsing(uint16_t *qW, float *alpha, int row, int col, 
    bool is_row_wise_quantize, int num_alpha_groups, int num_codebook,
    int nbits_per_codebook, int len_vector, float* codebook){
    this->num_groups = num_alpha_groups;
    this->group_size =  kSize/num_alpha_groups;
    this->nbits_per_codebook = nbits_per_codebook;
    this->len_vector = len_vector;
    this->num_codebook = num_codebook;
    int num_centroid = 1 << nbits_per_codebook;

    __half* p_alpha;
    __half* p_codebook;
    this->is_row_wise_quantize = is_row_wise_quantize;
    if(is_row_wise_quantize){
        mSize = row; 
        kSize = col; 
    }
    else{
        mSize = col; 
        kSize = row;             
    }

    if(codebook == nullptr) p_codebook = nullptr;
    else{
        cudaMallocManaged(&p_codebook,
            sizeof(__half  ) * num_centroid * len_vector * num_codebook);
        for(int i=0;i<num_centroid * len_vector * num_codebook;i++
            ) p_codebook[i] = __float2half(codebook[i]);
    }

    cudaMallocManaged(&p_alpha    ,sizeof(__half  ) * num_groups * mSize);
    for(int i=0;i<num_groups*mSize;i++) p_alpha[i] = __float2half(alpha[i]);

    cudaMallocManaged(&qWeight, 
        sizeof(uint16_t) * kSize * mSize / len_vector * num_codebook);
    cudaMemcpy(qWeight, qW,
        sizeof(uint16_t) * kSize * mSize / len_vector * num_codebook,
        cudaMemcpyHostToDevice);
    this->alpha = (void*)p_alpha;
    this->codebook = (void*)p_codebook;
}

void nQWeight_fp16::parsing(uint32_t *qW, float *alpha, int row, int col, 
    bool is_row_wise_quantize, int num_alpha_groups, int num_codebook,
    int nbits_per_codebook, int len_vector, float* codebook){
    this->num_groups = num_alpha_groups;
    this->group_size =  kSize/num_alpha_groups;
    this->nbits_per_codebook = nbits_per_codebook;
    this->len_vector = len_vector;
    this->num_codebook = num_codebook;
    int num_centroid = 1 << nbits_per_codebook;

    __half* p_alpha;
    __half* p_codebook;
    this->is_row_wise_quantize = is_row_wise_quantize;
    if(is_row_wise_quantize){
        mSize = row; 
        kSize = col; 
    }
    else{
        mSize = col; 
        kSize = row;             
    }

    if(codebook == nullptr) p_codebook = nullptr;
    else{
        cudaMallocManaged(&p_codebook,
            sizeof(__half  ) * num_centroid * len_vector * num_codebook);
        for(int i=0;i<num_centroid * len_vector * num_codebook;i++
            ) p_codebook[i] = __float2half(codebook[i]);
    }

    cudaMallocManaged(&p_alpha    ,sizeof(__half  ) * num_groups * mSize);
    for(int i=0;i<num_groups*mSize;i++) p_alpha[i] = __float2half(alpha[i]);

    cudaMallocManaged(&qWeight, 
        sizeof(uint32_t) * kSize * mSize / len_vector * num_codebook);
    cudaMemcpy(qWeight, qW,
        sizeof(uint32_t) * kSize * mSize / len_vector * num_codebook,
        cudaMemcpyHostToDevice);
    this->alpha = (void*)p_alpha;
    this->codebook = (void*)p_codebook;
}

nQWeight_fp16::~nQWeight_fp16(){
    cudaFree(alpha);
    cudaFree(alpha);
    if(codebook!= nullptr) cudaFree(codebook);
}

}

