#include <cuda_fp16.h>
#include <cuda.h>
#include <cstdio>
#include <ctime>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <fstream>
#include "sqllm.h"
#include "lutgemm.h"
#include "dec_config.h"
#include "dec_context.h"
#include "decdec.h"
#include "dec.cuh"
#include "anyprec.h"
#include "typetraits.h"
#include "datatype.h"

#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <assert.h>

// SQLLM 
#define SQLLM_BLOCKWIDTH 128
#define SQLLM_BLOCKHEIGHT3 12
#define SQLLM_BLOCKHEIGHT4 16

// LUTGEMM
#define K_TILE_SIZE 32
#define NUM_THREADS 256
#define M_TILE_SIZE 2048

////////////////////////////////////////////////////////////////////////////////
//                                     ANYPREC
////////////////////////////////////////////////////////////////////////////////

template<DataType DT>
void anyprec_gemv_templated(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth,
    cudaStream_t stream
) {
    uint32_t M = input.size(0);
    uint32_t N = output.size(2);
    uint32_t K = input.size(2);

    anyprec_matmul<DT>(
        (FP_DTYPE(DT)*)input.data_ptr<ATEN_DTYPE(DT)>(),
        (FP_DTYPE(DT)*)output.data_ptr<ATEN_DTYPE(DT)>(),
        (uint32_t*)qweight.data_ptr<int>(),
        (FP_DTYPE(DT)*)lut.data_ptr<ATEN_DTYPE(DT)>(),
        M, N, K,
        bitwidth,
        stream
    );
}

void anyprec_gemv_stream(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth,
    cudaStream_t stream
) {
    TORCH_CHECK(bitwidth >= 3 && bitwidth <= 8, "Bitwidth must be between 3 and 8.");
    TORCH_CHECK(input.scalar_type() == lut.scalar_type() && input.scalar_type() == output.scalar_type(), 
                "Mismatched data types between input, lut, and output tensors.");
    TORCH_CHECK(qweight.scalar_type() == at::kInt, "qweight tensor must be of type int.");
    TORCH_CHECK(input.dim() == 3, "input tensor must be of shape (batch_size, seq_len, hidden_size).");
    TORCH_CHECK(output.dim() == 3, "output tensor must be of shape (batch_size, seq_len, hidden_size).");

    // lut is of shape (output_feat, 2 ** bitwidth)
    TORCH_CHECK(lut.dim() == 2 && lut.size(1) == (1 << bitwidth) && lut.size(0) == output.size(2),
    "lut tensor must be of shape (output_feat, 2 ** bitwidth). Expected (", output.size(2), ", ", 1 << bitwidth, "), got (", lut.size(0), ", ", lut.size(1), ").");

    // qweight is of shape (bitwidth, output_feat, input_feat / 32)
    // TORCH_CHECK(qweight.dim() == 3 && qweight.size(0) == bitwidth && qweight.size(2) == input.size(2) / 32 && qweight.size(1) == output.size(2),
    // "qweight tensor must be of shape (bitwidth, output_feat, input_feat / 32). Expected (", bitwidth, ", ", output.size(2), ", ", input.size(2) / 32, "), got (", qweight.size(0), ", ", qweight.size(1), ", ", qweight.size(2), ").");

    // Check that sequence length is 1
    TORCH_CHECK(input.size(1) == 1, "Only sequence length of 1 is supported.");
    TORCH_CHECK(output.size(1) == 1, "Only sequence length of 1 is supported.");

    // Check that input and output are both on GPU
    TORCH_CHECK(input.is_cuda() && output.is_cuda(), "input and output tensors must be on GPU.");

    // Check that all tensors are contiguous
    TORCH_CHECK(input.is_contiguous(), "input tensor must be contiguous.");
    TORCH_CHECK(output.is_contiguous(), "output tensor must be contiguous.");
    TORCH_CHECK(qweight.is_contiguous(), "qweight tensor must be contiguous.");
    TORCH_CHECK(lut.is_contiguous(), "lut tensor must be contiguous.");

    auto dtype = input.scalar_type();
    if (dtype == at::kFloat) {
        TORCH_CHECK(false, "Any-Precision GEMV does not support float data type. Please use half or bfloat16.");
        //anyprec_gemv_templated<DataType::FP32>(input, output, qweight, lut, bitwidth, stream);
    } else if (dtype == at::kHalf) {
        anyprec_gemv_templated<DataType::FP16>(input, output, qweight, lut, bitwidth, stream);
    } else if (dtype == at::kBFloat16) {
        anyprec_gemv_templated<DataType::BF16>(input, output, qweight, lut, bitwidth, stream);
    } else {
        TORCH_CHECK(false, "Unsupported data type.");
    }
}

void anyprec_gemv(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth
) {
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    anyprec_gemv_stream(input, output, qweight, lut, bitwidth, stream);
}

template<DataType DT>
void anyprec_gemv_templated_sel(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut3,
    torch::Tensor lut4,
    torch::Tensor lut5,
    torch::Tensor lut6,
    int bitwidth,
    torch::Tensor bsel,
    cudaStream_t stream
) {
    uint32_t M = input.size(0);
    uint32_t N = output.size(2);
    uint32_t K = input.size(2);
    
    // lut is of shape (output_feat, 2 ** bitwidth)
    TORCH_CHECK(lut3.dim() == 2 && lut3.size(1) == (1 << 3) && lut3.size(0) == output.size(2),
    "lut3 tensor must be of shape (output_feat, 2 ** 3). Expected (", output.size(2), ", ", 1 << 3, "), got (", lut3.size(0), ", ", lut3.size(1), ").");
    TORCH_CHECK(lut4.dim() == 2 && lut4.size(1) == (1 << 4) && lut4.size(0) == output.size(2),
    "lut4 tensor must be of shape (output_feat, 2 ** 4). Expected (", output.size(2), ", ", 1 << 4, "), got (", lut4.size(0), ", ", lut4.size(1), ").");
    TORCH_CHECK(lut5.dim() == 2 && lut5.size(1) == (1 << 5) && lut5.size(0) == output.size(2),
    "lut5 tensor must be of shape (output_feat, 2 ** 5). Expected (", output.size(2), ", ", 1 << 5, "), got (", lut5.size(0), ", ", lut5.size(1), ").");
    TORCH_CHECK(lut6.dim() == 2 && lut6.size(1) == (1 << 6) && lut6.size(0) == output.size(2),
    "lut6 tensor must be of shape (output_feat, 2 ** 6). Expected (", output.size(2), ", ", 1 << 6, "), got (", lut6.size(0), ", ", lut4.size(1), ").");

    // qweight is of shape (bitwidth, output_feat, input_feat / 32)
    TORCH_CHECK(qweight.dim() == 3 && qweight.size(0) == 6 && qweight.size(2) == input.size(2) / 32 && qweight.size(1) == output.size(2),
    "qweight tensor must be of shape (4, output_feat, input_feat / 32). Expected (", 6, ", ", output.size(2), ", ", input.size(2) / 32, "), got (", qweight.size(0), ", ", qweight.size(1), ", ", qweight.size(2), ").");


    anyprec_matmul_sel<DT>(
        (FP_DTYPE(DT)*)input.data_ptr<ATEN_DTYPE(DT)>(),
        (FP_DTYPE(DT)*)output.data_ptr<ATEN_DTYPE(DT)>(),
        (uint32_t*)qweight.data_ptr<int>(),
        (FP_DTYPE(DT)*)lut3.data_ptr<ATEN_DTYPE(DT)>(),
        (FP_DTYPE(DT)*)lut4.data_ptr<ATEN_DTYPE(DT)>(),
        (FP_DTYPE(DT)*)lut5.data_ptr<ATEN_DTYPE(DT)>(),
        (FP_DTYPE(DT)*)lut6.data_ptr<ATEN_DTYPE(DT)>(),
        M, N, K,
        (int32_t *)bsel.data_ptr<int>(),
        stream
    );
}


void anyprec_gemv_stream_sel(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut3,
    torch::Tensor lut4,
    torch::Tensor lut5,
    torch::Tensor lut6,
    int bitwidth,
    torch::Tensor bsel,
    cudaStream_t stream
) {
    TORCH_CHECK(bitwidth >= 3 && bitwidth <= 8, "Bitwidth must be between 3 and 8.");
    TORCH_CHECK(input.scalar_type() == lut3.scalar_type() && input.scalar_type() == output.scalar_type(), 
                "Mismatched data types between input, lut, and output tensors.");
    TORCH_CHECK(qweight.scalar_type() == at::kInt, "qweight tensor must be of type int.");
    TORCH_CHECK(input.dim() == 3, "input tensor must be of shape (batch_size, seq_len, hidden_size).");
    TORCH_CHECK(output.dim() == 3, "output tensor must be of shape (batch_size, seq_len, hidden_size).");

    // Check that sequence length is 1
    TORCH_CHECK(input.size(1) == 1, "Only sequence length of 1 is supported.");
    TORCH_CHECK(output.size(1) == 1, "Only sequence length of 1 is supported.");

    // Check that input and output are both on GPU
    TORCH_CHECK(input.is_cuda() && output.is_cuda(), "input and output tensors must be on GPU.");

    // Check that all tensors are contiguous
    TORCH_CHECK(input.is_contiguous(), "input tensor must be contiguous.");
    TORCH_CHECK(output.is_contiguous(), "output tensor must be contiguous.");
    TORCH_CHECK(qweight.is_contiguous(), "qweight tensor must be contiguous.");
    TORCH_CHECK(lut3.is_contiguous(), "lut3 tensor must be contiguous.");
    TORCH_CHECK(lut4.is_contiguous(), "lut4 tensor must be contiguous.");
    TORCH_CHECK(lut5.is_contiguous(), "lut5 tensor must be contiguous.");
    TORCH_CHECK(lut6.is_contiguous(), "lut6 tensor must be contiguous.");

    auto dtype = input.scalar_type();
    if (dtype == at::kFloat) {
        TORCH_CHECK(false, "Any-Precision GEMV does not support float data type. Please use half or bfloat16.");
        //anyprec_gemv_templated<DataType::FP32>(input, output, qweight, lut, bitwidth, stream);
    } else if (dtype == at::kHalf) {
        anyprec_gemv_templated_sel<DataType::FP16>(input, output, qweight, lut3, lut4, lut5, lut6, bitwidth, bsel, stream);
    } else if (dtype == at::kBFloat16) {
        anyprec_gemv_templated_sel<DataType::BF16>(input, output, qweight, lut3, lut4, lut5, lut6, bitwidth, bsel, stream);
    } else {
        TORCH_CHECK(false, "Unsupported data type.");
    }
}

