/* LUT-GEMM
 * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef KERNELS_CODEGEMM_HPP
#define KERNELS_CODEGEMM_HPP


namespace kernel{

#include <stdio.h>

template<int K_TILE_SIZE, int num_codebook, int len_vector>
__global__ void _codegemm(
    uint32_t *qw,
    __half *alpha,
    __half *codebook,
    __half *input, __half *output,
    int M, int K, int M_TILE_SIZE, int group_size){


    __shared__ __half lut[K_TILE_SIZE/len_vector][num_codebook][256];
    __half base[num_codebook] = {0.0f};
    __half* _codebook[num_codebook];
    
    // __shared__ float lut[8][4][256];
    // float base[8] = {0.0f};
    // __half* _codebook[8];
    
    for (int i = 0; i < num_codebook; i++) {
        _codebook[i] = &codebook[threadIdx.x * len_vector + 256 * len_vector * i];
    }
    
    for (int i = threadIdx.y; i < K_TILE_SIZE / len_vector; i++) {
        __half* _inp = &input[blockIdx.y * K_TILE_SIZE + i * len_vector];
    
        for (int k = 0; k < num_codebook; k++) {
        for (int j = 0; j < len_vector; j++) {
            base[k] += (_codebook[k][j]) * _inp[j];
        }
        lut[i][k][threadIdx.x] = base[k];
        base[k] = 0.0f; // Reset for next iteration
        }
    }
    __syncthreads();
    
    // Compute output in parallel
    int m_start = blockIdx.x * M_TILE_SIZE + threadIdx.x * 2;
    int m_end = min((blockIdx.x + 1) * M_TILE_SIZE, M);
    int m_step = blockDim.x * 2;
    
    uint32_t* base_addr = &qw[blockIdx.y * K_TILE_SIZE / 32 * 8 / len_vector * M];
    int group_idx = (blockIdx.y * K_TILE_SIZE) / group_size;
    
    for (int m = m_start; m < m_end; m += m_step) {
        __half reg_o0 = 0.0f;
        __half reg_o1 = 0.0f;
    
        __half reg_a0 = (alpha[group_idx * M + m]);
        __half reg_a1 = (alpha[group_idx * M + m + 1]);
    
        for (int i = 0; i < K_TILE_SIZE/len_vector; i++) {
            uint32_t* temp_base_addr = base_addr + m + M * (i/len_vector);
            for (int k = 0; k < num_codebook; k++) {
                uint32_t reg = *(temp_base_addr + M * K / 32 * 8 / len_vector * k);
                reg_o0 += lut[i][k][(reg >> (8 * (i % 4))) & 255];
    
                reg = *(temp_base_addr + M * K / 32 * 8 / len_vector * k + 1);
                reg_o1 += lut[i][k][(reg >> (8 * (i % 4))) & 255];
            }
        }
    
        reg_o0 *= reg_a0;
        reg_o1 *= reg_a1;
    
        atomicAdd((half2*)&output[m], __halves2half2((reg_o0), (reg_o1)));
    }
}


}

#endif //KERNELS_CODEGEMM_HPP

