#include <torch/extension.h>

// Double quantization kernels
template <int D>
torch::Tensor adjust_scale_dq(
    torch::Tensor packed, // (B, H, L, D//PACKED)
    torch::Tensor codebook_idx,
    torch::Tensor codebook_scale,
    torch::Tensor codebook_offset,
    torch::Tensor norm, // (B, H, L, 1)
    torch::Tensor orig // (B, H, L, D)
);

template <int D>
torch::Tensor adjust_scale_dq_1bit(
    torch::Tensor packed, // (B, H, L, D//PACKED)
    torch::Tensor codebook_idx,
    torch::Tensor codebook_scale,
    torch::Tensor codebook_offset,
    torch::Tensor norm, // (B, H, L, 1)
    torch::Tensor orig // (B, H, L, D)
);

template<int D, int window_size>
torch::Tensor restore_quantized_dq(
    torch::Tensor A_in,
    torch::Tensor norm,
    torch::Tensor mean,
    torch::Tensor norm2,
    torch::Tensor codebook_idx,
    torch::Tensor codebook_scale,
    torch::Tensor codebook_offset
);

template<int D, int window_size>
torch::Tensor restore_quantized_dq_1bit(
    torch::Tensor A_in,
    torch::Tensor norm,
    torch::Tensor mean,
    torch::Tensor norm2,
    torch::Tensor codebook_idx,
    torch::Tensor codebook_scale,
    torch::Tensor codebook_offset
);

torch::Tensor dist_argmin_half_packed_dq_1bit(
    torch::Tensor A,
    torch::Tensor B_idx,
    torch::Tensor B_scale,
    torch::Tensor B_offset
);

torch::Tensor dist_argmin_half_packed_dq(
    torch::Tensor A,
    torch::Tensor B_idx,
    torch::Tensor B_scale,
    torch::Tensor B_offset
);

template <int D, int window_size>
torch::Tensor quantized_dot_product_fused_dq(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256,1),      int32
    torch::Tensor codebook_scale,   // (256,1),      half
    torch::Tensor codebook_offset,   // (256,1),      half
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset,
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_dot_product_fused_dq_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256,1),      int32
    torch::Tensor codebook_scale,   // (256,1),      half
    torch::Tensor codebook_offset,   // (256,1),      half
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset,
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_dot_product_fused_residual_dq(
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256,1),      int32_t,
    torch::Tensor codebook_scale,   // (256,1),      half,
    torch::Tensor codebook_offset,   // (256,1),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor residual_cos, // (B, R, D) half
    torch::Tensor residual_sin, // (B, R, D) half 
    torch::Tensor residual, // (B, H, R, D) half
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_dot_product_fused_residual_dq_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256,1),      int32_t,
    torch::Tensor codebook_scale,   // (256,1),      half,
    torch::Tensor codebook_offset,   // (256,1),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor residual_cos, // (B, R, D) half
    torch::Tensor residual_sin, // (B, R, D) half 
    torch::Tensor residual, // (B, H, R, D) half
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_weighted_sum_residual_dq(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B,H,R,D), half
    torch::Tensor codebook_idx,   // (256,1),      int32_t,
    torch::Tensor codebook_scale,   // (256,1),      half,
    torch::Tensor codebook_offset,   // (256,1),      half,
    torch::Tensor weight, // (B, H * nh, 1, L+R)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_weighted_sum_residual_dq_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B,H,R,D), half
    torch::Tensor codebook_idx,   // (256,1),      int32_t,
    torch::Tensor codebook_scale,   // (256,1),      half,
    torch::Tensor codebook_offset,   // (256,1),      half,
    torch::Tensor weight, // (B, H * nh, 1, L+R)
    int nh
);


template <int D, int window_size>
torch::Tensor quantized_weighted_sum_dq(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256,1),      int32_t,
    torch::Tensor codebook_scale,   // (256,1),      half,
    torch::Tensor codebook_offset,   // (256,1),      half,
    torch::Tensor weight, // (B, H * nh, L)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_weighted_sum_dq_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm_idx,      // (B,H,L//window_size,window_size//DQ_PACK),    int32
    torch::Tensor norm_scale,      // (B,H,L//window_size,1),    half
    torch::Tensor norm_offset,      // (B,H,L//window_size,1),    half
    torch::Tensor mean_idx,      // (B,H, L//window_size,1,D//DQ_PACK), int32
    torch::Tensor mean_scale,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor mean_offset,      // (B,H, L//window_size,1,D//32), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook_idx,   // (256,1),      int32_t,
    torch::Tensor codebook_scale,   // (256,1),      half,
    torch::Tensor codebook_offset,   // (256,1),      half,
    torch::Tensor weight, // (B, H * nh, L)
    int nh
);