StreamNevent::StreamNevent (
    cudaStream_t _sub_stream,
    cudaEvent_t _start_event,
    cudaEvent_t _end_event
):
    sub_stream(_sub_stream),
    start_event(_start_event),
    end_event(_end_event)
{}

StreamNevent::~StreamNevent() {}


StreamNevent_full::StreamNevent_full (
    cudaStream_t _sub_stream,
    cudaEvent_t _start_event,
    cudaEvent_t _end_event,
    cudaEvent_t _mid1_event,
    cudaEvent_t _mid2_event
):
    sub_stream(_sub_stream),
    start_event(_start_event),
    end_event(_end_event),
    mid1_event(_mid1_event),
    mid2_event(_mid2_event)
{}

StreamNevent_full::~StreamNevent_full() {}


void anyprec_gemv_sel(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut3,
    torch::Tensor lut4,
    torch::Tensor lut5,
    torch::Tensor lut6,
    int bitwidth,
    torch::Tensor bsel,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaEvent_t end_event_ = sne_->end_event;

    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    cudaStreamWaitEvent(stream, end_event_, 0);
    // cudaError_t err = cudaGetLastError();
    // if (err != cudaSuccess) {
    //     printf("GEMVTH: %s\n", cudaGetErrorString(err));
    // }
    anyprec_gemv_stream_sel(input, output, qweight, lut3, lut4, lut5, lut6, bitwidth, bsel, stream);
    // err = cudaGetLastError();
    // if (err != cudaSuccess) {
    //     printf("AP Error: %s\n", cudaGetErrorString(err));
    // }
}

void anyprec_gemv_sel_fake(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth,
    torch::Tensor bsel
) {
    anyprec_gemv(input, output, qweight, lut, bitwidth);
}

#define K_def 4096
#define N_def 64
#define WARP_SIZE 32
#define WARPS_PER_BLOCK 16
#define THREADS_PER_BLOCK (WARP_SIZE * WARPS_PER_BLOCK)  // WARP_SIZE * WARPS_PER_BLOCK
#define UNROLL 16
#define VEC 2
#define K_per_warp (K_def/VEC)

__global__ void warp1NaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C) {
    // while(1){
    __shared__ half C_smem[N_def];
    const size_t lane_id = threadIdx.x % WARP_SIZE;
    
    for(int iter=0; iter<N_def/WARPS_PER_BLOCK; iter++){
        const size_t warp_id = threadIdx.x / WARP_SIZE + iter * WARPS_PER_BLOCK;

        if(warp_id > N_def) break;
    
        const half2 * A2 = reinterpret_cast<const half2*>(A);
        const half2 * B2 = reinterpret_cast<const half2*>(B);
        
        __half2 temp = make_half2(__float2half(0.0), __float2half(0.0));
        #pragma unroll
        for(int i=0; i<K_def/VEC; i+=WARP_SIZE*UNROLL){
            const size_t A2_idx0 = i + lane_id;
            const size_t B2_idx0 = warp_id * K_per_warp + i + lane_id;
            __half2 temp0 = __hmul2(A2[A2_idx0], B2[B2_idx0]);
            const size_t A2_idx1 = i + lane_id + WARP_SIZE * 1;
            const size_t B2_idx1 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 1;
            __half2 temp1 = __hmul2(A2[A2_idx1], B2[B2_idx1]);
            const size_t A2_idx2 = i + lane_id + WARP_SIZE * 2;
            const size_t B2_idx2 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 2;
            __half2 temp2 = __hmul2(A2[A2_idx2], B2[B2_idx2]);
            const size_t A2_idx3 = i + lane_id + WARP_SIZE * 3;
            const size_t B2_idx3 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 3;
            __half2 temp3 = __hmul2(A2[A2_idx3], B2[B2_idx3]);
            const size_t A2_idx4 = i + lane_id + WARP_SIZE * 4;
            const size_t B2_idx4 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 4;
            __half2 temp4 = __hmul2(A2[A2_idx4], B2[B2_idx4]);
            const size_t A2_idx5 = i + lane_id + WARP_SIZE * 5;
            const size_t B2_idx5 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 5;
            __half2 temp5 = __hmul2(A2[A2_idx5], B2[B2_idx5]);
            const size_t A2_idx6 = i + lane_id + WARP_SIZE * 6;
            const size_t B2_idx6 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 6;
            __half2 temp6 = __hmul2(A2[A2_idx6], B2[B2_idx6]);
            const size_t A2_idx7 = i + lane_id + WARP_SIZE * 7;
            const size_t B2_idx7 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 7;
            __half2 temp7 = __hmul2(A2[A2_idx7], B2[B2_idx7]);
            const size_t A2_idx8 = i + lane_id + WARP_SIZE * 8;
            const size_t B2_idx8 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 8;
            __half2 temp8 = __hmul2(A2[A2_idx8], B2[B2_idx8]);
            const size_t A2_idx9 = i + lane_id + WARP_SIZE * 9;
            const size_t B2_idx9 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 9;
            __half2 temp9 = __hmul2(A2[A2_idx9], B2[B2_idx9]);
            const size_t A2_idx10 = i + lane_id + WARP_SIZE * 10;
            const size_t B2_idx10 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 10;
            __half2 temp10 = __hmul2(A2[A2_idx10], B2[B2_idx10]);
            const size_t A2_idx11 = i + lane_id + WARP_SIZE * 11;
            const size_t B2_idx11 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 11;
            __half2 temp11 = __hmul2(A2[A2_idx11], B2[B2_idx11]);
            const size_t A2_idx12 = i + lane_id + WARP_SIZE * 12;
            const size_t B2_idx12 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 12;
            __half2 temp12 = __hmul2(A2[A2_idx12], B2[B2_idx12]);
            const size_t A2_idx13 = i + lane_id + WARP_SIZE * 13;
            const size_t B2_idx13 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 13;
            __half2 temp13 = __hmul2(A2[A2_idx13], B2[B2_idx13]);
            const size_t A2_idx14 = i + lane_id + WARP_SIZE * 14;
            const size_t B2_idx14 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 14;
            __half2 temp14 = __hmul2(A2[A2_idx14], B2[B2_idx14]);
            const size_t A2_idx15 = i + lane_id + WARP_SIZE * 15;
            const size_t B2_idx15 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 15;
            __half2 temp15 = __hmul2(A2[A2_idx15], B2[B2_idx15]);
    
            temp0 = __hadd2(temp0, temp1);
            temp2 = __hadd2(temp2, temp3);
            temp4 = __hadd2(temp4, temp5);
            temp6 = __hadd2(temp6, temp7);
            temp8 = __hadd2(temp8, temp9);
            temp10 = __hadd2(temp10, temp11);
            temp12 = __hadd2(temp12, temp13);
            temp14 = __hadd2(temp14, temp15);
    
            temp0 = __hadd2(temp0, temp2);
            temp4 = __hadd2(temp4, temp6);
            temp8 = __hadd2(temp8, temp10);
            temp12 = __hadd2(temp12, temp14);
    
            temp0 = __hadd2(temp0, temp4);
            temp8 = __hadd2(temp8, temp12);
    
            temp0 = __hadd2(temp0, temp8);

            temp = __hadd2(temp0, temp);
        }
        
        constexpr unsigned int mask = 0xffffffff;
        #pragma unroll
        for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
            temp = __hadd2(__shfl_xor_sync(mask, temp, i), temp);
        }
    
        if(lane_id == 0) C_smem[warp_id] = __hadd(temp.x,temp.y);
    }
    __syncthreads();
    const size_t warp_id = threadIdx.x / WARP_SIZE;
    if(warp_id == 0){
        half csum = __float2half(0.0);

        #pragma unroll
        for(int i=0;i<N_def;i+=WARP_SIZE){
            half c_now = C_smem[i + lane_id];
            c_now = c_now * c_now;
            csum = csum + c_now;
        }

        constexpr unsigned int mask = 0xffffffff;
        #pragma unroll
        for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
            csum = __hadd(__shfl_xor_sync(mask, csum, i), csum);
        }

        if(threadIdx.x == 0){
            C[0] = csum;
        }

    }
        
    // }
}
__global__ void warp1NaiveKernelTH(const half *__restrict__ A, const half *__restrict__ B, float threshold,
                                int *__restrict__ bsel, int *__restrict__ low_p, int *__restrict__ high_p) {
    // while(1){
    __shared__ half C_smem[N_def];
    const size_t lane_id = threadIdx.x % WARP_SIZE;

    int low = *low_p;
    int high = *high_p;
    
    for(int iter=0; iter<N_def/WARPS_PER_BLOCK; iter++){
        const size_t warp_id = threadIdx.x / WARP_SIZE + iter * WARPS_PER_BLOCK;

        if(warp_id > N_def) break;
    
        const half2 * A2 = reinterpret_cast<const half2*>(A);
        const half2 * B2 = reinterpret_cast<const half2*>(B);
        
        __half2 temp = make_half2(__float2half(0.0), __float2half(0.0));
        #pragma unroll
        for(int i=0; i<K_def/VEC; i+=WARP_SIZE*UNROLL){
            const size_t A2_idx0 = i + lane_id;
            const size_t B2_idx0 = warp_id * K_per_warp + i + lane_id;
            __half2 temp0 = __hmul2(A2[A2_idx0], B2[B2_idx0]);
            const size_t A2_idx1 = i + lane_id + WARP_SIZE * 1;
            const size_t B2_idx1 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 1;
            __half2 temp1 = __hmul2(A2[A2_idx1], B2[B2_idx1]);
            const size_t A2_idx2 = i + lane_id + WARP_SIZE * 2;
            const size_t B2_idx2 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 2;
            __half2 temp2 = __hmul2(A2[A2_idx2], B2[B2_idx2]);
            const size_t A2_idx3 = i + lane_id + WARP_SIZE * 3;
            const size_t B2_idx3 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 3;
            __half2 temp3 = __hmul2(A2[A2_idx3], B2[B2_idx3]);
            const size_t A2_idx4 = i + lane_id + WARP_SIZE * 4;
            const size_t B2_idx4 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 4;
            __half2 temp4 = __hmul2(A2[A2_idx4], B2[B2_idx4]);
            const size_t A2_idx5 = i + lane_id + WARP_SIZE * 5;
            const size_t B2_idx5 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 5;
            __half2 temp5 = __hmul2(A2[A2_idx5], B2[B2_idx5]);
            const size_t A2_idx6 = i + lane_id + WARP_SIZE * 6;
            const size_t B2_idx6 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 6;
            __half2 temp6 = __hmul2(A2[A2_idx6], B2[B2_idx6]);
            const size_t A2_idx7 = i + lane_id + WARP_SIZE * 7;
            const size_t B2_idx7 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 7;
            __half2 temp7 = __hmul2(A2[A2_idx7], B2[B2_idx7]);
            const size_t A2_idx8 = i + lane_id + WARP_SIZE * 8;
            const size_t B2_idx8 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 8;
            __half2 temp8 = __hmul2(A2[A2_idx8], B2[B2_idx8]);
            const size_t A2_idx9 = i + lane_id + WARP_SIZE * 9;
            const size_t B2_idx9 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 9;
            __half2 temp9 = __hmul2(A2[A2_idx9], B2[B2_idx9]);
            const size_t A2_idx10 = i + lane_id + WARP_SIZE * 10;
            const size_t B2_idx10 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 10;
            __half2 temp10 = __hmul2(A2[A2_idx10], B2[B2_idx10]);
            const size_t A2_idx11 = i + lane_id + WARP_SIZE * 11;
            const size_t B2_idx11 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 11;
            __half2 temp11 = __hmul2(A2[A2_idx11], B2[B2_idx11]);
            const size_t A2_idx12 = i + lane_id + WARP_SIZE * 12;
            const size_t B2_idx12 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 12;
            __half2 temp12 = __hmul2(A2[A2_idx12], B2[B2_idx12]);
            const size_t A2_idx13 = i + lane_id + WARP_SIZE * 13;
            const size_t B2_idx13 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 13;
            __half2 temp13 = __hmul2(A2[A2_idx13], B2[B2_idx13]);
            const size_t A2_idx14 = i + lane_id + WARP_SIZE * 14;
            const size_t B2_idx14 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 14;
            __half2 temp14 = __hmul2(A2[A2_idx14], B2[B2_idx14]);
            const size_t A2_idx15 = i + lane_id + WARP_SIZE * 15;
            const size_t B2_idx15 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 15;
            __half2 temp15 = __hmul2(A2[A2_idx15], B2[B2_idx15]);
    
            temp0 = __hadd2(temp0, temp1);
            temp2 = __hadd2(temp2, temp3);
            temp4 = __hadd2(temp4, temp5);
            temp6 = __hadd2(temp6, temp7);
            temp8 = __hadd2(temp8, temp9);
            temp10 = __hadd2(temp10, temp11);
            temp12 = __hadd2(temp12, temp13);
            temp14 = __hadd2(temp14, temp15);
    
            temp0 = __hadd2(temp0, temp2);
            temp4 = __hadd2(temp4, temp6);
            temp8 = __hadd2(temp8, temp10);
            temp12 = __hadd2(temp12, temp14);
    
            temp0 = __hadd2(temp0, temp4);
            temp8 = __hadd2(temp8, temp12);
    
            temp0 = __hadd2(temp0, temp8);

            temp = __hadd2(temp0, temp);
        }
        
        constexpr unsigned int mask = 0xffffffff;
        #pragma unroll
        for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
            temp = __hadd2(__shfl_xor_sync(mask, temp, i), temp);
        }
    
        if(lane_id == 0) C_smem[warp_id] = __hadd(temp.x,temp.y);
    }
    __syncthreads();
    const size_t warp_id = threadIdx.x / WARP_SIZE;
    if(warp_id == 0){
        half csum = __float2half(0.0);

        #pragma unroll
        for(int i=0;i<N_def;i+=WARP_SIZE){
            half c_now = C_smem[i + lane_id];
            c_now = c_now * c_now;
            csum = csum + c_now;
        }

        constexpr unsigned int mask = 0xffffffff;
        #pragma unroll
        for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
            csum = __hadd(__shfl_xor_sync(mask, csum, i), csum);
        }

        if(threadIdx.x == 0){
            if(((float)csum) > threshold) bsel[0] = high;
            else bsel[0] = low;
        }

    }
        
    // }
}
// __global__ void warp1NaiveKernelUR8(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C) {
//     // while(1){
//     __shared__ half C_smem[N_def];
//     const size_t lane_id = threadIdx.x % WARP_SIZE;
    
