#include <thrust/device_ptr.h>

#include <thrust/scan.h>

#include <thrust/execution_policy.h>

#include <cub/cub.cuh>

#include "hybrid_sp.h"

#include "perf_instrumentation.h"



#include <iostream>

#include <vector>

#include <random>

#include <cassert>

#include <cmath>

#include <algorithm>

#include <numeric>

#include <unordered_set>

#include <cuda_runtime.h>

#include <cuda_bf16.h>

#include <cuda_runtime_api.h>





#define ERROR_CHECK 0

#define TILE_DIM_X 128
#define TILE_DIM_Y 128




#define ELL_WIDTH 128

#define VEC_SIZE 8

#define SLAB_SIZE 32




__forceinline__ __device__ void get_idx_coords128x128(uint16_t idx, uint32_t& x, uint32_t &y) {

    uint32_t i = idx & 0xFF;

    uint32_t tid = idx >> 8;

    x = i;

    y = tid;

}



__forceinline__ __device__ void unpack_element(int raw, uint32_t& i, uint32_t& j, __nv_bfloat16& val) {

    get_idx_coords128x128(raw >> 16, i, j);

    val = __ushort_as_bfloat16(raw & 0xFFFF);

}



__forceinline__ __device__ int atomic_read_i32(const int* p) {

    return atomicAdd((int*)p, 0);

}



__forceinline__ __device__ void populate_dense(

    int g_row, int g_col, __nv_bfloat16 val,

    __nv_bfloat16* __restrict__ tail_dense, int dense_ld,

    int* __restrict__ tail_dense_map,

    int* __restrict__ tail_dense_map_reverse,

    int* __restrict__ tail_dense_counter

) {


    int dr = atomic_read_i32(&tail_dense_map[g_row]);

    if (dr >= 0) {

        tail_dense[(size_t)dr * (size_t)dense_ld + (size_t)g_col] = val;

        return;

    }




    if (dr == -1) {

        int old = atomicCAS(&tail_dense_map[g_row], -1, -2);
        if (old == -1) {


            int new_dr = atomicAdd(tail_dense_counter, 1);



            if (new_dr < TAIL_CAPACITY_ROWS) {


                atomicExch(&tail_dense_map[g_row], new_dr);

                dr = new_dr;

                tail_dense_map_reverse[dr] = g_row;

            } else {


                atomicExch(&tail_dense_map[g_row], -3);

                printf("Error\n");

                return;

            }



            tail_dense[(size_t)dr * (size_t)dense_ld + (size_t)g_col] = val;

            return;

        } else {

            dr = old;
        }

    }




    if (dr == -2) {

        while (dr == -2 ) {

            dr = atomic_read_i32(&tail_dense_map[g_row]);

        }

        if (dr < 0){

            printf("Error 2\n");

            return;
        }

    }




    if (dr >= 0) {

        tail_dense[(size_t)dr * (size_t)dense_ld + (size_t)g_col] = val;

    }

}





__forceinline__ __device__ void populate_dense_known(

    int g_row, int g_col, __nv_bfloat16 val,

    __nv_bfloat16* __restrict__ tail_dense, int dense_ld,

    const int* __restrict__ tail_dense_map

) {


    int dr = atomic_read_i32((int*)&tail_dense_map[g_row]);

    if (dr >= 0) {

        tail_dense[(size_t)dr * (size_t)dense_ld + (size_t)g_col] = val;

        return;

    }

    printf("Invalid value!!\n");

}



__global__ void promote_overflow_rows_ell_into_tail_dense(

    const uint16_t*      __restrict__ ell_cols,
    const __nv_bfloat16* __restrict__ ell_vals,
    int*                 __restrict__ row_counters,
    int                  M_rows,



    __nv_bfloat16*       __restrict__ tail_dense,
    int                  dense_ld,
    const int*           __restrict__ tail_dense_map
) {

    int row = blockIdx.x;

    if (row >= M_rows) return;



    int nnz = row_counters[row];

    if (nnz <= ELL_WIDTH) return;


    int dr = tail_dense_map[row];

    if (dr < 0) return;


    const uint16_t*      r_cols = ell_cols + (size_t)row * ELL_WIDTH;

    const __nv_bfloat16* r_vals = ell_vals + (size_t)row * ELL_WIDTH;




    int k0 = (threadIdx.x * 8);

    int stride = (blockDim.x * 8);



    for (int k = k0; k < ELL_WIDTH; k += stride) {


        int4 idx_raw = *reinterpret_cast<const int4*>(r_cols + k);
        int4 val_raw = *reinterpret_cast<const int4*>(r_vals + k);


        const uint16_t* idx8 = reinterpret_cast<const uint16_t*>(&idx_raw);

        const __nv_bfloat16* v8 = reinterpret_cast<const __nv_bfloat16*>(&val_raw);



        #pragma unroll

        for (int t = 0; t < 8; ++t) {

            int col = (int)idx8[t];

            if ((unsigned)col < (unsigned)dense_ld) {

                tail_dense[(size_t)dr * dense_ld + (size_t)col] = v8[t];

            }

        }

    }

}





__global__ void scatter_tail_dense_rows_into_full_old(

    __nv_bfloat16*       __restrict__ A,
    int                  M_rows,

    int                  N_cols,

    const __nv_bfloat16* __restrict__ B,
    int                  tail_rows,

    const int*           __restrict__ tail_dense_map
) {

    int row = blockIdx.x;

    if (row >= M_rows) return;



    int dr = tail_dense_map[row];

    if (dr < 0) return;

    if ((unsigned)dr >= (unsigned)tail_rows) return;



    const __nv_bfloat16* src = B + (size_t)dr  * N_cols;

    __nv_bfloat16*       dst = A + (size_t)row * N_cols;

    int tid = threadIdx.x;

    constexpr int vec_elems = 8;

    int vecs = N_cols / vec_elems;

    int rem  = N_cols % vec_elems;


    for (int v = tid; v < vecs; v += blockDim.x) {

        int4 x = *reinterpret_cast<const int4*>(src + v * vec_elems);

        *reinterpret_cast<int4*>(dst + v * vec_elems) = x;

    }




    if (rem && tid == 0) {

        int base = vecs * vec_elems;

        for (int i = 0; i < rem; ++i) {

            dst[base + i] = src[base + i];

        }

    }

}








__device__ bool is_aligned_16(const void* ptr) {

    return ((uintptr_t)ptr & 0xF) == 0;

}



__global__ void scatter_tail_dense_rows_into_full(

    __nv_bfloat16*       __restrict__ A,
    int                  M_rows,

    int                  N_cols,

    const __nv_bfloat16* __restrict__ B,
    int                  tail_rows,

    const int*           __restrict__ tail_dense_map,
    int                  stride_A,
    int                  stride_B
) {

    int row = blockIdx.x;

    if (row >= M_rows) return;



    int dr = tail_dense_map[row];


    if (dr < 0 || dr >= tail_rows) {

        if (dr  >= tail_rows && blockIdx.x == 0 && threadIdx.x == 0) {

            printf("Row %d: dr=%d is out of valid range [0,%d)\n", row, dr, tail_rows);

        }

        return;

    }



    const __nv_bfloat16* src = B + (size_t)dr  * stride_B;

    __nv_bfloat16*       dst = A + (size_t)row * stride_A;




    if (!is_aligned_16(src) || !is_aligned_16(dst)) {

        if (blockIdx.x == 0 && threadIdx.x == 0 && threadIdx.y == 0) {

            printf("Row %d: misaligned src or dst pointer (src=%p dst=%p)\n",

                   row, (void*)src, (void*)dst);

        }

        return;
    }




    if (N_cols % 8 != 0) {

        if (blockIdx.x == 0 && threadIdx.x == 0) {

            printf("N_cols (%d) not divisible by 8; truncating to %d vecs\n",

                   N_cols, N_cols / 8);

        }

    }



    int tid = threadIdx.x;

    constexpr int vec_elems = 8;

    int vecs = N_cols / vec_elems;

    int rem  = N_cols % vec_elems;




    for (int v = tid; v < vecs; v += blockDim.x) {


        int base = v * vec_elems;

        if (base + vec_elems > N_cols) {

            if (blockIdx.x == 0 && threadIdx.x == 0) {

                printf("Row %d: vector index %d out of bounds (N_cols=%d)\n",

                       row, v, N_cols);

            }

            return;

        }


        int4 x = *reinterpret_cast<const int4*>(src + base);

        *reinterpret_cast<int4*>(dst + base) = x;

    }




    if (rem && tid == 0) {

        int base = vecs * vec_elems;

        for (int i = 0; i < rem; ++i) {

            dst[base + i] = src[base + i];

        }

    }

}