// Other kernels
template <int N>
torch::Tensor dist_argmin_half_batched(
    torch::Tensor A,
    torch::Tensor B
);

torch::Tensor dist_argmin_half(
    torch::Tensor A,
    torch::Tensor B
);

torch::Tensor dist_argmin_half_packed(
    torch::Tensor A,
    torch::Tensor B
);

torch::Tensor dist_argmin_half_packed_1bit(
    torch::Tensor A,
    torch::Tensor B
);

template<int D, int window_size>
torch::Tensor restore_quantized(
    torch::Tensor A_in,
    torch::Tensor norm,
    torch::Tensor mean,
    torch::Tensor norm2,
    torch::Tensor codebook
);

template<int D, int window_size>
torch::Tensor restore_quantized_1bit(
    torch::Tensor A_in,
    torch::Tensor norm,
    torch::Tensor mean,
    torch::Tensor norm2,
    torch::Tensor codebook
);


template <int D, int window_size>
torch::Tensor quantized_weighted_sum(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor weight, // (B, H * nh, L)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_weighted_sum_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor weight, // (B, H * nh, L)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_dot_product(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor query, // (B, H * nh, 1, D) half
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_dot_product_fused(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset,
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_dot_product_fused_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset,
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template <int D, int window_size>
torch::Tensor window_rope_dot_product(
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor query, // (B, H * nh, 1, L)
    torch::Tensor inv_freq, // (D//2)
    torch::Tensor offset, // (B)
    int nh
);

template <int D>
torch::Tensor adjust_scale(
    torch::Tensor packed, // (B, H, L, D//PACKED)
    torch::Tensor codebook,
    torch::Tensor norm, // (B, H, L, 1)
    torch::Tensor orig // (B, H, L, D)
);

template <int D>
torch::Tensor adjust_scale_1bit(
    torch::Tensor packed, // (B, H, L, D//PACKED)
    torch::Tensor codebook,
    torch::Tensor norm, // (B, H, L, 1)
    torch::Tensor orig // (B, H, L, D)
);


template <int D, int window_size>
torch::Tensor quantized_dot_product_fused_residual(
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor residual_cos, // (B, R, D) half
    torch::Tensor residual_sin, // (B, R, D) half 
    torch::Tensor residual, // (B, H, R, D) half
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_dot_product_fused_residual_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int32
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor inv_freq, // (D//2), half
    torch::Tensor offset, // (B), int
    torch::Tensor residual_cos, // (B, R, D) half
    torch::Tensor residual_sin, // (B, R, D) half 
    torch::Tensor residual, // (B, H, R, D) half
    torch::Tensor query_had, // (B, H * nh, 1, D) half
    torch::Tensor query, // (B, H * nh, 1, D)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_weighted_sum_residual(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B,H,R,D), half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor weight, // (B, H * nh, 1, L+R)
    int nh
);

template <int D, int window_size>
torch::Tensor quantized_weighted_sum_residual_1bit(
    torch::Tensor packed,      // (B,H,L,D//32), int64
    torch::Tensor norm,      // (B,H,L,1),    half
    torch::Tensor mean,      // (B,H, L//window_size,1,D), half
    torch::Tensor norm2,     // (B,H,L,1),    half
    torch::Tensor residual, // (B,H,R,D), half
    torch::Tensor codebook,   // (256,8),      half,
    torch::Tensor weight, // (B, H * nh, 1, L+R)
    int nh
);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    // double quantization kernels
    m.def("adjust_scale_dq_d128", &adjust_scale_dq<128>, "adjust_scale_dq_d128 (CUDA)");
    m.def("adjust_scale_dq_1bit_d128", &adjust_scale_dq_1bit<128>, "adjust_scale_dq_1bit_d128 (CUDA)");
    m.def("restore_quantized_dq_ws64_d128", &restore_quantized_dq<128, 64>, "restore_quantized_dq_ws64_d128 (CUDA)");
    m.def("restore_quantized_dq_1bit_ws64_d128", &restore_quantized_dq_1bit<128, 64>, "restore_quantized_dq_1bit_ws64_d128 (CUDA)");
    m.def("dist_argmin_half_packed_dq_1bit", &dist_argmin_half_packed_dq_1bit, "dist_argmin_half_packed_dq_1bit (CUDA)");
    m.def("dist_argmin_half_packed_dq", &dist_argmin_half_packed_dq, "dist_argmin_half_packed_dq (CUDA)");
    m.def("quantized_dot_product_fused_dq_ws64_d128", &quantized_dot_product_fused_dq<128, 64>, "quantized_dot_product_fused_dq_ws64_d128 (CUDA)");
    m.def("quantized_dot_product_fused_dq_1bit_ws64_d128", &quantized_dot_product_fused_dq_1bit<128, 64>, "quantized_dot_product_fused_dq_1bit_ws64_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_dq_ws64_d128", &quantized_dot_product_fused_residual_dq<128, 64>, "quantized_dot_product_fused_residual_dq_ws64_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_dq_1bit_ws64_d128", &quantized_dot_product_fused_residual_dq_1bit<128, 64>, "quantized_dot_product_fused_residual_dq_1bit_ws64_d128 (CUDA)");
    m.def("quantized_weighted_sum_dq_ws64_d128", &quantized_weighted_sum_dq<128, 64>, "quantized_weighted_sum_dq_ws64_d128 (CUDA)");
    m.def("quantized_weighted_sum_dq_1bit_ws64_d128", &quantized_weighted_sum_dq_1bit<128, 64>, "quantized_weighted_sum_dq_1bit_ws64_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_dq_ws64_d128", &quantized_weighted_sum_residual_dq<128, 64>, "quantized_weighted_sum_residual_dq_ws64_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_dq_1bit_ws64_d128", &quantized_weighted_sum_residual_dq_1bit<128, 64>, "quantized_weighted_sum_residual_dq_1bit_ws64_d128 (CUDA)");

    // For window_size = 32
    m.def("restore_quantized_dq_ws32_d128", &restore_quantized_dq<128, 32>, "restore_quantized_dq_ws32_d128 (CUDA)");
    m.def("restore_quantized_dq_1bit_ws32_d128", &restore_quantized_dq_1bit<128, 32>, "restore_quantized_dq_1bit_ws32_d128 (CUDA)");
    m.def("quantized_dot_product_fused_dq_ws32_d128", &quantized_dot_product_fused_dq<128, 32>, "quantized_dot_product_fused_dq_ws32_d128 (CUDA)");
    m.def("quantized_dot_product_fused_dq_1bit_ws32_d128", &quantized_dot_product_fused_dq_1bit<128, 32>, "quantized_dot_product_fused_dq_1bit_ws32_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_dq_ws32_d128", &quantized_dot_product_fused_residual_dq<128, 32>, "quantized_dot_product_fused_residual_dq_ws32_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_dq_1bit_ws32_d128", &quantized_dot_product_fused_residual_dq_1bit<128, 32>, "quantized_dot_product_fused_residual_dq_1bit_ws32_d128 (CUDA)");
    m.def("quantized_weighted_sum_dq_ws32_d128", &quantized_weighted_sum_dq<128, 32>, "quantized_weighted_sum_dq_ws32_d128 (CUDA)");
    m.def("quantized_weighted_sum_dq_1bit_ws32_d128", &quantized_weighted_sum_dq_1bit<128, 32>, "quantized_weighted_sum_dq_1bit_ws32_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_dq_ws32_d128", &quantized_weighted_sum_residual_dq<128, 32>, "quantized_weighted_sum_residual_dq_ws32_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_dq_1bit_ws32_d128", &quantized_weighted_sum_residual_dq_1bit<128, 32>, "quantized_weighted_sum_residual_dq_1bit_ws32_d128 (CUDA)");

    // For window_size = 128
    m.def("restore_quantized_dq_ws128_d128", &restore_quantized_dq<128, 128>, "restore_quantized_dq_ws128_d128 (CUDA)");
    m.def("restore_quantized_dq_1bit_ws128_d128", &restore_quantized_dq_1bit<128, 128>, "restore_quantized_dq_1bit_ws128_d128 (CUDA)");
    m.def("quantized_dot_product_fused_dq_ws128_d128", &quantized_dot_product_fused_dq<128, 128>, "quantized_dot_product_fused_dq_ws128_d128 (CUDA)");
    m.def("quantized_dot_product_fused_dq_1bit_ws128_d128", &quantized_dot_product_fused_dq_1bit<128, 128>, "quantized_dot_product_fused_dq_1bit_ws128_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_dq_ws128_d128", &quantized_dot_product_fused_residual_dq<128, 128>, "quantized_dot_product_fused_residual_dq_ws128_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_dq_1bit_ws128_d128", &quantized_dot_product_fused_residual_dq_1bit<128, 128>, "quantized_dot_product_fused_residual_dq_1bit_ws128_d128 (CUDA)");
    m.def("quantized_weighted_sum_dq_ws128_d128", &quantized_weighted_sum_dq<128, 128>, "quantized_weighted_sum_dq_ws128_d128 (CUDA)");
    m.def("quantized_weighted_sum_dq_1bit_ws128_d128", &quantized_weighted_sum_dq_1bit<128, 128>, "quantized_weighted_sum_dq_1bit_ws128_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_dq_ws128_d128", &quantized_weighted_sum_residual_dq<128, 128>, "quantized_weighted_sum_residual_dq_ws128_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_dq_1bit_ws128_d128", &quantized_weighted_sum_residual_dq_1bit<128, 128>, "quantized_weighted_sum_residual_dq_1bit_ws128_d128 (CUDA)");

    // other kernels
    m.def("dist_argmin_half", &dist_argmin_half, "dist_argmin_half (CUDA)");
    m.def("dist_argmin_half_batched_d4", &dist_argmin_half_batched<4>, "dist_argmin_half_batched_d4 (CUDA)");
    m.def("dist_argmin_half_batched_d8", &dist_argmin_half_batched<8>, "dist_argmin_half_batched_d8 (CUDA)");
    m.def("dist_argmin_half_batched_d9", &dist_argmin_half_batched<9>, "dist_argmin_half_batched_d9 (CUDA)");
    m.def("dist_argmin_half_batched_d10", &dist_argmin_half_batched<10>, "dist_argmin_half_batched_d10 (CUDA)");
    m.def("dist_argmin_half_packed", &dist_argmin_half_packed, "dist_argmin_half_packed (CUDA)");
    m.def("dist_argmin_half_packed_1bit", &dist_argmin_half_packed_1bit, "dist_argmin_half_packed_1bit (CUDA)");
    m.def("restore_quantized_ws64_d128", &restore_quantized<128, 64>, "restore_quantized_ws64_d128 (CUDA)");
    m.def("restore_quantized_1bit_ws64_d128", &restore_quantized_1bit<128, 64>, "restore_quantized_1bit_ws64_d128 (CUDA)");
    m.def("quantized_dot_product_ws64_d128", &quantized_dot_product<128, 64>, "quantized_dot_product_ws64_d128 (CUDA)");
    m.def("quantized_dot_product_fused_ws64_d128", &quantized_dot_product_fused<128, 64>, "quantized_dot_product_fused_ws64_d128 (CUDA)");
    m.def("quantized_dot_product_fused_1bit_ws64_d128", &quantized_dot_product_fused_1bit<128, 64>, "quantized_dot_product_fused_1bit_ws64_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_ws64_d128", &quantized_dot_product_fused_residual<128, 64>, "quantized_dot_product_fused_residual_ws64_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_1bit_ws64_d128", &quantized_dot_product_fused_residual_1bit<128, 64>, "quantized_dot_product_fused_residual_1bit_ws64_d128 (CUDA)");
    m.def("window_rope_dot_product_ws64_d128", &window_rope_dot_product<128, 64>, "window_rope_dot_product_ws64_d128 (CUDA)");
    m.def("quantized_weighted_sum_ws64_d128", &quantized_weighted_sum<128, 64>, "quantized_weighted_sum_ws64_d128 (CUDA)");
    m.def("quantized_weighted_sum_1bit_ws64_d128", &quantized_weighted_sum_1bit<128, 64>, "quantized_weighted_sum_1bit_ws64_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_ws64_d128", &quantized_weighted_sum_residual<128, 64>, "quantized_weighted_sum_residual_ws64_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_1bit_ws64_d128", &quantized_weighted_sum_residual_1bit<128, 64>, "quantized_weighted_sum_residual_1bit_ws64_d128 (CUDA)");
    m.def("adjust_scale_d128", &adjust_scale<128>, "adjust_scale_d128 (CUDA)");
    m.def("adjust_scale_1bit_d128", &adjust_scale_1bit<128>, "adjust_scale_1bit_d128 (CUDA)");
    // For ws32, d128
    m.def("restore_quantized_ws32_d128", &restore_quantized<128, 32>, "restore_quantized_ws32_d128 (CUDA)");
    m.def("restore_quantized_1bit_ws32_d128", &restore_quantized_1bit<128, 32>, "restore_quantized_1bit_ws32_d128 (CUDA)");
    m.def("quantized_dot_product_ws32_d128", &quantized_dot_product<128, 32>, "quantized_dot_product_ws32_d128 (CUDA)");
    m.def("quantized_dot_product_fused_ws32_d128", &quantized_dot_product_fused<128, 32>, "quantized_dot_product_fused_ws32_d128 (CUDA)");
    m.def("quantized_dot_product_fused_1bit_ws32_d128", &quantized_dot_product_fused_1bit<128, 32>, "quantized_dot_product_fused_1bit_ws32_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_ws32_d128", &quantized_dot_product_fused_residual<128, 32>, "quantized_dot_product_fused_residual_ws32_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_1bit_ws32_d128", &quantized_dot_product_fused_residual_1bit<128, 32>, "quantized_dot_product_fused_residual_1bit_ws32_d128 (CUDA)");
    m.def("window_rope_dot_product_ws32_d128", &window_rope_dot_product<128, 32>, "window_rope_dot_product_ws32_d128 (CUDA)");
    m.def("quantized_weighted_sum_ws32_d128", &quantized_weighted_sum<128, 32>, "quantized_weighted_sum_ws32_d128 (CUDA)");
    m.def("quantized_weighted_sum_1bit_ws32_d128", &quantized_weighted_sum_1bit<128, 32>, "quantized_weighted_sum_1bit_ws32_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_ws32_d128", &quantized_weighted_sum_residual<128, 32>, "quantized_weighted_sum_residual_ws32_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_1bit_ws32_d128", &quantized_weighted_sum_residual_1bit<128, 32>, "quantized_weighted_sum_residual_1bit_ws32_d128 (CUDA)");
    // For ws128, d128
    m.def("restore_quantized_ws128_d128", &restore_quantized<128, 128>, "restore_quantized_ws128_d128 (CUDA)");
    m.def("restore_quantized_1bit_ws128_d128", &restore_quantized_1bit<128, 128>, "restore_quantized_1bit_ws128_d128 (CUDA)");
    m.def("quantized_dot_product_ws128_d128", &quantized_dot_product<128, 128>, "quantized_dot_product_ws128_d128 (CUDA)");
    m.def("quantized_dot_product_fused_ws128_d128", &quantized_dot_product_fused<128, 128>, "quantized_dot_product_fused_ws128_d128 (CUDA)");
    m.def("quantized_dot_product_fused_1bit_ws128_d128", &quantized_dot_product_fused_1bit<128, 128>, "quantized_dot_product_fused_1bit_ws128_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_ws128_d128", &quantized_dot_product_fused_residual<128, 128>, "quantized_dot_product_fused_residual_ws128_d128 (CUDA)");
    m.def("quantized_dot_product_fused_residual_1bit_ws128_d128", &quantized_dot_product_fused_residual_1bit<128, 128>, "quantized_dot_product_fused_residual_1bit_ws128_d128 (CUDA)");
    m.def("window_rope_dot_product_ws128_d128", &window_rope_dot_product<128, 128>, "window_rope_dot_product_ws128_d128 (CUDA)");
    m.def("quantized_weighted_sum_ws128_d128", &quantized_weighted_sum<128, 128>, "quantized_weighted_sum_ws128_d128 (CUDA)");
    m.def("quantized_weighted_sum_1bit_ws128_d128", &quantized_weighted_sum_1bit<128, 128>, "quantized_weighted_sum_1bit_ws128_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_ws128_d128", &quantized_weighted_sum_residual<128, 128>, "quantized_weighted_sum_residual_ws128_d128 (CUDA)");
    m.def("quantized_weighted_sum_residual_1bit_ws128_d128", &quantized_weighted_sum_residual_1bit<128, 128>, "quantized_weighted_sum_residual_1bit_ws128_d128 (CUDA)");
}