//     for(int iter=0; iter<N_def/WARPS_PER_BLOCK; iter++){
//         const size_t warp_id = threadIdx.x / WARP_SIZE + iter * WARPS_PER_BLOCK;

//         if(warp_id > N_def) break;
    
//         const half2 * A2 = reinterpret_cast<const half2*>(A);
//         const half2 * B2 = reinterpret_cast<const half2*>(B);
        
//         __half2 temp = make_half2(__float2half(0.0), __float2half(0.0));
//         #pragma unroll
//         for(int i=0; i<K_def/VEC; i+=WARP_SIZE*UNROLL){
//             const size_t A2_idx0 = i + lane_id;
//             const size_t B2_idx0 = warp_id * K_per_warp + i + lane_id;
//             __half2 temp0 = __hmul2(A2[A2_idx0], B2[B2_idx0]);
//             const size_t A2_idx1 = i + lane_id + WARP_SIZE * 1;
//             const size_t B2_idx1 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 1;
//             __half2 temp1 = __hmul2(A2[A2_idx1], B2[B2_idx1]);
//             const size_t A2_idx2 = i + lane_id + WARP_SIZE * 2;
//             const size_t B2_idx2 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 2;
//             __half2 temp2 = __hmul2(A2[A2_idx2], B2[B2_idx2]);
//             const size_t A2_idx3 = i + lane_id + WARP_SIZE * 3;
//             const size_t B2_idx3 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 3;
//             __half2 temp3 = __hmul2(A2[A2_idx3], B2[B2_idx3]);
//             const size_t A2_idx4 = i + lane_id + WARP_SIZE * 4;
//             const size_t B2_idx4 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 4;
//             __half2 temp4 = __hmul2(A2[A2_idx4], B2[B2_idx4]);
//             const size_t A2_idx5 = i + lane_id + WARP_SIZE * 5;
//             const size_t B2_idx5 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 5;
//             __half2 temp5 = __hmul2(A2[A2_idx5], B2[B2_idx5]);
//             const size_t A2_idx6 = i + lane_id + WARP_SIZE * 6;
//             const size_t B2_idx6 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 6;
//             __half2 temp6 = __hmul2(A2[A2_idx6], B2[B2_idx6]);
//             const size_t A2_idx7 = i + lane_id + WARP_SIZE * 7;
//             const size_t B2_idx7 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 7;
//             __half2 temp7 = __hmul2(A2[A2_idx7], B2[B2_idx7]);
//             const size_t A2_idx8 = i + lane_id + WARP_SIZE * 8;
//             const size_t B2_idx8 = warp_id * K_per_warp + i + lane_id + WARP_SIZE * 8;
    
//             temp0 = __hadd2(temp0, temp1);
//             temp2 = __hadd2(temp2, temp3);
//             temp4 = __hadd2(temp4, temp5);
//             temp6 = __hadd2(temp6, temp7);
    
//             temp0 = __hadd2(temp0, temp2);
//             temp4 = __hadd2(temp4, temp6);
    
//             temp0 = __hadd2(temp0, temp4);

//             temp = __hadd2(temp0, temp);
//         }
        
//         constexpr unsigned int mask = 0xffffffff;
//         #pragma unroll
//         for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
//             temp = __hadd2(__shfl_xor_sync(mask, temp, i), temp);
//         }
    
//         if(lane_id == 0) C_smem[warp_id] = __hadd(temp.x,temp.y);
//     }
//     __syncthreads();
//     const size_t warp_id = threadIdx.x / WARP_SIZE;
//     if(warp_id == 0){
//         half csum = __float2half(0.0);

//         #pragma unroll
//         for(int i=0;i<N_def;i+=WARP_SIZE){
//             half c_now = C_smem[i + lane_id];
//             c_now = c_now * c_now;
//             csum = csum + c_now;
//         }

//         constexpr unsigned int mask = 0xffffffff;
//         #pragma unroll
//         for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
//             csum = __hadd(__shfl_xor_sync(mask, csum, i), csum);
//         }

//         if(threadIdx.x == 0){
//             C[0] = csum;
//         }

//     }
        
//     // }
// }

// Use 1 threadblock
__global__ void thresholdKernel(const half *__restrict__ X, const float a, const float b, 
                                const float threshold, int *__restrict__ bsel, int low, int high){
    float x = 0.0;
    #pragma unroll
    for(int i=0; i<K_def/WARP_SIZE; i++){
        float xnow = (float)X[i*WARP_SIZE + threadIdx.x];
        x += xnow * xnow;
    }

    constexpr unsigned int mask = 0xffffffff;
    #pragma unroll
    for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
        x = __shfl_xor_sync(mask, x, i)+ x;
    }

    if(threadIdx.x == 0){
        if(x*a+b > threshold) bsel[0] = high;
        else bsel[0] = low;
    }
}