__global__ void gather_rows_by_map_bf16_int4(

    const __nv_bfloat16* __restrict__ A,
    __nv_bfloat16*       __restrict__ C,
    const int*           __restrict__ map,
    int M,

    int R,

    int cols

) {

    int sr = blockIdx.x;
    if (sr >= M) return;

    if ((unsigned)sr >= (unsigned)M) return;


    int dr = map[sr];

    if (dr < 0) {



        return;

    }

    if (dr >= R) {

        if (threadIdx.x == 0) printf("Wrong tail dr %d\n", dr);

    }



    const __nv_bfloat16* __restrict__ src = A + (size_t)sr * cols;

    __nv_bfloat16*       __restrict__ dst = C + (size_t)dr * cols;




    int vec_cols = cols & ~7;


    for (int j = threadIdx.x * 8; j < vec_cols; j += blockDim.x * 8) {

        const int4 v = *reinterpret_cast<const int4*>(src + j);

        *reinterpret_cast<int4*>(dst + j) = v;

    }




    for (int j = vec_cols + threadIdx.x; j < cols; j += blockDim.x) {

        dst[j] = src[j];

    }

}









__global__ void convert_slabs_to_ell_with_l0(

    const int*           __restrict__ slabs_data,

    uint16_t*            __restrict__ ell_col_indices,

    __nv_bfloat16*       __restrict__ ell_values,

    int*                 __restrict__ row_counters,

    int*                 __restrict__ num_slabs_ptr,

    int                                  M_rows,

    int* __restrict__ overflow_counter,

    __nv_bfloat16*       __restrict__ tail_dense,
    int                  dense_ld,
    int*           __restrict__ tail_dense_map,
    int*           __restrict__ tail_dense_map_reverse,
    float* __restrict__ l0,

    float* __restrict__ l1

) {


    int lane_id   = threadIdx.x & 31;

    int warp_in_block = threadIdx.x >> 5;
    int warps_per_block = blockDim.x >> 5;
    int global_warp_id  = blockIdx.x * warps_per_block + warp_in_block;

    int total_warps     = gridDim.x * warps_per_block;

    int num_slabs = *num_slabs_ptr / 32;



    float thread_l0 = 0.0f;

    float thread_l1 = 0.0f;

    __shared__ float block_l0[32];

    __shared__ float block_l1[32];



    if (warp_in_block == 0) {

        block_l0[lane_id] = 0.0f;

        block_l1[lane_id] = 0.0f;

    }



    for (int warp_id = global_warp_id; warp_id < num_slabs; warp_id += total_warps) {

        int raw_val    = slabs_data[warp_id * SLAB_SIZE + lane_id];

        int header_raw = __shfl_sync(0xFFFFFFFF, raw_val, 0);

        uint16_t tile_x = (header_raw >> 16) & 0xFFFF;

        uint16_t tile_y = header_raw & 0xFFFF;



        if (lane_id > 0 && raw_val != 0) {

            uint32_t i, j;

            __nv_bfloat16 val;

            unpack_element(raw_val, i, j, val);

            int g_col, g_row;



            g_row = tile_x * TILE_DIM_X + i;

            g_col = tile_y * TILE_DIM_Y + j;



            if (l0 != nullptr) {

                float a = __bfloat162float(val);

                thread_l0 += 1.f / M_rows;

                thread_l1 += a / M_rows;

            } 



            if (g_row < M_rows) {


                int k = atomicAdd(&row_counters[g_row], 1);

                if (k < ELL_WIDTH) {


                    size_t addr = (size_t)g_row * ELL_WIDTH + k;

                    ell_col_indices[addr] = (uint16_t)g_col;

                    ell_values[addr]      = val;

                } else {


                    populate_dense(g_row, g_col, val, tail_dense, dense_ld, tail_dense_map, tail_dense_map_reverse, overflow_counter);

                }

            }

        }

    }



    if (l0 != nullptr) {




        for (int offset = 32 >> 1; offset > 0; offset >>= 1) {

            thread_l0 += __shfl_down_sync(0xffffffff, thread_l0, offset);

        }

        for (int offset = 32 >> 1; offset > 0; offset >>= 1) {

            thread_l1 += __shfl_down_sync(0xffffffff, thread_l1, offset);

        }


        if (lane_id == 0) {

            block_l0[warp_in_block] = thread_l0;        

            block_l1[warp_in_block] = thread_l1;        

        }

            

        __syncthreads();

        float l0_sum = 0.0f;

        float l1_sum = 0.0f;

        if (warp_in_block == 0) {

            l0_sum = block_l0[threadIdx.x];

            l1_sum = block_l1[threadIdx.x];

        }



        for (int offset = 32 >> 1; offset > 0; offset >>= 1) {

            l0_sum += __shfl_down_sync(0xffffffff, l0_sum, offset);

        }

        for (int offset = 32 >> 1; offset > 0; offset >>= 1) {

            l1_sum += __shfl_down_sync(0xffffffff, l1_sum, offset);

        }



        if (threadIdx.x == 0) {

            atomicAdd(l0, l0_sum);

            atomicAdd(l1, l1_sum);

        }

    }

}





__global__ void convert_slabs_to_ell(

    const int*           __restrict__ slabs_data,

    uint16_t*            __restrict__ ell_col_indices,

    __nv_bfloat16*       __restrict__ ell_values,

    int*                 __restrict__ row_counters,

    int*                 __restrict__ num_slabs_ptr,

    int                                  M_rows,

    int* __restrict__ overflow_counter,

    __nv_bfloat16*       __restrict__ tail_dense,
    int                  dense_ld,
    int*           __restrict__ tail_dense_map,
    int*           __restrict__ tail_dense_map_reverse
) {


    int lane_id   = threadIdx.x & 31;

    int warp_in_block = threadIdx.x >> 5;
    int warps_per_block = blockDim.x >> 5;
    int global_warp_id  = blockIdx.x * warps_per_block + warp_in_block;

    int total_warps     = gridDim.x * warps_per_block;

    int num_slabs = *num_slabs_ptr / 32;



    for (int warp_id = global_warp_id; warp_id < num_slabs; warp_id += total_warps) {

        int raw_val    = slabs_data[warp_id * SLAB_SIZE + lane_id];

        int header_raw = __shfl_sync(0xFFFFFFFF, raw_val, 0);

        uint16_t tile_x = (header_raw >> 16) & 0xFFFF;

        uint16_t tile_y = header_raw & 0xFFFF;



        if (lane_id > 0 && raw_val != 0) {

            uint32_t i, j;

            __nv_bfloat16 val;

            unpack_element(raw_val, i, j, val);

            int g_col, g_row;



            g_row = tile_x * TILE_DIM_X + i;

            g_col = tile_y * TILE_DIM_Y + j;





            if (g_row < M_rows) {


                int k = atomicAdd(&row_counters[g_row], 1);

                if (k < ELL_WIDTH) {


                    size_t addr = (size_t)g_row * ELL_WIDTH + k;

                    ell_col_indices[addr] = (uint16_t)g_col;

                    ell_values[addr]      = val;

                } else {


                    populate_dense(g_row, g_col, val, tail_dense, dense_ld, tail_dense_map, tail_dense_map_reverse, overflow_counter);

                }

            }

        }

    }

}



