#ifndef DECDEC_CUH
#define DECDEC_CUH

#include <cassert>
#include <cstdlib>
#include <cuda_fp16.h>
#include <cstdio>
#include <ctime>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <fstream>

#include <torch/extension.h>
#include <cuda_runtime.h>

void anyprec_gemv(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth
);

void anyprec_gemv_sel(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut3,
    torch::Tensor lut4,
    torch::Tensor lut5,
    torch::Tensor lut6,
    int bitwidth,
    torch::Tensor bsel,
    uintptr_t sne
);

void anyprec_gemv_sel_fake(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth,
    torch::Tensor bsel
);

void anyprec_gemv_sel_two(
    torch::Tensor input,
    torch::Tensor output,
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth,
    torch::Tensor jl,
    torch::Tensor res
);

void gemvNormTH(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
);

void gemvNormTHq(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
);
void gemvNormTHk(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
);
void gemvNormTHv(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
);
void gemvNormTHg(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
);
void gemvNormTHu(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
);


void gemvNormTHqkv(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
);
void gemvNormTHgu(
    torch::Tensor input,
    torch::Tensor jl,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    float threshold,
    uintptr_t sne
);

void normTHq(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
);
void normTHk(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
);
void normTHv(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
);
void normTHg(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
);
void normTHu(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
);

void normTHqkv(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
);
void normTHgu(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    torch::Tensor low,
    torch::Tensor high,
    uintptr_t sne
);


void gemvNormTH2(
    torch::Tensor input,
    torch::Tensor jl1,
    torch::Tensor jl2,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    float threshold1,
    float threshold2,
    uintptr_t sne
);

void gemvNormTH3(
    torch::Tensor input,
    torch::Tensor jl1,
    torch::Tensor jl2,
    torch::Tensor jl3,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    torch::Tensor bsel3,
    float threshold1,
    float threshold2,
    float threshold3,
    uintptr_t sne
);

void gemvNormTH3Full(
    torch::Tensor input,
    torch::Tensor jl1,
    torch::Tensor jl2,
    torch::Tensor jl3,
    torch::Tensor res1,
    torch::Tensor res2,
    torch::Tensor res3,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    torch::Tensor bsel3,
    float threshold1,
    float threshold2,
    float threshold3,
    uintptr_t sne
);

void normTH(
    torch::Tensor input,
    float a,
    float b,
    float threshold,
    torch::Tensor bsel,
    uintptr_t sne
);

void normTH2(
    torch::Tensor input,
    float a1,
    float a2,
    float b1,
    float b2,
    float threshold1,
    float threshold2,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    uintptr_t sne
);

void lnNormTH2(
    torch::Tensor input,
    torch::Tensor normW,
    torch::Tensor res,
    float a1,
    float a2,
    float b1,
    float b2,
    float threshold1,
    float threshold2,
    torch::Tensor bsel1,
    torch::Tensor bsel2,
    uintptr_t sne
);

void lnGemvNormTH(
    torch::Tensor input,
    torch::Tensor normW,
    torch::Tensor res,
    torch::Tensor jl,
    torch::Tensor bsel,
    float threshold,
    uintptr_t sne
);

void fakeTrigger(
    uintptr_t sne
);

uintptr_t create_streamNevent_full();

torch::Tensor anyprec_dequant(
    torch::Tensor qweight,
    torch::Tensor lut,
    int bitwidth
);

void lutgemm_gemv(
    torch::Tensor input,
    torch::Tensor q_weight,
    torch::Tensor alpha,
    torch::Tensor q_bias,
    torch::Tensor output,
    int bitwidth,
    int group_size 
);

void sqllm_gemv(
    torch::Tensor input,
    torch::Tensor qweight,
    torch::Tensor lut,
    torch::Tensor output,
    int bitwidth
);


class StreamNevent {
public:
    cudaStream_t sub_stream;
    cudaEvent_t start_event;
    cudaEvent_t end_event;

    StreamNevent (
        cudaStream_t _sub_stream,
        cudaEvent_t _start_event,
        cudaEvent_t _end_event
    );

    ~StreamNevent();
};

class StreamNevent_full {
public:
    cudaStream_t sub_stream;
    cudaEvent_t start_event;
    cudaEvent_t end_event;
    cudaEvent_t mid1_event;
    cudaEvent_t mid2_event;

    StreamNevent_full (
        cudaStream_t _sub_stream,
        cudaEvent_t _start_event,
        cudaEvent_t _end_event,
        cudaEvent_t _mid1_event,
        cudaEvent_t _mid2_event
    );

    ~StreamNevent_full();
};

#endif // DECDEC_CUH