__global__ void warp1NaiveKernelD(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, 
                                    const int K, const int N) {
    extern __shared__ half C_smem[];
    const size_t lane_id = threadIdx.x % WARP_SIZE;
    
    for(int iter=0; iter<N/WARPS_PER_BLOCK; iter++){
        const size_t warp_id = threadIdx.x / WARP_SIZE + iter * WARPS_PER_BLOCK;

        if(warp_id > N) break;
    
        const half2 * A2 = reinterpret_cast<const half2*>(A);
        const half2 * B2 = reinterpret_cast<const half2*>(B);
        
        __half2 temp = make_half2(__float2half(0.0), __float2half(0.0));
        const int K_per_warpD = (K/VEC);
        #pragma unroll
        for(int i=0; i<K/VEC; i+=WARP_SIZE*UNROLL){
            const size_t A2_idx0 = i + lane_id;
            const size_t B2_idx0 = warp_id * K_per_warpD + i + lane_id;
            __half2 temp0 = __hmul2(A2[A2_idx0], B2[B2_idx0]);
            const size_t A2_idx1 = i + lane_id + WARP_SIZE * 1;
            const size_t B2_idx1 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 1;
            __half2 temp1 = __hmul2(A2[A2_idx1], B2[B2_idx1]);
            const size_t A2_idx2 = i + lane_id + WARP_SIZE * 2;
            const size_t B2_idx2 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 2;
            __half2 temp2 = __hmul2(A2[A2_idx2], B2[B2_idx2]);
            const size_t A2_idx3 = i + lane_id + WARP_SIZE * 3;
            const size_t B2_idx3 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 3;
            __half2 temp3 = __hmul2(A2[A2_idx3], B2[B2_idx3]);
            const size_t A2_idx4 = i + lane_id + WARP_SIZE * 4;
            const size_t B2_idx4 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 4;
            __half2 temp4 = __hmul2(A2[A2_idx4], B2[B2_idx4]);
            const size_t A2_idx5 = i + lane_id + WARP_SIZE * 5;
            const size_t B2_idx5 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 5;
            __half2 temp5 = __hmul2(A2[A2_idx5], B2[B2_idx5]);
            const size_t A2_idx6 = i + lane_id + WARP_SIZE * 6;
            const size_t B2_idx6 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 6;
            __half2 temp6 = __hmul2(A2[A2_idx6], B2[B2_idx6]);
            const size_t A2_idx7 = i + lane_id + WARP_SIZE * 7;
            const size_t B2_idx7 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 7;
            __half2 temp7 = __hmul2(A2[A2_idx7], B2[B2_idx7]);
            const size_t A2_idx8 = i + lane_id + WARP_SIZE * 8;
            const size_t B2_idx8 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 8;
            __half2 temp8 = __hmul2(A2[A2_idx8], B2[B2_idx8]);
            const size_t A2_idx9 = i + lane_id + WARP_SIZE * 9;
            const size_t B2_idx9 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 9;
            __half2 temp9 = __hmul2(A2[A2_idx9], B2[B2_idx9]);
            const size_t A2_idx10 = i + lane_id + WARP_SIZE * 10;
            const size_t B2_idx10 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 10;
            __half2 temp10 = __hmul2(A2[A2_idx10], B2[B2_idx10]);
            const size_t A2_idx11 = i + lane_id + WARP_SIZE * 11;
            const size_t B2_idx11 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 11;
            __half2 temp11 = __hmul2(A2[A2_idx11], B2[B2_idx11]);
            const size_t A2_idx12 = i + lane_id + WARP_SIZE * 12;
            const size_t B2_idx12 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 12;
            __half2 temp12 = __hmul2(A2[A2_idx12], B2[B2_idx12]);
            const size_t A2_idx13 = i + lane_id + WARP_SIZE * 13;
            const size_t B2_idx13 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 13;
            __half2 temp13 = __hmul2(A2[A2_idx13], B2[B2_idx13]);
            const size_t A2_idx14 = i + lane_id + WARP_SIZE * 14;
            const size_t B2_idx14 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 14;
            __half2 temp14 = __hmul2(A2[A2_idx14], B2[B2_idx14]);
            const size_t A2_idx15 = i + lane_id + WARP_SIZE * 15;
            const size_t B2_idx15 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 15;
            __half2 temp15 = __hmul2(A2[A2_idx15], B2[B2_idx15]);
    
            temp0 = __hadd2(temp0, temp1);
            temp2 = __hadd2(temp2, temp3);
            temp4 = __hadd2(temp4, temp5);
            temp6 = __hadd2(temp6, temp7);
            temp8 = __hadd2(temp8, temp9);
            temp10 = __hadd2(temp10, temp11);
            temp12 = __hadd2(temp12, temp13);
            temp14 = __hadd2(temp14, temp15);
    
            temp0 = __hadd2(temp0, temp2);
            temp4 = __hadd2(temp4, temp6);
            temp8 = __hadd2(temp8, temp10);
            temp12 = __hadd2(temp12, temp14);
    
            temp0 = __hadd2(temp0, temp4);
            temp8 = __hadd2(temp8, temp12);
    
            temp0 = __hadd2(temp0, temp8);

            temp = __hadd2(temp0, temp);
        }
        
        constexpr unsigned int mask = 0xffffffff;
        #pragma unroll
        for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
            temp = __hadd2(__shfl_xor_sync(mask, temp, i), temp);
        }
    
        if(lane_id == 0) C_smem[warp_id] = __hadd(temp.x,temp.y);
    }
    __syncthreads();
    const size_t warp_id = threadIdx.x / WARP_SIZE;
    if(warp_id == 0){
        half csum = __float2half(0.0);

        #pragma unroll
        for(int i=0;i<N;i+=WARP_SIZE){
            half c_now = C_smem[i + lane_id];
            c_now = c_now * c_now;
            csum = csum + c_now;
        }

        constexpr unsigned int mask = 0xffffffff;
        #pragma unroll
        for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
            csum = __hadd(__shfl_xor_sync(mask, csum, i), csum);
        }

        if(threadIdx.x == 0){
            C[0] = csum;
        }

    }
        
    // }
}
__global__ void warp1NaiveKernelTHD(const half *__restrict__ A, const half *__restrict__ B, float threshold,
                                int *__restrict__ bsel, int *__restrict__ low_p, int *__restrict__ high_p, 
                                const int K, const int N) {
    extern __shared__ half C_smem[];
    const size_t lane_id = threadIdx.x % WARP_SIZE;

    int low = *low_p;
    int high = *high_p;
    
    for(int iter=0; iter<N/WARPS_PER_BLOCK; iter++){
        const size_t warp_id = threadIdx.x / WARP_SIZE + iter * WARPS_PER_BLOCK;

        if(warp_id > N) break;
    
        const half2 * A2 = reinterpret_cast<const half2*>(A);
        const half2 * B2 = reinterpret_cast<const half2*>(B);
        
        __half2 temp = make_half2(__float2half(0.0), __float2half(0.0));
        const int K_per_warpD = (K/VEC);
        #pragma unroll
        for(int i=0; i<K/VEC; i+=WARP_SIZE*UNROLL){
            const size_t A2_idx0 = i + lane_id;
            const size_t B2_idx0 = warp_id * K_per_warpD + i + lane_id;
            __half2 temp0 = __hmul2(A2[A2_idx0], B2[B2_idx0]);
            const size_t A2_idx1 = i + lane_id + WARP_SIZE * 1;
            const size_t B2_idx1 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 1;
            __half2 temp1 = __hmul2(A2[A2_idx1], B2[B2_idx1]);
            const size_t A2_idx2 = i + lane_id + WARP_SIZE * 2;
            const size_t B2_idx2 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 2;
            __half2 temp2 = __hmul2(A2[A2_idx2], B2[B2_idx2]);
            const size_t A2_idx3 = i + lane_id + WARP_SIZE * 3;
            const size_t B2_idx3 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 3;
            __half2 temp3 = __hmul2(A2[A2_idx3], B2[B2_idx3]);
            const size_t A2_idx4 = i + lane_id + WARP_SIZE * 4;
            const size_t B2_idx4 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 4;
            __half2 temp4 = __hmul2(A2[A2_idx4], B2[B2_idx4]);
            const size_t A2_idx5 = i + lane_id + WARP_SIZE * 5;
            const size_t B2_idx5 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 5;
            __half2 temp5 = __hmul2(A2[A2_idx5], B2[B2_idx5]);
            const size_t A2_idx6 = i + lane_id + WARP_SIZE * 6;
            const size_t B2_idx6 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 6;
            __half2 temp6 = __hmul2(A2[A2_idx6], B2[B2_idx6]);
            const size_t A2_idx7 = i + lane_id + WARP_SIZE * 7;
            const size_t B2_idx7 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 7;
            __half2 temp7 = __hmul2(A2[A2_idx7], B2[B2_idx7]);
            const size_t A2_idx8 = i + lane_id + WARP_SIZE * 8;
            const size_t B2_idx8 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 8;
            __half2 temp8 = __hmul2(A2[A2_idx8], B2[B2_idx8]);
            const size_t A2_idx9 = i + lane_id + WARP_SIZE * 9;
            const size_t B2_idx9 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 9;
            __half2 temp9 = __hmul2(A2[A2_idx9], B2[B2_idx9]);
            const size_t A2_idx10 = i + lane_id + WARP_SIZE * 10;
            const size_t B2_idx10 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 10;
            __half2 temp10 = __hmul2(A2[A2_idx10], B2[B2_idx10]);
            const size_t A2_idx11 = i + lane_id + WARP_SIZE * 11;
            const size_t B2_idx11 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 11;
            __half2 temp11 = __hmul2(A2[A2_idx11], B2[B2_idx11]);
            const size_t A2_idx12 = i + lane_id + WARP_SIZE * 12;
            const size_t B2_idx12 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 12;
            __half2 temp12 = __hmul2(A2[A2_idx12], B2[B2_idx12]);
            const size_t A2_idx13 = i + lane_id + WARP_SIZE * 13;
            const size_t B2_idx13 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 13;
            __half2 temp13 = __hmul2(A2[A2_idx13], B2[B2_idx13]);
            const size_t A2_idx14 = i + lane_id + WARP_SIZE * 14;
            const size_t B2_idx14 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 14;
            __half2 temp14 = __hmul2(A2[A2_idx14], B2[B2_idx14]);
            const size_t A2_idx15 = i + lane_id + WARP_SIZE * 15;
            const size_t B2_idx15 = warp_id * K_per_warpD + i + lane_id + WARP_SIZE * 15;
            __half2 temp15 = __hmul2(A2[A2_idx15], B2[B2_idx15]);
    
            temp0 = __hadd2(temp0, temp1);
            temp2 = __hadd2(temp2, temp3);
            temp4 = __hadd2(temp4, temp5);
            temp6 = __hadd2(temp6, temp7);
            temp8 = __hadd2(temp8, temp9);
            temp10 = __hadd2(temp10, temp11);
            temp12 = __hadd2(temp12, temp13);
            temp14 = __hadd2(temp14, temp15);
    
            temp0 = __hadd2(temp0, temp2);
            temp4 = __hadd2(temp4, temp6);
            temp8 = __hadd2(temp8, temp10);
            temp12 = __hadd2(temp12, temp14);
    
            temp0 = __hadd2(temp0, temp4);
            temp8 = __hadd2(temp8, temp12);
    
            temp0 = __hadd2(temp0, temp8);

            temp = __hadd2(temp0, temp);
        }
        
        constexpr unsigned int mask = 0xffffffff;
        #pragma unroll
        for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
            temp = __hadd2(__shfl_xor_sync(mask, temp, i), temp);
        }
    
        if(lane_id == 0) C_smem[warp_id] = __hadd(temp.x,temp.y);
    }
    __syncthreads();
    const size_t warp_id = threadIdx.x / WARP_SIZE;
    if(warp_id == 0){
        half csum = __float2half(0.0);

        #pragma unroll
        for(int i=0;i<N;i+=WARP_SIZE){
            half c_now = C_smem[i + lane_id];
            c_now = c_now * c_now;
            csum = csum + c_now;
        }

        constexpr unsigned int mask = 0xffffffff;
        #pragma unroll
        for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
            csum = __hadd(__shfl_xor_sync(mask, csum, i), csum);
        }

        if(threadIdx.x == 0){
            if(((float)csum) > threshold) bsel[0] = high;
            else bsel[0] = low;
        }

    }
}