#include <cuda_runtime.h>

#include <cuda_bf16.h>




union OutputPack {

    __nv_bfloat16 bf[8];

    int4 vec;

};



__global__ void ell_spmm_rowmajor_b_rowwise_optimized(

    const __nv_bfloat16* __restrict__ A_vals,
    const uint16_t* __restrict__ A_idxs,
    const int* __restrict__ row_counts,
    const __nv_bfloat16* __restrict__ B,


    __nv_bfloat16* __restrict__ C,
    int M_rows,

    int K_rows,

    int N_cols

) {



    int row = blockIdx.x;

    if (row >= M_rows) return;



    int nnz = row_counts[row];

    if (nnz <= 0) return;


    if (nnz > ELL_WIDTH) return; 







    __shared__ __nv_bfloat16 sh_vals[ELL_WIDTH];

    __shared__ uint16_t sh_idxs[ELL_WIDTH];

    


    const __nv_bfloat16* A_row_vals_g = A_vals + (size_t)row * ELL_WIDTH;

    const uint16_t* A_row_idxs_g = A_idxs + (size_t)row * ELL_WIDTH;





    for (int k = threadIdx.x; k < nnz; k += blockDim.x) {

        sh_idxs[k] = A_row_idxs_g[k];

        sh_vals[k] = A_row_vals_g[k];

    }

    


    __syncthreads();





    for (int n_out = threadIdx.x * VEC_SIZE; n_out < N_cols; n_out += VEC_SIZE * blockDim.x) {



        float2 acc[4];

        #pragma unroll

        for (int i = 0; i < 4; ++i) acc[i] = make_float2(0.f, 0.f);






        for (int k = 0; k < nnz; ++k) {


            __nv_bfloat16 a_val = sh_vals[k];

            uint16_t    col_idx = sh_idxs[k];




            const __nv_bfloat16* B_row_ptr = B + (size_t)col_idx * N_cols + n_out;





            int4 b_vec_raw = *reinterpret_cast<const int4*>(B_row_ptr);

            


            __nv_bfloat16* b_vec = reinterpret_cast<__nv_bfloat16*>(&b_vec_raw);

            __nv_bfloat162* b_pairs = reinterpret_cast<__nv_bfloat162*>(b_vec);



            float a= __bfloat162float(a_val);






            float2 b_f32;

            

            b_f32 = __bfloat1622float2(b_pairs[0]);

            acc[0].x = fmaf(a, b_f32.x, acc[0].x);

            acc[0].y = fmaf(a, b_f32.y, acc[0].y);



            b_f32 = __bfloat1622float2(b_pairs[1]);

            acc[1].x = fmaf(a, b_f32.x, acc[1].x);

            acc[1].y = fmaf(a, b_f32.y, acc[1].y);



            b_f32 = __bfloat1622float2(b_pairs[2]);

            acc[2].x = fmaf(a, b_f32.x, acc[2].x);

            acc[2].y = fmaf(a, b_f32.y, acc[2].y);



            b_f32 = __bfloat1622float2(b_pairs[3]);

            acc[3].x = fmaf(a, b_f32.x, acc[3].x);

            acc[3].y = fmaf(a, b_f32.y, acc[3].y);

        }





        OutputPack out;

        out.bf[0] = __float2bfloat16(acc[0].x);

        out.bf[1] = __float2bfloat16(acc[0].y);

        out.bf[2] = __float2bfloat16(acc[1].x);

        out.bf[3] = __float2bfloat16(acc[1].y);

        out.bf[4] = __float2bfloat16(acc[2].x);

        out.bf[5] = __float2bfloat16(acc[2].y);

        out.bf[6] = __float2bfloat16(acc[3].x);

        out.bf[7] = __float2bfloat16(acc[3].y);



        __nv_bfloat16* C_ptr = C + (size_t)row * N_cols + n_out;

        *reinterpret_cast<int4*>(C_ptr) = out.vec;

    }

}








static void set_l2_persist_for_matrix(

    const void* ptr,

    size_t num_bytes,

    cudaStream_t stream,

    float hit_ratio = 1.0f

) {

    cudaStreamAttrValue attr = {};

    attr.accessPolicyWindow.base_ptr = const_cast<void*>(ptr);

    attr.accessPolicyWindow.num_bytes = num_bytes;

    attr.accessPolicyWindow.hitRatio = hit_ratio;

    attr.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;

    attr.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;



    cudaError_t err = cudaStreamSetAttribute(

        stream,

        cudaStreamAttributeAccessPolicyWindow,

        &attr

    );




    if (err != cudaSuccess && err != cudaErrorNotSupported) {

        cudaGetLastError();
    }

}



static void reset_l2_persist(cudaStream_t stream) {

    cudaStreamAttrValue reset_attr = {};

    reset_attr.accessPolicyWindow.num_bytes = 0;

    cudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &reset_attr);

}















void set_matrix_b_persistence(const void* ptr, size_t size, cudaStream_t stream) {

    cudaStreamAttrValue attr;

    attr.accessPolicyWindow.base_ptr  = (void*)ptr;

    attr.accessPolicyWindow.num_bytes = size;

    attr.accessPolicyWindow.hitRatio  = 1.0f;

    attr.accessPolicyWindow.hitProp   = cudaAccessPropertyPersisting;

    attr.accessPolicyWindow.missProp  = cudaAccessPropertyStreaming;

    cudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &attr);

}



hybrid_sp_t::hybrid_sp_t(int M, int N, torch::Device device) {

    auto options_tint16 = at::TensorOptions().dtype(torch::kInt16).device(device);

    auto options_tint32 = at::TensorOptions().dtype(torch::kInt32).device(device);

    auto options_bf16 = at::TensorOptions().dtype(torch::kBFloat16).device(device);

    this->_ell_col_indices = at::empty({M * ELL_WIDTH}, options_tint16);

    this->_ell_values = at::empty({M * ELL_WIDTH}, options_bf16);

    this->_row_counters = at::zeros({M}, options_tint32);

    this->_overflow_counter = at::zeros({}, options_tint32);

    this->_tail_dense = at::zeros({TAIL_CAPACITY_ROWS, N}, options_bf16);
    this->_tail_dense_map = at::full({M}, -1, options_tint32);

    this->_tail_dense_map_reverse = at::full({TAIL_CAPACITY_ROWS}, -1, options_tint32);

    this->_dense_active_rows = TAIL_CAPACITY_ROWS;

    this->hN = at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kInt).pinned_memory(true));

}



hybrid_sp_t::hybrid_sp_t(const hybrid_sp_t& sp) {

    this->_ell_col_indices = sp._ell_col_indices;

    this->_ell_values = sp._ell_values;

    this->_row_counters = sp._row_counters;

    this->_overflow_counter = sp._overflow_counter;

    this->_tail_dense = sp._tail_dense;

    this->_tail_dense_map = sp._tail_dense_map;

    this->_tail_dense_map_reverse = sp._tail_dense_map_reverse;

    this->_dense_active_rows = sp._dense_active_rows;

}




void hybrid_sp_t::reset_vals() {

    auto options_bf16 = at::TensorOptions().dtype(torch::kBFloat16).device(this->_ell_values.device());

    this->_ell_values = at::zeros({this->_ell_values.numel()}, options_bf16);


}