// Use 1 threadblock
__global__ void thresholdKernelD(const half *__restrict__ X, const float a, const float b, 
                                const float threshold, int *__restrict__ bsel, int *__restrict__ low_p, int *__restrict__ high_p, const int K){
    
    int low = *low_p;
    int high = *high_p;
    float x = 0.0;
    #pragma unroll
    for(int i=0; i<K/WARP_SIZE; i++){
        float xnow = (float)X[i*WARP_SIZE + threadIdx.x];
        x += xnow * xnow;
    }

    constexpr unsigned int mask = 0xffffffff;
    #pragma unroll
    for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
        x = __shfl_xor_sync(mask, x, i)+ x;
    }

    if(threadIdx.x == 0){
        if(x*a+b > threshold) bsel[0] = high;
        else bsel[0] = low;
    }
}

#define FULL_WARPS_PER_BLOCK 8
#define FULL_THREADS_PER_BLOCK (WARP_SIZE * FULL_WARPS_PER_BLOCK)  // WARP_SIZE * WARPS_PER_BLOCK
#define FULL_UNROLL 16
#define FULL_VEC 2
#define FULL_K_per_warp (K_def/FULL_VEC)
// One threadblock for one col
// blockidx == 0 will wrap up
// WARP_SIZE * FULL_WARPS_PER_BLOCK * FULL_UNROLL = K_def / FULL_VEC
__global__ void warp1full3(const half *__restrict__ A, 
                            const half *__restrict__ Ba, const half *__restrict__ Bb, const half *__restrict__ Bc,
                            half *__restrict__ Ca, half *__restrict__ Cb,half *__restrict__ Cc) {

    const half *__restrict__ B = (blockIdx.y == 0)?Ba:((blockIdx.y == 1)?Bb:Bc);
    half *__restrict__ C = (blockIdx.y == 0)?Ca:((blockIdx.y == 1)?Cb:Cc);

    __shared__ half C_smem[FULL_WARPS_PER_BLOCK];
    const size_t lane_id = threadIdx.x % WARP_SIZE;
    const size_t col_id = blockIdx.x;
    if(col_id >= N_def) return;
    
    const size_t warp_id = threadIdx.x / WARP_SIZE;

    const half2 * A2 = reinterpret_cast<const half2*>(A);
    const half2 * B2 = reinterpret_cast<const half2*>(B);
    
    __half2 temp = make_half2(__float2half(0.0), __float2half(0.0));
    const size_t i = warp_id * K_def/FULL_VEC/FULL_WARPS_PER_BLOCK;
    const size_t A2_idx0 = i + lane_id;
    const size_t B2_idx0 = K_def/FULL_VEC*col_id + i + lane_id;
    __half2 temp0 = __hmul2(A2[A2_idx0], B2[B2_idx0]);
    const size_t A2_idx1 = i + lane_id + WARP_SIZE * 1;
    const size_t B2_idx1 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 1;
    __half2 temp1 = __hmul2(A2[A2_idx1], B2[B2_idx1]);
    const size_t A2_idx2 = i + lane_id + WARP_SIZE * 2;
    const size_t B2_idx2 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 2;
    __half2 temp2 = __hmul2(A2[A2_idx2], B2[B2_idx2]);
    const size_t A2_idx3 = i + lane_id + WARP_SIZE * 3;
    const size_t B2_idx3 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 3;
    __half2 temp3 = __hmul2(A2[A2_idx3], B2[B2_idx3]);
    const size_t A2_idx4 = i + lane_id + WARP_SIZE * 4;
    const size_t B2_idx4 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 4;
    __half2 temp4 = __hmul2(A2[A2_idx4], B2[B2_idx4]);
    const size_t A2_idx5 = i + lane_id + WARP_SIZE * 5;
    const size_t B2_idx5 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 5;
    __half2 temp5 = __hmul2(A2[A2_idx5], B2[B2_idx5]);
    const size_t A2_idx6 = i + lane_id + WARP_SIZE * 6;
    const size_t B2_idx6 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 6;
    __half2 temp6 = __hmul2(A2[A2_idx6], B2[B2_idx6]);
    const size_t A2_idx7 = i + lane_id + WARP_SIZE * 7;
    const size_t B2_idx7 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 7;
    __half2 temp7 = __hmul2(A2[A2_idx7], B2[B2_idx7]);
    const size_t A2_idx8 = i + lane_id + WARP_SIZE * 8;
    const size_t B2_idx8 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 8;
    __half2 temp8 = __hmul2(A2[A2_idx8], B2[B2_idx8]);
    const size_t A2_idx9 = i + lane_id + WARP_SIZE * 9;
    const size_t B2_idx9 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 9;
    __half2 temp9 = __hmul2(A2[A2_idx9], B2[B2_idx9]);
    const size_t A2_idx10 = i + lane_id + WARP_SIZE * 10;
    const size_t B2_idx10 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 10;
    __half2 temp10 = __hmul2(A2[A2_idx10], B2[B2_idx10]);
    const size_t A2_idx11 = i + lane_id + WARP_SIZE * 11;
    const size_t B2_idx11 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 11;
    __half2 temp11 = __hmul2(A2[A2_idx11], B2[B2_idx11]);
    const size_t A2_idx12 = i + lane_id + WARP_SIZE * 12;
    const size_t B2_idx12 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 12;
    __half2 temp12 = __hmul2(A2[A2_idx12], B2[B2_idx12]);
    const size_t A2_idx13 = i + lane_id + WARP_SIZE * 13;
    const size_t B2_idx13 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 13;
    __half2 temp13 = __hmul2(A2[A2_idx13], B2[B2_idx13]);
    const size_t A2_idx14 = i + lane_id + WARP_SIZE * 14;
    const size_t B2_idx14 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 14;
    __half2 temp14 = __hmul2(A2[A2_idx14], B2[B2_idx14]);
    const size_t A2_idx15 = i + lane_id + WARP_SIZE * 15;
    const size_t B2_idx15 = K_def/FULL_VEC*col_id + i + lane_id + WARP_SIZE * 15;
    __half2 temp15 = __hmul2(A2[A2_idx15], B2[B2_idx15]);

    temp0 = __hadd2(temp0, temp1);
    temp2 = __hadd2(temp2, temp3);
    temp4 = __hadd2(temp4, temp5);
    temp6 = __hadd2(temp6, temp7);
    temp8 = __hadd2(temp8, temp9);
    temp10 = __hadd2(temp10, temp11);
    temp12 = __hadd2(temp12, temp13);
    temp14 = __hadd2(temp14, temp15);

    temp0 = __hadd2(temp0, temp2);
    temp4 = __hadd2(temp4, temp6);
    temp8 = __hadd2(temp8, temp10);
    temp12 = __hadd2(temp12, temp14);

    temp0 = __hadd2(temp0, temp4);
    temp8 = __hadd2(temp8, temp12);

    temp0 = __hadd2(temp0, temp8);

    temp = __hadd2(temp0, temp);
    
    constexpr unsigned int mask = 0xffffffff;
    #pragma unroll
    for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
        temp = __hadd2(__shfl_xor_sync(mask, temp, i), temp);
    }

    if(lane_id == 0) C[warp_id] = __hadd(temp.x,temp.y);
    __syncthreads();
}

// Use 1 block for each, 32 threads, total 3 blocks
// WARP_SIZE * FULL_VEC = N
__global__ void reduce3(const half *__restrict__ Ca,const half *__restrict__ Cb,const half *__restrict__ Cc, 
                        const float threshold1, const float threshold2, const float threshold3, 
                        int *__restrict__ bsel1, int *__restrict__ bsel2, int *__restrict__ bsel3, 
                        int *__restrict__ low_p, int *__restrict__ high_p){
    const half *__restrict__ C = (blockIdx.x==0)?Ca:((blockIdx.x==1)?Cb:Cc);
    const float threshold = (blockIdx.x==0)?threshold1:((blockIdx.x==1)?threshold2:threshold3);
    int *__restrict__ bsel = (blockIdx.x==0)?bsel1:((blockIdx.x==1)?bsel2:bsel3);
    
    int low = *low_p;
    int high = *high_p;

    const half2 * C2 = reinterpret_cast<const half2*>(C);
    half2 c = C2[threadIdx.x];
    c = __hmul2(c, c);
    constexpr unsigned int mask = 0xffffffff;
    #pragma unroll
    for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
        c = __hadd2(__shfl_xor_sync(mask, c, i), c);
    }
    if(threadIdx.x == 0){
        if((float)(__hadd(c.x, c.y)) > threshold) bsel[0] = high;
        else bsel[0] = low;
    }
}


#define RMS_WARPS_PER_BLOCK 4
#define RMS_THREADS_PER_BLOCK (WARP_SIZE * RMS_WARPS_PER_BLOCK)
__global__ void rmsnorm(const half *__restrict__ W, const half *__restrict__ x, half *__restrict__ y){
    __shared__ float x_smem[K_def];
    __shared__ float xnorm[RMS_WARPS_PER_BLOCK];
    const size_t warp_id = threadIdx.x / WARP_SIZE;
    const size_t lane_id = threadIdx.x % WARP_SIZE;
    
    float sqsum = 0.0;
    #pragma unroll
    for(int i=0; i<K_def; i+=RMS_THREADS_PER_BLOCK){
        const size_t addr = i + warp_id * WARP_SIZE + lane_id;
        float xnow = (float)(x[addr]);
        x_smem[addr] = xnow;
        sqsum += xnow * xnow;
    }

    constexpr unsigned int mask = 0xffffffff;
    #pragma unroll
    for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
        sqsum += __shfl_xor_sync(mask, sqsum, i);
    }

    if(lane_id == 0) xnorm[warp_id] = sqsum;

    __syncthreads();

    float rsqrt = 0.0;
    
    #pragma unroll
    for (int i=0; i<RMS_WARPS_PER_BLOCK; i++){
        rsqrt += xnorm[i];
    }
    rsqrt = 1.0 / sqrt(rsqrt);

    #pragma unroll
    for(int i=0; i<K_def; i+=RMS_THREADS_PER_BLOCK){
        const size_t addr = i + warp_id * WARP_SIZE + lane_id;
        y[addr] = __hmul(__float2half(x_smem[addr] * rsqrt),W[addr]);
    }
}

void anyprec_gemv_sel_two(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth,
    torch::Tensor jl,
    torch::Tensor res
) {
    cudaStream_t stream;
    cudaStreamCreate(&stream);
    
    warp1NaiveKernel<<<1,THREADS_PER_BLOCK,0,stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>());

    anyprec_gemv(input, output, qweight, lut, bitwidth);
}

void gemvNormTH(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    warp1NaiveKernelTH<<<1,THREADS_PER_BLOCK,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        threshold,
                                                        (int *)bsel.data_ptr<int>(), 
                                                        (int *)low.data_ptr<int>(),
                                                        (int *)high.data_ptr<int>());
    cudaEventRecord(end_event_, now_stream);
    // cudaError_t err = cudaGetLastError();
    // if (err != cudaSuccess) {
    //     printf("My CUDA Error: %s\n", cudaGetErrorString(err));
    // }
    // printf("TH: bsel = %p\n", (int *)bsel.data_ptr<int>());
}

// q->k->v
void gemvNormTHq(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t mid1_event_ = sne_->mid1_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    int K = input.size(2);
    int N = jl.size(0);
    // printf("K=%d,N=%d\n", K, N);
    warp1NaiveKernelTHD<<<1,THREADS_PER_BLOCK,N*1,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        threshold,
                                                        (int *)bsel.data_ptr<int>(), 
                                                        (int *)low.data_ptr<int>(),
                                                        (int *)high.data_ptr<int>(),
                                                        K, N);
    cudaEventRecord(mid1_event_, now_stream);
}

void gemvNormTHk(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t mid1_event_ = sne_->mid1_event;
    cudaEvent_t mid2_event_ = sne_->mid2_event;

    cudaStreamWaitEvent(now_stream, mid1_event_, 0);
    int K = input.size(2);
    int N = jl.size(0);
    warp1NaiveKernelTHD<<<1,THREADS_PER_BLOCK,N*1,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        threshold,
                                                        (int *)bsel.data_ptr<int>(), 
                                                        (int *)low.data_ptr<int>(),
                                                        (int *)high.data_ptr<int>(),
                                                        K, N);
    cudaEventRecord(mid2_event_, now_stream);
}

void gemvNormTHv(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t mid2_event_ = sne_->mid2_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaStreamWaitEvent(now_stream, mid2_event_, 0);
    int K = input.size(2);
    int N = jl.size(0);
    warp1NaiveKernelTHD<<<1,THREADS_PER_BLOCK,N*1,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        threshold,
                                                        (int *)bsel.data_ptr<int>(), 
                                                        (int *)low.data_ptr<int>(),
                                                        (int *)high.data_ptr<int>(),
                                                        K, N);
    cudaEventRecord(end_event_, now_stream);
}

// g->u
void gemvNormTHg(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t mid1_event_ = sne_->mid1_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    int K = input.size(2);
    int N = jl.size(0);
    warp1NaiveKernelTHD<<<1,THREADS_PER_BLOCK,N*1,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        threshold,
                                                        (int *)bsel.data_ptr<int>(), 
                                                        (int *)low.data_ptr<int>(),
                                                        (int *)high.data_ptr<int>(),
                                                        K, N);
    cudaEventRecord(mid1_event_, now_stream);
}

void gemvNormTHu(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t mid1_event_ = sne_->mid1_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaStreamWaitEvent(now_stream, mid1_event_, 0);
    int K = input.size(2);
    int N = jl.size(0);
    warp1NaiveKernelTHD<<<1,THREADS_PER_BLOCK,N*1,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        threshold,
                                                        (int *)bsel.data_ptr<int>(), 
                                                        (int *)low.data_ptr<int>(),
                                                        (int *)high.data_ptr<int>(),
                                                        K, N);
    cudaEventRecord(end_event_, now_stream);
}


// qkv fused
void gemvNormTHqkv(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    int K = input.size(2);
    int N = jl.size(0);
    warp1NaiveKernelTHD<<<1,THREADS_PER_BLOCK,N*1,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        threshold,
                                                        (int *)bsel.data_ptr<int>(), 
                                                        (int *)low.data_ptr<int>(),
                                                        (int *)high.data_ptr<int>(),
                                                        K, N);
    cudaEventRecord(end_event_, now_stream);
}

// gu fused
void gemvNormTHgu(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    int K = input.size(2);
    int N = jl.size(0);
    warp1NaiveKernelTHD<<<1,THREADS_PER_BLOCK,N*1,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
                                                        threshold,
                                                        (int *)bsel.data_ptr<int>(), 
                                                        (int *)low.data_ptr<int>(),
                                                        (int *)high.data_ptr<int>(),
                                                        K, N);
    cudaEventRecord(end_event_, now_stream);
}


// q->k->v
void normTHq(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t mid1_event_ = sne_->mid1_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    int K = input.size(2);
    thresholdKernelD<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
                                                    a, b, threshold, 
                                                    (int *)bsel.data_ptr<int>(), 
                                                    (int *)low.data_ptr<int>(),
                                                    (int *)high.data_ptr<int>(),
                                                    K);
    cudaEventRecord(mid1_event_, now_stream);
}

void normTHk(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t mid1_event_ = sne_->mid1_event;
    cudaEvent_t mid2_event_ = sne_->mid2_event;

    cudaStreamWaitEvent(now_stream, mid1_event_, 0);
    int K = input.size(2);
    thresholdKernelD<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
                                                    a, b, threshold, 
                                                    (int *)bsel.data_ptr<int>(),
                                                    (int *)low.data_ptr<int>(),
                                                    (int *)high.data_ptr<int>(),
                                                    K);
    cudaEventRecord(mid2_event_, now_stream);
}

void normTHv(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t mid2_event_ = sne_->mid2_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaStreamWaitEvent(now_stream, mid2_event_, 0);
    int K = input.size(2);
    thresholdKernelD<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
                                                    a, b, threshold, 
                                                    (int *)bsel.data_ptr<int>(),
                                                    (int *)low.data_ptr<int>(),
                                                    (int *)high.data_ptr<int>(),
                                                    K);
    cudaEventRecord(end_event_, now_stream);
}

// g->u
void normTHg(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t mid1_event_ = sne_->mid1_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    int K = input.size(2);
    thresholdKernelD<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
                                                    a, b, threshold, 
                                                    (int *)bsel.data_ptr<int>(),
                                                    (int *)low.data_ptr<int>(),
                                                    (int *)high.data_ptr<int>(),
                                                    K);
    cudaEventRecord(mid1_event_, now_stream);
}

void normTHu(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t mid1_event_ = sne_->mid1_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaStreamWaitEvent(now_stream, mid1_event_, 0);
    int K = input.size(2);
    thresholdKernelD<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
                                                    a, b, threshold, 
                                                    (int *)bsel.data_ptr<int>(),
                                                    (int *)low.data_ptr<int>(),
                                                    (int *)high.data_ptr<int>(),
                                                    K);
    cudaEventRecord(end_event_, now_stream);
}


// qkv fused
void normTHqkv(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    int K = input.size(2);
    thresholdKernelD<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
                                                    a, b, threshold, 
                                                    (int *)bsel.data_ptr<int>(),
                                                    (int *)low.data_ptr<int>(),
                                                    (int *)high.data_ptr<int>(),
                                                    K);
    cudaEventRecord(end_event_, now_stream);
}

// gu fused
void normTHgu(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    int K = input.size(2);
    thresholdKernelD<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
                                                    a, b, threshold, 
                                                    (int *)bsel.data_ptr<int>(),
                                                    (int *)low.data_ptr<int>(),
                                                    (int *)high.data_ptr<int>(),
                                                    K);
    cudaEventRecord(end_event_, now_stream);
}



void gemvNormTH2(
    torch::Tensor input,
    torch::Tensor jl1,
    torch::Tensor jl2,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    float threshold1,
    float threshold2,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    assert(false);
    // warp1NaiveKernelTH<<<1,THREADS_PER_BLOCK,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     (half *)jl1.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     threshold1,
    //                                                     (int *)bsel1.data_ptr<int>(), 
    //                                                     3, 4);
    // warp1NaiveKernelTH<<<1,THREADS_PER_BLOCK,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     (half *)jl2.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     threshold2,
    //                                                     (int *)bsel2.data_ptr<int>(), 
    //                                                     3, 4);
    cudaEventRecord(end_event_, now_stream);
    // cudaError_t err = cudaGetLastError();
    // if (err != cudaSuccess) {
    //     printf("My CUDA Error: %s\n", cudaGetErrorString(err));
    // }
    // printf("TH: bsel = %p\n", (int *)bsel.data_ptr<int>());
}

void gemvNormTH3(
    torch::Tensor input,
    torch::Tensor jl1,
    torch::Tensor jl2,
    torch::Tensor jl3,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    torch::Tensor bsel3,
    float threshold1,
    float threshold2,
    float threshold3,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    assert(false);
    // warp1NaiveKernelTH<<<1,THREADS_PER_BLOCK,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     (half *)jl1.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     threshold1,
    //                                                     (int *)bsel1.data_ptr<int>(), 
    //                                                     3, 4);
    // warp1NaiveKernelTH<<<1,THREADS_PER_BLOCK,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     (half *)jl2.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     threshold2,
    //                                                     (int *)bsel2.data_ptr<int>(), 
    //                                                     3, 4);
    // warp1NaiveKernelTH<<<1,THREADS_PER_BLOCK,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     (half *)jl3.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     threshold3,
    //                                                     (int *)bsel3.data_ptr<int>(), 
    //                                                     3, 4);
    cudaEventRecord(end_event_, now_stream);
}

void fakeTrigger(
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent_full*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    cudaEventRecord(end_event_, now_stream);
}

void gemvNormTH3Full(
    torch::Tensor input,
    torch::Tensor jl1,
    torch::Tensor jl2,
    torch::Tensor jl3,
    torch::Tensor res1,
    torch::Tensor res2,
    torch::Tensor res3,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    torch::Tensor bsel3,
    float threshold1,
    float threshold2,
    float threshold3,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    dim3 grid(N_def, 3, 1), block(FULL_THREADS_PER_BLOCK, 1, 1);
    assert(false);
    // warp1full3<<<grid,block,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                         (half *)jl1.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                         (half *)jl2.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                         (half *)jl3.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                         (half *)res1.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                         (half *)res2.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                         (half *)res3.data_ptr<ATEN_DTYPE(DataType::FP16)>());
    // reduce3<<<3,WARP_SIZE,0,now_stream>>>((half *)res1.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                         (half *)res2.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                         (half *)res3.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                         threshold1, threshold2, threshold3,
    //                                         (int *)bsel1.data_ptr<int>(), 
    //                                         (int *)bsel2.data_ptr<int>(), 
    //                                         (int *)bsel3.data_ptr<int>(), 
    //                                         3, 4);
    cudaEventRecord(end_event_, now_stream);
}