void create_hybrid_sparse(at::Tensor& slabs_data_d, at::Tensor& num_slabs_d, hybrid_sp_t* sp, at::Tensor& l0, at::Tensor& l1, int M, int N, cudaStream_t stream) {

    PERF_START("create_hybrid_sparse_total", stream);



    auto device = slabs_data_d.get_device();

    int num_sms;

    cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);



    int block_slabs = 256;

    int grid_slabs  = num_sms * 4;


    PERF_START("create_hybrid_sparse:convert_slabs_to_ell", stream);

    convert_slabs_to_ell_with_l0<<<grid_slabs, block_slabs, 0, stream>>>(

        static_cast<int32_t*>(slabs_data_d.data_ptr()),

        sp->ell_col_indices(),

        sp->ell_values(),

        sp->row_counters(),

        static_cast<int32_t*>(num_slabs_d.data_ptr()),

        M,

        sp->overflow_counter(),

        sp->tail_dense(),

        N,

        sp->tail_dense_map(),

        sp->tail_dense_map_reverse(),

        static_cast<float*>(l0.data_ptr()),

        static_cast<float*>(l1.data_ptr())

    );

    PERF_STOP("create_hybrid_sparse:convert_slabs_to_ell");

#if ERROR_CHECK

    cudaError_t __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "INIT-ELL: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

        TORCH_CHECK(false, "error on hybrid");

    }

#endif

    PERF_START("create_hybrid_sparse:promote_overflow", stream);

    promote_overflow_rows_ell_into_tail_dense<<<M, 128, 0, stream>>>(

        sp->ell_col_indices(),

        sp->ell_values(),

        sp->row_counters(),

        M,

        sp->tail_dense(),

        N,

        sp->tail_dense_map()

    );

    PERF_STOP("create_hybrid_sparse:promote_overflow");

#if ERROR_CHECK

    __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "INIT-Dense: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

        TORCH_CHECK(false, "error on hybrid");

    }

#endif



    PERF_STOP("create_hybrid_sparse_total");

}



void create_hybrid_sparse(at::Tensor& slabs_data_d, at::Tensor& num_slabs_d, hybrid_sp_t* sp, int M, int N, cudaStream_t stream) {

    auto device = slabs_data_d.get_device();

    int num_sms;

    cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);



    int block_slabs = 256;

    int grid_slabs  = num_sms * 4;
   

    convert_slabs_to_ell<<<grid_slabs, block_slabs, 0, stream>>>(

        static_cast<int32_t*>(slabs_data_d.data_ptr()),

        sp->ell_col_indices(), 

        sp->ell_values(),

        sp->row_counters(),

        static_cast<int32_t*>(num_slabs_d.data_ptr()),

        M,

        sp->overflow_counter(),

        sp->tail_dense(),

        N,

        sp->tail_dense_map(),

        sp->tail_dense_map_reverse()

    );

#if ERROR_CHECK

    cudaError_t __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "INIT-ELL: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__); 

        TORCH_CHECK(false, "error on hybrid");

    }

#endif

    promote_overflow_rows_ell_into_tail_dense<<<M, 128, 0, stream>>>(

        sp->ell_col_indices(),

        sp->ell_values(),

        sp->row_counters(),

        M,

        sp->tail_dense(),

        N,

        sp->tail_dense_map()

    );

#if ERROR_CHECK

    __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "INIT-Dense: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__); 

        TORCH_CHECK(false, "error on hybrid");

    }

#endif

}






__global__ void dense_ab_as_sparse_hybrid_unified_optimized(

    const __nv_bfloat16* __restrict__ A,
    const __nv_bfloat16* __restrict__ B_T,
    const uint16_t* __restrict__ ell_cols,
    const int* __restrict__ row_counts,
    __nv_bfloat16* __restrict__ C_ell_vals,
    const float* __restrict__ init_val,
    int M_rows,

    int K,

    int N_cols

) {

    int row = blockIdx.x;

    if (row >= M_rows) return;



    int nnz_total = row_counts[row];

    if (nnz_total <= 0) return;

    if (nnz_total > ELL_WIDTH) return;




    extern __shared__ __nv_bfloat16 sh_A[];



    const __nv_bfloat16* A_row_ptr = A + (size_t)row * K;




    int num_vec_k = K / 8;

    int tail_k    = K % 8;



    int4* sh_A_vec = reinterpret_cast<int4*>(sh_A);

    const int4* A_row_vec = reinterpret_cast<const int4*>(A_row_ptr);




    for (int i = threadIdx.x; i < num_vec_k; i += blockDim.x) {

        sh_A_vec[i] = A_row_vec[i];

    }




    if (tail_k > 0) {

        int tail_start = num_vec_k * 8;

        for (int i = tail_start + threadIdx.x; i < K; i += blockDim.x) {

            sh_A[i] = A_row_ptr[i];

        }

    }



    __syncthreads();




    const int lane_id   = threadIdx.x & 31;

    const int warp_id   = threadIdx.x >> 5;

    const int num_warps = blockDim.x >> 5;



    float base_init = (init_val) ? *init_val : 0.0f;




    for (int out_idx = warp_id; out_idx < nnz_total; out_idx += num_warps) {




        int col = (int)ell_cols[(size_t)row * ELL_WIDTH + out_idx];

        if (col < 0 || col >= N_cols) continue;




        const __nv_bfloat16* B_row_ptr = B_T + (size_t)col * K;



        float acc = base_init;




        int k_vec_idx = lane_id;




        while (k_vec_idx < num_vec_k) {


            int4 a_raw = *reinterpret_cast<const int4*>(&sh_A[k_vec_idx * 8]);




            int4 b_raw = *reinterpret_cast<const int4*>(&B_row_ptr[k_vec_idx * 8]);



            __nv_bfloat162* a2 = reinterpret_cast<__nv_bfloat162*>(&a_raw);

            __nv_bfloat162* b2 = reinterpret_cast<__nv_bfloat162*>(&b_raw);




            #pragma unroll

            for (int t = 0; t < 4; ++t) {

                float2 af = __bfloat1622float2(a2[t]);

                float2 bf = __bfloat1622float2(b2[t]);

                acc = fmaf(af.x, bf.x, acc);

                acc = fmaf(af.y, bf.y, acc);

            }



            k_vec_idx += 32;
        }




        if (tail_k > 0) {

            int k_idx = num_vec_k * 8 + lane_id;

            if (k_idx < K) {

                float a = __bfloat162float(sh_A[k_idx]);

                float b = __bfloat162float(B_row_ptr[k_idx]);

                acc = fmaf(a, b, acc);

            }

        }




        #pragma unroll

        for (int offset = 16; offset > 0; offset >>= 1) {

            acc += __shfl_xor_sync(0xffffffff, acc, offset);

        }




        if (lane_id == 0) {

            C_ell_vals[(size_t)row * ELL_WIDTH + out_idx] = __float2bfloat16(acc);

        }

    }

}



__device__ __forceinline__ bool bf16_nonzero_mask(__nv_bfloat16 v) {


    return (__bfloat16_as_ushort(v) & 0x7FFF) != 0;

}