void normTH(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    uintptr_t sne
){
    auto sne_ = reinterpret_cast<StreamNevent*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    assert(false);
    // thresholdKernel<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
    //                 a, b, threshold, 
    //                 (int *)bsel.data_ptr<int>(), 3, 4);
    cudaEventRecord(end_event_, now_stream);
}

void normTH2(
    torch::Tensor input,
    float a1,
    float a2,
    float b1,
    float b2,
    float threshold1,
    float threshold2,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    uintptr_t sne
){
    auto sne_ = reinterpret_cast<StreamNevent*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    assert(false);
    // thresholdKernel<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
    //                 a1, b1, threshold1, 
    //                 (int *)bsel1.data_ptr<int>(), 3, 4);
    // thresholdKernel<<<1,WARP_SIZE,0,now_stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
    //                 a2, b2, threshold2, 
    //                 (int *)bsel2.data_ptr<int>(), 3, 4);
    cudaEventRecord(end_event_, now_stream);
}

void lnNormTH2(
    torch::Tensor input,
    torch::Tensor normW,
    torch::Tensor res,
    float a1,
    float a2,
    float b1,
    float b2,
    float threshold1,
    float threshold2,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    uintptr_t sne
){
    auto sne_ = reinterpret_cast<StreamNevent*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    assert(false);
    // rmsnorm<<<1,RMS_THREADS_PER_BLOCK,0,now_stream>>>((half *)normW.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                 (half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                 (half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>());
    // thresholdKernel<<<1,WARP_SIZE,0,now_stream>>>((half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
    //                 a1, b1, threshold1, 
    //                 (int *)bsel1.data_ptr<int>(), 3, 4);
    // thresholdKernel<<<1,WARP_SIZE,0,now_stream>>>((half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
    //                 a2, b2, threshold2, 
    //                 (int *)bsel2.data_ptr<int>(), 3, 4);
    cudaEventRecord(end_event_, now_stream);
}

void lnNormTH3(
    torch::Tensor input,
    torch::Tensor normW,
    torch::Tensor res,
    float a1,
    float a2,
    float a3,
    float b1,
    float b2,
    float b3,
    float threshold1,
    float threshold2,
    float threshold3,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    uintptr_t sne
){
    auto sne_ = reinterpret_cast<StreamNevent*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    assert(false);
    // rmsnorm<<<1,RMS_THREADS_PER_BLOCK,0,now_stream>>>((half *)normW.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                 (half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                 (half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>());
    // thresholdKernel<<<1,WARP_SIZE,0,now_stream>>>((half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
    //                 a1, b1, threshold1, 
    //                 (int *)bsel1.data_ptr<int>(), 3, 4);
    // thresholdKernel<<<1,WARP_SIZE,0,now_stream>>>((half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>(), 
    //                 a2, b2, threshold2, 
    //                 (int *)bsel2.data_ptr<int>(), 3, 4);
    cudaEventRecord(end_event_, now_stream);
}


void lnGemvNormTH(
    torch::Tensor input,
    torch::Tensor normW,
    torch::Tensor res,
    torch::Tensor jl,
    torch::Tensor bsel,
    float threshold,
    uintptr_t sne
) {
    auto sne_ = reinterpret_cast<StreamNevent*>(sne);
    cudaStream_t now_stream = sne_->sub_stream;
    cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();

    cudaEvent_t start_event_ = sne_->start_event;
    cudaEvent_t end_event_ = sne_->end_event;

    cudaEventRecord(start_event_, main_stream);
    cudaStreamWaitEvent(now_stream, start_event_, 0);
    assert(false);
    // rmsnorm<<<1,RMS_THREADS_PER_BLOCK,0,now_stream>>>((half *)normW.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                 (half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                 (half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>());
    // warp1NaiveKernelTH<<<1,THREADS_PER_BLOCK,0,now_stream>>>((half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
    //                                                     threshold,
    //                                                     (int *)bsel.data_ptr<int>(), 
    //                                                     3, 4);
    cudaEventRecord(end_event_, now_stream);
    // cudaError_t err = cudaGetLastError();
    // if (err != cudaSuccess) {
    //     printf("My CUDA Error: %s\n", cudaGetErrorString(err));
    // }
    // printf("TH: bsel = %p\n", (int *)bsel.data_ptr<int>());
}


// uintptr_t create_streamNevent (
// ) {
//     cudaStream_t sub_stream;
//     cudaStreamCreateWithFlags(&sub_stream, cudaStreamNonBlocking);

//     cudaEvent_t start_event;
//     cudaEventCreate(&start_event);
//     cudaEvent_t end_event;
//     cudaEventCreate(&end_event);

//     auto* sne = new StreamNevent(
//         sub_stream,
//         start_event,
//         end_event
//     );
//     return reinterpret_cast<uintptr_t>(sne);
// }

// uintptr_t create_streamNevent_intra (
// ) {
//     cudaStream_t sub_stream;
//     cudaStreamCreateWithFlags(&sub_stream, cudaStreamNonBlocking);

//     cudaEvent_t start_event;
//     cudaEventCreate(&start_event);
//     cudaEvent_t end_event;
//     cudaEventCreate(&end_event);
//     cudaEvent_t mid_event;
//     cudaEventCreate(&mid_event);

//     auto* sne = new StreamNevent_intra(
//         sub_stream,
//         start_event,
//         end_event,
//         mid_event
//     );
//     return reinterpret_cast<uintptr_t>(sne);
// }

uintptr_t create_streamNevent_full (
) {
    cudaStream_t sub_stream;
    cudaStreamCreateWithFlags(&sub_stream, cudaStreamNonBlocking);

    cudaEvent_t start_event;
    cudaEventCreate(&start_event);
    cudaEvent_t end_event;
    cudaEventCreate(&end_event);
    cudaEvent_t mid1_event;
    cudaEventCreate(&mid1_event);
    cudaEvent_t mid2_event;
    cudaEventCreate(&mid2_event);

    auto* sne = new StreamNevent_full(
        sub_stream,
        start_event,
        end_event,
        mid1_event,
        mid2_event
    );
    return reinterpret_cast<uintptr_t>(sne);
}

// void anyprec_gemv_sel_two(
//     torch::Tensor input,
//     torch::Tensor output,
//     torch::Tensor qweight,
//     torch::Tensor lut,
//     int bitwidth,
//     torch::Tensor jl,
//     torch::Tensor res
// ) {
//     cudaStream_t stream;
//     cudaStreamCreate(&stream);
//     cudaEvent_t event_1;
//     cudaEventCreate(&event_1);
//     cudaEvent_t event_2;
//     cudaEventCreate(&event_2);
    
//     cudaStream_t main_stream = at::cuda::getCurrentCUDAStream();
//     cudaEventRecord(event_1, stream);
//     warp1NaiveKernelUR8<<<1,THREADS_PER_BLOCK,0,stream>>>((half *)input.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
//                                                         (half *)jl.data_ptr<ATEN_DTYPE(DataType::FP16)>(),
//                                                         (half *)res.data_ptr<ATEN_DTYPE(DataType::FP16)>());

//     cudaStreamWaitEvent(main_stream, event_1, 0);
//     anyprec_gemv(input, output, qweight, lut, bitwidth);
//     cudaEventRecord(event_2, main_stream);
//     cudaStreamWaitEvent(stream, event_2, 0);
// }

////////////////////////////////////////////////////////////////////////////////
//                               ANYPREC DEQUANT
////////////////////////////////////////////////////////////////////////////////

template<DataType DT>
torch::Tensor anyprec_dequant_templated(
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth,
    cudaStream_t stream
) {
    assert(qweight.ndimension() == 3 && qweight.dtype() == torch::kInt && (lut.dtype() == torch::kHalf || lut.dtype() == torch::kBFloat16));
    assert(qweight.device() == lut.device() && qweight.is_cuda());
    assert(bitwidth >= 2 && bitwidth <= 8);
    const int N = qweight.size(1);
    const int K = qweight.size(2) * 32;

    auto options = torch::TensorOptions().dtype(lut.dtype()).device(qweight.device());
    at::Tensor weight = torch::empty({N, K}, options);

    anyprec_dequant_kbit<DT>(
        (uint32_t *)qweight.data_ptr<int>(),
        N, K,
        (FP_DTYPE(DT) *)lut.data_ptr<ATEN_DTYPE(DT)>(),
        (FP_DTYPE(DT) *)weight.data_ptr<ATEN_DTYPE(DT)>(),
        bitwidth,
        stream
    );

    return weight;
}

torch::Tensor anyprec_dequant_stream(
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth,
    cudaStream_t stream
) {
    auto dtype = lut.scalar_type();
    if (dtype == at::kFloat) {
        TORCH_CHECK(false, "Any-Precision Dequantization does not support float data type. Please use half or bfloat16.");
        //return anyprec_dequant_templated<DataType::FP32>(qweight, lut, bitwidth, stream);
    } else if (dtype == at::kHalf) {
        return anyprec_dequant_templated<DataType::FP16>(qweight, lut, bitwidth, stream);
    } else if (dtype == at::kBFloat16) {
        return anyprec_dequant_templated<DataType::BF16>(qweight, lut, bitwidth, stream);
    } else {
        TORCH_CHECK(false, "Unsupported data type.");
    }
}

torch::Tensor anyprec_dequant(
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth
) {
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    return anyprec_dequant_stream(qweight, lut, bitwidth, stream);
}

////////////////////////////////////////////////////////////////////////////////
//                                     LUTGEMM
////////////////////////////////////////////////////////////////////////////////

template<DataType DT>
void lutgemm_gemv_templated(
    torch::Tensor input,
    torch::Tensor q_weight,
    torch::Tensor alpha,
    torch::Tensor q_bias,
    torch::Tensor output,
    int bitwidth,
    int group_size,
    cudaStream_t stream
) {
    uint32_t kSize = input.size(2);
    uint32_t mSize = output.size(2);

    dim3 grid((mSize + M_TILE_SIZE - 1) / M_TILE_SIZE,
              (kSize + K_TILE_SIZE - 1) / K_TILE_SIZE);
    dim3 block(NUM_THREADS);

    nqmv_bias<DT><<<grid, block, 0, stream>>>(
        (uint32_t*) q_weight.data_ptr<int32_t>(),
        (FP_DTYPE(DT)*) alpha.data_ptr<ATEN_DTYPE(DT)>(),
        (FP_DTYPE(DT)*) q_bias.data_ptr<ATEN_DTYPE(DT)>(),
        (FP_DTYPE(DT)*) input.data_ptr<ATEN_DTYPE(DT)>(),
        (FP_DTYPE(DT)*) output.data_ptr<ATEN_DTYPE(DT)>(),
        mSize, kSize, bitwidth, group_size
    );

    cudaError_t err = cudaGetLastError();
    TORCH_CHECK(err == cudaSuccess, "CUDA Error: ", cudaGetErrorString(err));
}