__global__ void tail_dense_masked_add_inplace(

    __nv_bfloat16*       __restrict__ out_tail,
    const __nv_bfloat16*       __restrict__ tail,
    const __nv_bfloat16* __restrict__ dense_out,
    int rows,

    int N,

    const float*         __restrict__ init_val
) {


    const int64_t total_vec = (int64_t)rows * (int64_t)N / 8;



    int64_t tid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;

    if (tid >= total_vec) return;



    float init = init_val ? *init_val : 0.0f;




    int64_t elem0 = tid * 8;




    int4 out_raw = *reinterpret_cast<const int4*>(out_tail  + elem0);

    int4 tail_raw = *reinterpret_cast<const int4*>(tail  + elem0);

    int4 den_raw = *reinterpret_cast<const int4*>(dense_out + elem0);



    __nv_bfloat16* out8 = reinterpret_cast<__nv_bfloat16*>(&out_raw);

    __nv_bfloat16* tail8 = reinterpret_cast<__nv_bfloat16*>(&tail_raw);

    __nv_bfloat16* den8 = reinterpret_cast<__nv_bfloat16*>(&den_raw);



    #pragma unroll

    for (int i = 0; i < 8; ++i) {

        bool m = bf16_nonzero_mask(tail8[i]);
        if (m) {

            float d = __bfloat162float(den8[i]);

            out8[i] = __float2bfloat16(init + d);

        } else {

            out8[i] = __float2bfloat16(0.0f);
        }

    }



    *reinterpret_cast<int4*>(out_tail + elem0) = out_raw;

}



void new_product_as_sparse_sma(hybrid_sp_t* out, at::Tensor const& a, at::Tensor const& b, at::Tensor const& init_val, int M, int N, int K, cudaStream_t stream) {

    PERF_START("new_product_as_sparse_total", stream);



    const __nv_bfloat16* A_ptr  =

        reinterpret_cast<const __nv_bfloat16*>(a.contiguous().data_ptr());

    const __nv_bfloat16* BT_ptr =

        reinterpret_cast<const __nv_bfloat16*>(b.data_ptr());



    const float* init_ptr =

        init_val.data_ptr<float>();


    dim3 block(256);
    dim3 grid(M);

    size_t smem_bytes = K * sizeof(__nv_bfloat16);




    size_t BT_size = (size_t)N * K * sizeof(__nv_bfloat16);

    if (BT_size < 4 * 1024 * 1024) {
        set_l2_persist_for_matrix(BT_ptr, BT_size, stream, 1.0f);

    }



    PERF_START("new_product_as_sparse:dense_ab_kernel", stream);

    dense_ab_as_sparse_hybrid_unified_optimized<<<grid, block, smem_bytes, stream>>>(

        A_ptr, BT_ptr,

        out->ell_col_indices(),

        out->row_counters(),

        out->ell_values(),

        init_ptr,

        M,

        K,

        N

    );

    PERF_STOP("new_product_as_sparse:dense_ab_kernel");




    if (BT_size < 4 * 1024 * 1024) {

        reset_l2_persist(stream);

    }

#if ERROR_CHECK

    cudaError_t __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "Dense_ab_as_sparse: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__); 

        TORCH_CHECK(false, "error on hybrid");

    }

#endif

    if(out->_dense_active_rows) {



        auto options_bf16 = at::TensorOptions().dtype(torch::kBFloat16).device(a.device());


        at::Tensor dense = torch::empty({TAIL_CAPACITY_ROWS, K}, options_bf16);


        __nv_bfloat16* dense_ptr  =

            reinterpret_cast<__nv_bfloat16*>(dense.data_ptr());

    

        PERF_START("new_product_as_sparse:gather_rows", stream);

        gather_rows_by_map_bf16_int4<<<M, 128, 0, stream>>>(

            A_ptr, dense_ptr, out->tail_dense_map(),

            M, TAIL_CAPACITY_ROWS, K);

        PERF_STOP("new_product_as_sparse:gather_rows");

#if ERROR_CHECK

        __err = cudaDeviceSynchronize();

        if (__err != cudaSuccess) {

            fprintf(stderr, "gather_rows: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

            TORCH_CHECK(false, "error on hybrid");

        }

#endif

        if(out->_dense_active_rows > 0) {

        PERF_START("new_product_as_sparse:dense_matmul", stream);

    


        at::Tensor dense_active = dense.narrow(0, 0, out->_dense_active_rows);
    


        at::Tensor dense_out = torch::empty({out->_dense_active_rows, N},

                                           torch::TensorOptions()

                                               .dtype(torch::kBFloat16)

                                               .device(a.device()));

    



        dense_out = torch::matmul(dense_active, b.transpose(0, 1));



        PERF_STOP("new_product_as_sparse:dense_matmul");

    

        TORCH_CHECK(out->_tail_dense.is_contiguous(), "out_tail_dense must be contiguous");

        TORCH_CHECK(dense_out.is_contiguous(), "dense_out must be contiguous");

    


        const int vecs_per_row = N / 8;

        const int64_t total_vec = (int64_t)out->_dense_active_rows * (int64_t)vecs_per_row;

    

        const int threads = 256;

        const int blocks  = (int)((total_vec + threads - 1) / threads);


        auto tail_dense = out->_tail_dense.clone();

        out->_tail_dense = torch::empty_like(tail_dense);



        PERF_START("new_product_as_sparse:masked_add", stream);

        tail_dense_masked_add_inplace<<<blocks, threads, 0, stream>>>(

            reinterpret_cast<__nv_bfloat16*>(out->_tail_dense.data_ptr()),

            reinterpret_cast<const __nv_bfloat16*>(tail_dense.data_ptr()),

            reinterpret_cast<const __nv_bfloat16*>(dense_out.data_ptr()),


            out->_dense_active_rows,

            N,

            init_ptr

        );

        PERF_STOP("new_product_as_sparse:masked_add");

        }

#if ERROR_CHECK

        __err = cudaDeviceSynchronize();

        if (__err != cudaSuccess) {

            fprintf(stderr,

                    "masked_add_inplace: Fatal error: (%s at %s:%d)\n",

                    cudaGetErrorString(__err), __FILE__, __LINE__);

            TORCH_CHECK(false, "error in new_product_as_sparse_sma");

        }

#endif

    }

    PERF_STOP("new_product_as_sparse_total");

}





__device__ __forceinline__ bool bf16_is_pos_nonzero(uint16_t u)

{




    return ((u & 0x8000u) == 0u) && ((u & 0x7fffu) != 0u);

}







__global__ void pack_matrix_vec8_atomic128b(const __nv_bfloat16* __restrict__ A,

                                            int rows, int cols, int64_t out_size,

                                            int32_t* __restrict__ out,

                                            int32_t* __restrict__ g_counter)