void lutgemm_gemv_stream(
    torch::Tensor input,
    torch::Tensor q_weight,
    torch::Tensor alpha,
    torch::Tensor q_bias,
    torch::Tensor output,
    int bitwidth,
    int group_size,
    cudaStream_t stream
) {
    TORCH_CHECK(bitwidth >= 1 && bitwidth <= 8, "Bitwidth must be between 1 and 8.");
    TORCH_CHECK(input.scalar_type() == alpha.scalar_type() && input.scalar_type() == q_bias.scalar_type() && input.scalar_type() == output.scalar_type(), "Mismatched data types between input, alpha, q_bias, and output tensors.");
    // Check that input is of shape (batch_size, seq_len, input_feat)
    TORCH_CHECK(input.dim() == 3, "input tensor must be of shape (batch_size, seq_len, input_feat).");
    // Check that output is of shape (batch_size, seq_len, output_feat)
    TORCH_CHECK(output.dim() == 3, "output tensor must be of shape (batch_size, seq_len, output_feat).");

    // Only allow single batch size and sequence length
    TORCH_CHECK(input.size(0) == 1, "Batch size must be 1 for input tensor.");
    TORCH_CHECK(input.size(1) == 1, "Sequence length must be 1 for input tensor.");
    TORCH_CHECK(output.size(0) == 1, "Batch size must be 1 for output tensor.");
    TORCH_CHECK(output.size(1) == 1, "Sequence length must be 1 for output tensor.");

    // Check that input and output are both on GPU
    TORCH_CHECK(input.is_cuda() && output.is_cuda(), "input and output tensors must be on GPU.");

    // Check that all tensors are contiguous
    TORCH_CHECK(input.is_contiguous(), "input tensor must be contiguous.");
    TORCH_CHECK(output.is_contiguous(), "output tensor must be contiguous.");
    TORCH_CHECK(q_weight.is_contiguous(), "q_weight tensor must be contiguous.");
    TORCH_CHECK(alpha.is_contiguous(), "alpha tensor must be contiguous.");
    TORCH_CHECK(q_bias.is_contiguous(), "q_bias tensor must be contiguous.");

    uint32_t kSize = input.size(2);
    uint32_t mSize = output.size(2);
    uint32_t num_groups = kSize / group_size;

    // check that q_weight is of shape (input_feat / 32, bitwidth, output_feat)
    TORCH_CHECK(q_weight.dim() == 3 && q_weight.size(0) == kSize / 32 && q_weight.size(1) == bitwidth && q_weight.size(2) == mSize, "q_weight tensor must be of shape (input_feat / 32, bitwidth, output_feat). Expected (", kSize / 32, ", ", bitwidth, ", ", mSize, "), got (", q_weight.size(0), ", ", q_weight.size(1), ", ", q_weight.size(2), ").");
    // check that alpha is of shape (num_groups, bitwidth, mSize)
    TORCH_CHECK(alpha.dim() == 3 && alpha.size(0) == num_groups && alpha.size(1) == bitwidth && alpha.size(2) == mSize, 
                "alpha tensor must be of shape (num_groups, bitwidth, output_feat). Expected (", num_groups, ", ", bitwidth, ", ", mSize, "), got (", alpha.size(0), ", ", alpha.size(1), ", ", alpha.size(2), ").");

    auto dtype = input.scalar_type();
    if (dtype == at::kFloat) {
        lutgemm_gemv_templated<DataType::FP32>(input, q_weight, alpha, q_bias, output, bitwidth, group_size, stream);
    } else if (dtype == at::kHalf) {
        lutgemm_gemv_templated<DataType::FP16>(input, q_weight, alpha, q_bias, output, bitwidth, group_size, stream);
    } else if (dtype == at::kBFloat16) {
        lutgemm_gemv_templated<DataType::BF16>(input, q_weight, alpha, q_bias, output, bitwidth, group_size, stream);
    } else {
        TORCH_CHECK(false, "Unsupported data type.");
    }
}

void lutgemm_gemv(
    torch::Tensor input,
    torch::Tensor q_weight,
    torch::Tensor alpha,
    torch::Tensor q_bias,
    torch::Tensor output,
    int bitwidth,
    int group_size
) {
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    lutgemm_gemv_stream(input, q_weight, alpha, q_bias, output, bitwidth, group_size, stream);
}

////////////////////////////////////////////////////////////////////////////////
//                                     SQLLM
////////////////////////////////////////////////////////////////////////////////

template<DataType DT>
void sqllm_gemv_templated(
    torch::Tensor input,
    torch::Tensor qweight,
    torch::Tensor lut,
    torch::Tensor output,
    int bitwidth,
    cudaStream_t stream
) {
    uint32_t height = input.size(2);
    uint32_t width = qweight.size(1);
    uint32_t matrix_height = height / 32 * bitwidth;

    dim3 threads(SQLLM_BLOCKWIDTH);
    dim3 grid = dim3(
        (matrix_height + (bitwidth == 3 ? SQLLM_BLOCKHEIGHT3 : SQLLM_BLOCKHEIGHT4) - 1) / 
        (bitwidth == 3 ? SQLLM_BLOCKHEIGHT3 : SQLLM_BLOCKHEIGHT4),
        (width + SQLLM_BLOCKWIDTH - 1) / SQLLM_BLOCKWIDTH
    );

    if (bitwidth == 3) {
        VecQuant3MatMulKernelNUQPerChannel<DT><<<grid, threads, 0, stream>>>(
            (FP_DTYPE(DT)*)input.data_ptr<ATEN_DTYPE(DT)>(),
            (int*)qweight.data_ptr<int>(),
            (FP_DTYPE(DT)*)output.data_ptr<ATEN_DTYPE(DT)>(),
            (FP_DTYPE(DT)*)lut.data_ptr<ATEN_DTYPE(DT)>(),
            static_cast<int>(height), static_cast<int>(width)
        );
    } else if (bitwidth == 4) {
        VecQuant4MatMulKernelNUQPerChannel<DT><<<grid, threads, 0, stream>>>(
            (FP_DTYPE(DT)*)input.data_ptr<ATEN_DTYPE(DT)>(),
            (int*)qweight.data_ptr<int>(),
            (FP_DTYPE(DT)*)output.data_ptr<ATEN_DTYPE(DT)>(),
            (FP_DTYPE(DT)*)lut.data_ptr<ATEN_DTYPE(DT)>(),
            static_cast<int>(height), static_cast<int>(width)
        );
    }

    cudaError_t err = cudaGetLastError();
    TORCH_CHECK(err == cudaSuccess, "CUDA Error in sqllm_gemv: ", cudaGetErrorString(err));
}

void sqllm_gemv_stream(
    torch::Tensor input,
    torch::Tensor qweight,
    torch::Tensor lut,
    torch::Tensor output,
    int bitwidth,
    cudaStream_t stream
) {
    TORCH_CHECK(bitwidth == 3 || bitwidth == 4, "Bitwidth must be 3 or 4.");
    TORCH_CHECK(input.scalar_type() == lut.scalar_type() && input.scalar_type() == output.scalar_type(), 
                "Mismatched data types between input, lut, and output tensors.");
    TORCH_CHECK(qweight.scalar_type() == at::kInt, "qweight tensor must be of type int.");
    TORCH_CHECK(input.dim() == 3, "input tensor must be of shape (batch_size, seq_len, hidden_size).");
    TORCH_CHECK(output.dim() == 3, "output tensor must be of shape (batch_size, seq_len, hidden_size).");

    // Only allow single batch size and sequence length
    TORCH_CHECK(input.size(0) == 1, "Batch size must be 1 for input tensor.");
    TORCH_CHECK(input.size(1) == 1, "Sequence length must be 1 for input tensor.");
    TORCH_CHECK(output.size(0) == 1, "Batch size must be 1 for output tensor.");
    TORCH_CHECK(output.size(1) == 1, "Sequence length must be 1 for output tensor.");

    // Check that input and output are both on GPU
    TORCH_CHECK(input.is_cuda() && output.is_cuda(), "input and output tensors must be on GPU.");

    // Check that lut is of shape (output_feat, 2 ** bitwidth)
    TORCH_CHECK(lut.dim() == 2 && lut.size(1) == (1 << bitwidth) && lut.size(0) == output.size(2),
    "lut tensor must be of shape (output_feat, 2 ** bitwidth). Expected (", output.size(2), ", ", 1 << bitwidth, "), got (", lut.size(0), ", ", lut.size(1), ").");

    // Check that qweight is of shape (input_feat * bitwidth / 32, output_feat)
    TORCH_CHECK(qweight.dim() == 2 && qweight.size(1) == lut.size(0) && qweight.size(0) == input.size(2) * bitwidth / 32,
    "qweight tensor must be of shape (input_feat * bitwidth / 32, output_feat). Expected (", input.size(2) * bitwidth / 32, ", ", lut.size(0), "), got (", qweight.size(0), ", ", qweight.size(1), ").");

    // Check that all tensors are contiguous
    TORCH_CHECK(input.is_contiguous(), "input tensor must be contiguous.");
    TORCH_CHECK(output.is_contiguous(), "output tensor must be contiguous.");
    TORCH_CHECK(qweight.is_contiguous(), "qweight tensor must be contiguous.");
    TORCH_CHECK(lut.is_contiguous(), "lut tensor must be contiguous.");

    auto dtype = input.scalar_type();
    if (dtype == at::kFloat) {
        sqllm_gemv_templated<DataType::FP32>(input, qweight, lut, output, bitwidth, stream);
    } else if (dtype == at::kHalf) {
        sqllm_gemv_templated<DataType::FP16>(input, qweight, lut, output, bitwidth, stream);
    } else if (dtype == at::kBFloat16) {
        sqllm_gemv_templated<DataType::BF16>(input, qweight, lut, output, bitwidth, stream);
    } else {
        TORCH_CHECK(false, "Unsupported data type.");
    }
}

void sqllm_gemv(
    torch::Tensor input,
    torch::Tensor qweight,
    torch::Tensor lut,
    torch::Tensor output,
    int bitwidth
) {
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    sqllm_gemv_stream(input, qweight, lut, output, bitwidth, stream);
}