{

    const int tile_x = blockIdx.x;

    const int tile_y = blockIdx.y;



    const int r0 = tile_y * TILE_DIM_X;

    const int c0 = tile_x * TILE_DIM_Y;



    const int th = max(0, min(TILE_DIM_X, rows - r0));

    const int tw = max(0, min(TILE_DIM_Y, cols - c0));

    if (th == 0 || tw == 0) return;



    const int tid_linear = threadIdx.x + blockDim.x * threadIdx.y;

    const int lane_id = tid_linear & 31;

    const int w_id    = tid_linear >> 5;

    const int num_warps = (blockDim.x * blockDim.y + 31) >> 5;



    __shared__ int idx; 



    constexpr int CAP = TILE_DIM_X * (TILE_DIM_Y - 33);

    __shared__ uint32_t data[CAP];



    if (tid_linear == 0) idx = 0;

    __syncthreads();



    for (int tr = threadIdx.y; tr < th; tr += blockDim.y) {

        const int irow = r0 + tr;

        const __nv_bfloat16* row = A + irow * cols;




        for (int tc8 = threadIdx.x * 8; tc8 < tw; tc8 += blockDim.x * 8) {

            const int j = c0 + tc8;



            uint16_t b[8] = {0,0,0,0,0,0,0,0};

            bool in[8];

#pragma unroll

            for (int e = 0; e < 8; ++e) in[e] = (tc8 + e) < tw;




            if (in[0] && in[7]) {

                const uintptr_t addr = reinterpret_cast<uintptr_t>(row + j);

                if ((addr & 15u) == 0u) {

                    const uint4 v = *reinterpret_cast<const uint4*>(row + j);

                    uint32_t w0 = v.x, w1 = v.y, w2 = v.z, w3 = v.w;

                    b[0] = (uint16_t)( w0        & 0xffffu);

                    b[1] = (uint16_t)((w0 >> 16) & 0xffffu);

                    b[2] = (uint16_t)( w1        & 0xffffu);

                    b[3] = (uint16_t)((w1 >> 16) & 0xffffu);

                    b[4] = (uint16_t)( w2        & 0xffffu);

                    b[5] = (uint16_t)((w2 >> 16) & 0xffffu);

                    b[6] = (uint16_t)( w3        & 0xffffu);

                    b[7] = (uint16_t)((w3 >> 16) & 0xffffu);

                } else {


#pragma unroll

                    for (int e = 0; e < 8; ++e) {

                        if (in[e]) b[e] = (uint16_t)__bfloat16_as_ushort(row[j + e]);

                    }

                }

            } else {


#pragma unroll

                for (int e = 0; e < 8; ++e) {

                    if (in[e]) b[e] = (uint16_t)__bfloat16_as_ushort(row[j + e]);

                }

            }



            bool p[8];

#pragma unroll

            for (int e = 0; e < 8; ++e) p[e] = in[e] && bf16_is_pos_nonzero(b[e]);




            uint32_t m = 0;

#pragma unroll

            for (int e = 0; e < 8; ++e) m |= (uint32_t)p[e] << e;



            int k = __popc(m);

            if (k) {

                int base = atomicAdd(&idx, k);



                if (base + k <= CAP) {

                    int off = 0;

#pragma unroll

                    for (int e = 0; e < 8; ++e) {

                        if (p[e]) {

                            uint32_t ii, tt;


                            ii = tr; tt=tc8+e;

                            data[base + off] = (uint32_t)b[e] | (ii << 16) | (tt << 24);

                            ++off;

                        }

                    }

                } else {

                    atomicSub(&idx, k);
                }

            }

        }

    }



    __syncthreads();




    const uint32_t n = (uint32_t)idx;

    const uint32_t slabs = (n + 30u) / 31u;



    for (uint32_t s = (uint32_t)w_id; s < slabs; s += (uint32_t)num_warps) {

        uint32_t val = ((uint32_t)tile_y << 16) | (uint32_t)tile_x;



        int64_t out_base = 0;

        if (lane_id == 0) out_base = (int64_t)atomicAdd(g_counter, 32u);

        if (out_base > out_size) {

            return;

        } 

        out_base = __shfl_sync(0xffffffff, out_base, 0);



        const int current_elem = (int)(s * 31u) + lane_id - 1;



        if (lane_id > 0) {

            val = ((uint32_t)current_elem < n) ? data[current_elem] : 0u;

        }



        out[out_base + lane_id] = val;

    }

}



void pack_sparse(at::Tensor out, at::Tensor out_n, at::Tensor sp, int M, int N, int64_t out_size, cudaStream_t stream) {

    PERF_START("pack_sparse", stream);



    constexpr int WARPS = 8;



    constexpr int CAP = 8192*2;



     const int tiles_x = (N + 128 - 1) / 128;
     const int tiles_y = (M + 128 - 1) / 128;
     dim3 block(16, 32);

     dim3 grid(tiles_x, tiles_y);

    pack_matrix_vec8_atomic128b<<<grid, block, 0, stream>>>(

       reinterpret_cast<const __nv_bfloat16*>(sp.data_ptr()), M, N, out_size, out.data_ptr<int32_t>(), out_n.data_ptr<int32_t>());

#if ERROR_CHECK

    if (out_n.item<int>() >= out_size) {

        std::cout<<"Created "<<out_n<<" limit "<<out_size<<std::endl;

        TORCH_CHECK(false, "overflow when packing sparse");

    }

    cudaError_t __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "Pack-Sparse: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

        TORCH_CHECK(false, "error on transpose");

    }

#endif



    PERF_STOP("pack_sparse");

}



__global__ void transpose_hybrid_ell_dense(


    const uint16_t*      __restrict__ A_ell_cols,

    const __nv_bfloat16* __restrict__ A_ell_vals,

    const int*           __restrict__ A_row_counts,

    const __nv_bfloat16* __restrict__ A_tail_dense,
    const int*           __restrict__ A_tail_dense_map_reverse,
    int*                 __restrict__ A_tail_dense_rows,
    int                  A_dense_ld,
    int M_rows,

    int N_cols,




    uint16_t*            __restrict__ AT_ell_cols,

    __nv_bfloat16*       __restrict__ AT_ell_vals,

    int*                 __restrict__ AT_row_counts,

    int*                 __restrict__ AT_overflow_counter,

    __nv_bfloat16*       __restrict__ tail_dense,
    int                  dense_ld,
    int*                 __restrict__ tail_dense_map,
    int*                 __restrict__ tail_dense_map_reverse,



    const int*           __restrict__ precomputed_tail_dense_map,

    bool                 use_precomputed_maps

) {




    for (int row = blockIdx.x; row < M_rows; row += gridDim.x) {

        int nnz_row = A_row_counts[row];

        if (nnz_row <= 0) continue;





        if (nnz_row > ELL_WIDTH) continue;



        const int ell_n = nnz_row;


        for (int k = threadIdx.x; k < ell_n; k += blockDim.x) {

            uint16_t col = A_ell_cols[(size_t)row * ELL_WIDTH + k];

            __nv_bfloat16 val = A_ell_vals[(size_t)row * ELL_WIDTH + k];



            const int out_row = (int)col;
            const int out_col = row;
            if ((unsigned)out_row >= (unsigned)N_cols) continue;

            int pos = atomicAdd(&AT_row_counts[out_row], 1);

            if (pos < ELL_WIDTH) {

                size_t addr = (size_t)out_row * ELL_WIDTH + pos;

                AT_ell_cols[addr] = (uint16_t)out_col;

                AT_ell_vals[addr] = val;

            } else {

                if (use_precomputed_maps) {

                    populate_dense_known(out_row, out_col, val,

                                       tail_dense, /*dense_ld_AT=*/M_rows,

                                       precomputed_tail_dense_map);

                } else {

                    populate_dense(out_row, out_col, val,

                                   tail_dense, /*dense_ld_AT=*/M_rows,

                                   tail_dense_map, tail_dense_map_reverse, AT_overflow_counter);

                }

            }

        }

    }






    int a_dense_rows = *A_tail_dense_rows;

    for (int dense_row = blockIdx.x; dense_row < a_dense_rows; dense_row += gridDim.x) {

        const int src_row = A_tail_dense_map_reverse[dense_row];

        if (src_row < 0) continue;
        if ((unsigned)src_row >= (unsigned)M_rows) continue;



        const __nv_bfloat16* __restrict__ src = A_tail_dense + (size_t)dense_row * A_dense_ld;




        for (int col0 = threadIdx.x * 8; col0 < dense_ld; col0 += blockDim.x * 8) {



            if (col0 + 7 < dense_ld) {


                int4 raw = *reinterpret_cast<const int4*>(src + col0);





                if ((raw.x | raw.y | raw.z | raw.w) == 0) {

                    continue;

                }




                uint32_t w0 = (uint32_t)raw.x;

                uint32_t w1 = (uint32_t)raw.y;

                uint32_t w2 = (uint32_t)raw.z;

                uint32_t w3 = (uint32_t)raw.w;



                uint16_t e[8];

                e[0] = (uint16_t)(w0 & 0xFFFF);

                e[1] = (uint16_t)(w0 >> 16);

                e[2] = (uint16_t)(w1 & 0xFFFF);

                e[3] = (uint16_t)(w1 >> 16);

                e[4] = (uint16_t)(w2 & 0xFFFF);

                e[5] = (uint16_t)(w2 >> 16);

                e[6] = (uint16_t)(w3 & 0xFFFF);

                e[7] = (uint16_t)(w3 >> 16);




                unsigned mask = 0;

                #pragma unroll

                for (int t = 0; t < 8; ++t) {

                    mask |= (unsigned)(e[t] != 0) << t;

                }

                if (!mask) continue;




                while (mask) {

                    int t = __ffs(mask) - 1;

                    mask &= (mask - 1);



                    const int out_row = col0 + t;
                    if ((unsigned)out_row >= (unsigned)N_cols) continue;




                    __nv_bfloat16 val = src[out_row];



                    const int out_col = src_row;


                    int pos = atomicAdd(&AT_row_counts[out_row], 1);

                    if (pos < ELL_WIDTH) {

                        size_t addr = (size_t)out_row * ELL_WIDTH + pos;

                        AT_ell_cols[addr] = (uint16_t)out_col;

                        AT_ell_vals[addr] = val;

                    } else {

                        if (use_precomputed_maps) {

                            populate_dense_known(out_row, out_col, val,

                                               tail_dense, /*dense_ld_AT=*/M_rows,

                                               precomputed_tail_dense_map);

                        } else {

                            populate_dense(out_row, out_col, val,

                                           tail_dense, /*dense_ld_AT=*/M_rows,

                                           tail_dense_map, tail_dense_map_reverse, AT_overflow_counter);

                        }

                    }

                }

            } else {


                for (int c = col0; c < dense_ld; ++c) {

                    __nv_bfloat16 val = src[c];


                    if (*reinterpret_cast<const uint16_t*>(&val) == 0) continue;



                    const int out_row = c;

                    if ((unsigned)out_row >= (unsigned)N_cols) continue;



                    const int out_col = src_row;



                    int pos = atomicAdd(&AT_row_counts[out_row], 1);

                    if (pos < ELL_WIDTH) {

                        size_t addr = (size_t)out_row * ELL_WIDTH + pos;

                        AT_ell_cols[addr] = (uint16_t)out_col;

                        AT_ell_vals[addr] = val;

                    } else {

                        if (use_precomputed_maps) {

                            populate_dense_known(out_row, out_col, val,

                                               tail_dense, /*dense_ld_AT=*/M_rows,

                                               precomputed_tail_dense_map);

                        } else {

                            populate_dense(out_row, out_col, val,

                                           tail_dense, /*dense_ld_AT=*/M_rows,

                                           tail_dense_map, tail_dense_map_reverse, AT_overflow_counter);

                        }

                    }

                }

            }

        }

    }

}







void transpose_hybrid_dense(

    const hybrid_sp_t& A,
    hybrid_sp_t&       AT,
    int M_rows,

    int N_cols,

    cudaStream_t stream,

    const int* precomputed_tail_dense_map,

    const int* precomputed_tail_dense_map_reverse)

{

    PERF_START("transpose_hybrid_dense_total", stream);



    bool use_precomputed = (precomputed_tail_dense_map != nullptr);




    if (use_precomputed) {

        cudaMemcpyAsync(AT.tail_dense_map(),

                       precomputed_tail_dense_map,

                       N_cols * sizeof(int),

                       cudaMemcpyDeviceToDevice,

                       stream);

        cudaMemcpyAsync(AT.tail_dense_map_reverse(),

                       precomputed_tail_dense_map_reverse,

                       TAIL_CAPACITY_ROWS * sizeof(int),

                       cudaMemcpyDeviceToDevice,

                       stream);

    }



    auto device = A._ell_col_indices.get_device();

    int num_sms;

    cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);

    dim3 block(128);

    dim3 grid(min(M_rows, num_sms*4));


    PERF_START("transpose_hybrid_dense:transpose_kernel", stream);

    transpose_hybrid_ell_dense<<<grid, block, 0, stream>>>(

        A.ell_col_indices(),

        A.ell_values(),

        A.row_counters(),

        A.tail_dense(),

        A.tail_dense_map_reverse(),

        A.overflow_counter(),

        N_cols,

        M_rows,

        N_cols,

        AT.ell_col_indices(),

        AT.ell_values(),

        AT.row_counters(),

        AT.overflow_counter(),

        AT.tail_dense(),

        M_rows,

        AT.tail_dense_map(),

        AT.tail_dense_map_reverse(),

        use_precomputed ? precomputed_tail_dense_map : nullptr,

        use_precomputed

    );

    PERF_STOP("transpose_hybrid_dense:transpose_kernel");

#if ERROR_CHECK

    cudaError_t __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "Transpose: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

        TORCH_CHECK(false, "error on transpose");

    }

#endif


    PERF_START("transpose_hybrid_dense:promote_overflow", stream);

    promote_overflow_rows_ell_into_tail_dense<<<N_cols, 128, 0, stream>>>(

        AT.ell_col_indices(),

        AT.ell_values(),

        AT.row_counters(),

        N_cols,

        AT.tail_dense(),

        M_rows,

        AT.tail_dense_map()

    );

    PERF_STOP("transpose_hybrid_dense:promote_overflow");

#if ERROR_CHECK

    __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "Transpose Promote 1: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

        TORCH_CHECK(false, "error on transpose");

    }

#endif



    PERF_STOP("transpose_hybrid_dense_total");

}





int op_id = 0;

void sparse_dense_gemm_hybrid_dense(at::Tensor& out, hybrid_sp_t* A, const at::Tensor& B, int M, int N, int K, bool transpose_dense_part, cudaStream_t stream, const at::Tensor& B_fp32_cache) {

    PERF_START("sparse_dense_gemm_total", stream);







    size_t B_size = B.numel() * B.element_size();

    if (B_size < 4 * 1024 * 1024) {
        set_l2_persist_for_matrix(B.data_ptr(), B_size, stream, 1.0f);

    }





    PERF_START("sparse_dense_gemm:ell_spmm", stream);

    ell_spmm_rowmajor_b_rowwise_optimized<<<M, 256, 0, stream>>>(

        A->ell_values(), A->ell_col_indices(), A->row_counters(),

        static_cast<__nv_bfloat16*>(B.data_ptr()),

        static_cast<__nv_bfloat16*>(out.data_ptr()),

        M, K, N

    );

    PERF_STOP("sparse_dense_gemm:ell_spmm");




    if (B_size < 4 * 1024 * 1024) {

        reset_l2_persist(stream);

    }

#if ERROR_CHECK

    cudaError_t __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "ELL Prod 1: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

        TORCH_CHECK(false, "error on hybrid ell");

    }

#endif

    at::Tensor dense;

    PERF_START("sparse_dense_gemm:dense_matmul", stream);

    if(A->_dense_active_rows) {


        at::Tensor A_tail_bf16 = A->_tail_dense.narrow(0, 0, A->_dense_active_rows);

        at::Tensor B_bf16 = B.reshape({K, N});

    


        dense = torch::empty({A->_dense_active_rows, N},

                            torch::TensorOptions()

                                .dtype(torch::kBFloat16)

                                .device(B.device()));

    



        dense = torch::matmul(A_tail_bf16, B_bf16);



        PERF_STOP("sparse_dense_gemm:dense_matmul");

#if ERROR_CHECK

        __err = cudaDeviceSynchronize();

        if (__err != cudaSuccess) {

            int device = B.get_device();

            fprintf(stderr, "OP %d, device %d Error in dense-matmul ell (%d, %d) -> %d\n",op_id, device, M, N, A->_dense_active_rows);

            fprintf(stderr, "DENSEMATMUL ELL: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

            TORCH_CHECK(false, "error on hybrid dense");

        }

#endif

        PERF_START("sparse_dense_gemm:scatter", stream);

        scatter_tail_dense_rows_into_full_old<<<M, 128,0 ,stream>>>(

            static_cast<__nv_bfloat16*>(out.data_ptr()),

            M, N,

            static_cast<__nv_bfloat16*>(dense.data_ptr()),

            A->_dense_active_rows,

            A->tail_dense_map()



        );

        PERF_STOP("sparse_dense_gemm:scatter");

#if ERROR_CHECK

        __err = cudaDeviceSynchronize();

        op_id++;

        if (__err != cudaSuccess) {

            int device = B.get_device();

            fprintf(stderr, "OP %d, device %d Error in scatter ell (%d, %d) -> %d\n",op_id, device, M, N, A->_dense_active_rows);

            fprintf(stderr, "scatter ELL: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

            TORCH_CHECK(false, "error on hybrid scatter");

        }

#endif

    }

    PERF_STOP("sparse_dense_gemm_total");

}





__forceinline__ __device__ bool int4_all_zero(const int4& v) {

    return (v.x | v.y | v.z | v.w) == 0;

}









__global__ void sparse_elementwise_product(

    __nv_bfloat16*       __restrict__ out_ell,

    __nv_bfloat16*       __restrict__ out_dense,

    const __nv_bfloat16* __restrict__ A_ell,

    const __nv_bfloat16* __restrict__ A_dense,

    const __nv_bfloat16* __restrict__ B_ell,

    const __nv_bfloat16* __restrict__ B_dense,

    const int*           __restrict__ row_counts,

    const int*           __restrict__ tail_dense_map,

    const int*           __restrict__ b_tail_dense_map,

    int M,

    int N,

    const float* __restrict__ acc_init

) {

    constexpr int VEC = 8;
    const int lane  = threadIdx.x & 31;

    const int warp  = threadIdx.x >> 5;

    const int warps_per_block = blockDim.x >> 5;




    const int row = (int)blockIdx.x * warps_per_block + warp;

    if (row >= M) return;



    int nnz_total = row_counts[row];

    if (nnz_total <= 0) return;



    float init = acc_init ? *acc_init : 0.0f;






    if (nnz_total <= ELL_WIDTH) {

        const __nv_bfloat16* a_row = A_ell   + (size_t)row * ELL_WIDTH;

        const __nv_bfloat16* b_row = B_ell   + (size_t)row * ELL_WIDTH;

        __nv_bfloat16*       o_row = out_ell + (size_t)row * ELL_WIDTH;





        int nnz_vec = (nnz_total + (VEC - 1)) & ~(VEC - 1);

        if (nnz_vec > ELL_WIDTH) nnz_vec = ELL_WIDTH;




        for (int base = lane * VEC; base < nnz_vec; base += 32 * VEC) {

            const int4 a_raw = *reinterpret_cast<const int4*>(a_row + base);

            const int4 b_raw = *reinterpret_cast<const int4*>(b_row + base);



            int4 o_raw;

            const __nv_bfloat162* a2 = reinterpret_cast<const __nv_bfloat162*>(&a_raw);

            const __nv_bfloat162* b2 = reinterpret_cast<const __nv_bfloat162*>(&b_raw);

            __nv_bfloat162* o2       = reinterpret_cast<__nv_bfloat162*>(&o_raw);

            __nv_bfloat162 init2      = __bfloat162bfloat162(__float2bfloat16(init));



            #pragma unroll

            for (int t = 0; t < 4; ++t) {

                o2[t] = a2[t] * b2[t] + init2;

            }



            *reinterpret_cast<int4*>(o_row + base) = o_raw;

        }




        return;

    }






    int dr = tail_dense_map[row];

    if (dr < 0) return;
    if (dr != b_tail_dense_map[row]) printf("Invalid dense maps\n");

    const __nv_bfloat16* a_row = A_dense   + (size_t)dr * N;

    const __nv_bfloat16* b_row = B_dense   + (size_t)dr * N;

    __nv_bfloat16*       o_row = out_dense + (size_t)dr * N;




    for (int base = lane * VEC; base < N; base += 32 * VEC) {

        const int4 a_raw = *reinterpret_cast<const int4*>(a_row + base);

        const int4 b_raw = *reinterpret_cast<const int4*>(b_row + base);




        if (int4_all_zero(a_raw) || int4_all_zero(b_raw)) {

            *reinterpret_cast<int4*>(o_row + base) = make_int4(0, 0, 0, 0);

            continue;

        }



        int4 o_raw;

        const __nv_bfloat162* a2 = reinterpret_cast<const __nv_bfloat162*>(&a_raw);

        const __nv_bfloat162* b2 = reinterpret_cast<const __nv_bfloat162*>(&b_raw);

        __nv_bfloat162* o2       = reinterpret_cast<__nv_bfloat162*>(&o_raw);

        __nv_bfloat162 init2      = __bfloat162bfloat162(__float2bfloat16(init));



        #pragma unroll

        for (int t = 0; t < 4; ++t) {

            o2[t] = a2[t] * b2[t] + init2;

        }



        *reinterpret_cast<int4*>(o_row + base) = o_raw;

    }

}



void sparse_elementwise(hybrid_sp_t* out, hybrid_sp_t* A, hybrid_sp_t* B, int M, int N, cudaStream_t stream) {

    PERF_START("sparse_elementwise", stream);




    constexpr int THREADS = 128;

    constexpr int WARPS_PER_BLOCK = THREADS / 32;



    dim3 block(THREADS);

    dim3 grid((int)((M + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK));


    out->_ell_values = torch::zeros_like(out->_ell_values);

    out->_tail_dense = torch::zeros_like(out->_tail_dense);

    sparse_elementwise_product<<<grid, block, 0, stream>>>(

        out->ell_values(),

        out->tail_dense(),

        A->ell_values(),

        A->tail_dense(),

        B->ell_values(),

        B->tail_dense(),

        A->row_counters(),

        A->tail_dense_map(),

        B->tail_dense_map(),

        M,

        N,

        nullptr

    );

#if ERROR_CHECK

    cudaError_t __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "Sparse-elemwise-product: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

        TORCH_CHECK(false, "error on transpose");

    }

#endif



    PERF_STOP("sparse_elementwise");

}



void sparse_elementwise(hybrid_sp_t* out, hybrid_sp_t* A, hybrid_sp_t* B, int M, int N, at::Tensor& acc_init, cudaStream_t stream) {

    PERF_START("sparse_elementwise_with_acc", stream);




    constexpr int THREADS = 128;

    constexpr int WARPS_PER_BLOCK = THREADS / 32;



    dim3 block(THREADS);

    dim3 grid((int)((M + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK));


    out->_ell_values = torch::zeros_like(out->_ell_values);

    out->_tail_dense = torch::zeros_like(out->_tail_dense);

    sparse_elementwise_product<<<grid, block, 0, stream>>>(

        out->ell_values(),

        out->tail_dense(),

        A->ell_values(),

        A->tail_dense(),

        B->ell_values(),

        B->tail_dense(),

        A->row_counters(),

        A->tail_dense_map(),

        B->tail_dense_map(),

        M,

        N,

        reinterpret_cast<const float*>(acc_init.data_ptr())

    );

#if ERROR_CHECK

    cudaError_t __err = cudaDeviceSynchronize();

    if (__err != cudaSuccess) {

        fprintf(stderr, "Sparse-elemwise-product: Fatal error: (%s at %s:%d)\n", cudaGetErrorString(__err), __FILE__, __LINE__);

        TORCH_CHECK(false, "error on transpose");

    }

#endif



    PERF_STOP("sparse_elementwise_with_acc");

}

