#include "MatMulUtilities.cuh"
#include <vector>

#define DEBUG 0
#define DEBUG2 0
#define DEBUG1 0



template<typename TilingConfig, typename SparseKernelConfig>

__device__ __forceinline__ void SpMM_CopyFromGlobalToReg(//uint32_t* Registers_nz,
                                                         uint32_t    Registers_nz[64],
                                                         uint64_t*    Registers_bmp,
                                                         uint32_t*    Registers_nnz,
                                                         //const uint32_t* GlobalPTR_nz,
                                                         const uint4* GlobalPTR_nz,
                                                         const uint64_t* GlobalPTR_bmp,
                                                         const uint32_t* GlobalPTR_nnz, 
                                                         uint32_t* nnz_tile0, 
                                                         uint32_t* nnz_tile1,
                                                         int startTileIdx) 
{

constexpr int MAX_NZ_PER_BMP_div_2_4 = 8; //first divide by 2 for half, then divide by 4 for uint4. : 64 / 8 = 8
   
    // Each thread handles 2 bitmaps (each of a column)

    #if DEBUG2
        if (blockIdx.x == 0 && blockIdx.y == 383 && threadIdx.x == 127) { //[7168, 7168, 8]  //383
            printf("------Check inside Reg load...\n");
            printf("StartTileIdx: %d\n", startTileIdx);
            printf("bmp0: %u\n", GlobalPTR_bmp[startTileIdx]);
            printf("nnz0: %u\n", GlobalPTR_nnz[startTileIdx]);
        }
    #endif
#pragma unroll     
    for (int i = 0; i < 2; i++) {
        int globalTileIdx = startTileIdx + i;
        // Load bitmap
        Registers_bmp[i] = GlobalPTR_bmp[globalTileIdx];
        Registers_nnz[i] = GlobalPTR_nnz[globalTileIdx]; 

        // Load non-zero values into the register
        uint32_t num_nz_per_bitmap = __popcll(Registers_bmp[i]);
        if (i){
            *nnz_tile1 = num_nz_per_bitmap; //This is the number of halfs
        }
        else{
            *nnz_tile0 = num_nz_per_bitmap;
        }

        // Load non-zero elements (half precision) into the register
#pragma unroll 
        for (int j = 0; j < MAX_NZ_PER_BMP_div_2_4 ; j++) { //8 iterations to copy the 4 x packed two fp16s.
            //loading Vectors 
            if (j <= num_nz_per_bitmap / 8 ) {
            //if (j < num_nz_per_bitmap / 8 ) {
                //**Registers_nnz is in 'uint32' units. 
                Registers_nz[i * 32 + j * 4 + 0] = GlobalPTR_nz[Registers_nnz[i] / 4 + j].x; // load nz
                Registers_nz[i * 32 + j * 4 + 1] = GlobalPTR_nz[Registers_nnz[i] / 4 + j].y; // load nz
                Registers_nz[i * 32 + j * 4 + 2] = GlobalPTR_nz[Registers_nnz[i] / 4 + j].z; // load nz
                Registers_nz[i * 32 + j * 4 + 3] = GlobalPTR_nz[Registers_nnz[i] / 4 + j].w; // load nz
                //Registers_nz[i * 32 + j * 4 + 0] = GlobalPTR_nz[Registers_nnz[i] / 8 + j].x; // load nz
                //Registers_nz[i * 32 + j * 4 + 1] = GlobalPTR_nz[Registers_nnz[i] / 8 + j].y; // load nz
                //Registers_nz[i * 32 + j * 4 + 2] = GlobalPTR_nz[Registers_nnz[i] / 8 + j].z; // load nz
                //Registers_nz[i * 32 + j * 4 + 3] = GlobalPTR_nz[Registers_nnz[i] / 8 + j].w; // load nz
            }
        }
    }
}

// Init Shared Memory to 0
template<typename TilingConfig>
__device__ __forceinline__ void SpMM_InitSharedMemory(half* __restrict__ SharedPTR)
{
    int lane_id = threadIdx.x % WARP_SIZE;
    int warp_id = threadIdx.x / WARP_SIZE;
    //
    static_assert(TilingConfig::TILE_M % TilingConfig::BLOCK_WARPS == 0,
                  "TILE_M must be an integer multiple to BLOCK_WARPS");
    constexpr int RowsPerWarp = TilingConfig::TILE_M / TilingConfig::BLOCK_WARPS;
    //
    static_assert(TILE_K == 64, "For now, TILE_K is assumed to be 64.\n");
    const int StartRowNum         = warp_id * RowsPerWarp;
    half*     SharedPTR_PerThread = SharedPTR + StartRowNum * TILE_K + HALF_PER_128B * lane_id;
    //
    static_assert(RowsPerWarp % (WARP_SIZE * HALF_PER_128B / TILE_K) == 0,
                  "RowsPerWarp%(WARP_SIZE*HALF_PER_128B/TILE_K) should be 0\n");
    constexpr int ITERATIONS_PER_THREAD = RowsPerWarp / (WARP_SIZE * HALF_PER_128B / TILE_K);
#pragma unroll
    for (int i = 0; i < ITERATIONS_PER_THREAD; i++) {
        cp_async_ignore_src<16>(SharedPTR_PerThread, (half*)NULL);
        SharedPTR_PerThread += WARP_SIZE * HALF_PER_128B;
    }
}


//Trying col-major shared mem placement? adding the swizzled row permutation in there. 
template<typename TilingConfig, typename SparseKernelConfig>
__device__ __forceinline__ void SpMM_DecompressFromRegisterToShared(half* __restrict__ SharedPTR,
                                                                    uint32_t Registers_nz[64],
                                                                    uint64_t* Registers_bmp,
                                                                    uint32_t* nnz_tile0, 
                                                                    uint32_t* nnz_tile1,
                                                                    int TB_ROW, 
                                                                    int TB_COL)
                                                                    //int tileIdx)
{
    //tildIdx = 2*tid = nth 64x1 tile to start with. 
//entire smem space is 256x64. 
int tile_element_start = TB_ROW * 64 * 64 + TB_COL * 2;
#pragma unroll
    for (int i = 0; i < 2; i++) {
         // Reinterpret Registers_nz as half*
        half* nz_values = reinterpret_cast<half*>(Registers_nz+i*32);

        uint64_t bmp = Registers_bmp[i];
        int pos1 = 0;  // Initialize pos1 before processing rows

        // Precompute tile positions
        int fuk = tile_element_start + i;
        //int tileCol = 64 * (tileIdx + i);

        uint32_t nnz_tile = i? *nnz_tile1 : *nnz_tile0;


    #pragma unroll
        for (int j = 0; j < 64; j++){
            if (j == nnz_tile){
                break; //becomes inactive thread, waits for other threads to finish. 
            }
            //not have this be an iteration over the 64 length bitmap, but just find the msb side 1. (hardware supported - per a cuda community post)
                //no longer a loop, but a fixed single execution. 
            //pos1 = 63 - __clzll(bmp);  // Find the position of the first '1' bit from the MSB side
            pos1 = __clzll(bmp); //thought opposite.
            //bmp &= ~(1ULL << pos1);
            bmp &= ~(0x8000000000000000 >> pos1);

            int output_idx = fuk + (pos1 << 6);
            SharedPTR[output_idx] = nz_values[j];

            pos1++;
        }
    }
}




template<typename TilingConfig, typename SparseKernelConfig>
__global__ void 
//__maxnreg__(255)
SpMM_Kernel(const half*  A,
                            const uint64_t* bmp, 
                            //const uint32_t* NZ,
                            const uint4* NZ, 
                            const uint32_t* idx,
                            //const uint32_t* bmp_idx_offset,
                            const uint32_t* NZ_offset,
                            //const uint4* Compressed_A,
                            //const int*   TileOffsets,
                            const half*  B,
                            half*        Reduction_Workspace,
                            const int    M_Global,
                            const int    N_Global,
                            const int    K_Global,
                            int          Split_K,
                            const int    Batch_Size)
{    
    #if DEBUG
                
        if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0){
            printf("-------SpMM_Kernel() debugging...\n");
            printf("M_Global: %d\n", M_Global);
            printf("N_Global: %d\n", N_Global);
            printf("K_Global: %d\n", K_Global);
            printf("TILE_M: %d\n", TilingConfig::TILE_M); //256
            printf("TILE_N: %d\n", TilingConfig::TILE_N); //16
            printf("TILE_K: %d\n", TILE_K); //defined as (MMA_K * BLOCK_K_TENSORS) in TilingConfig.h //64
            printf("Register_For_SparseTiles: %d\n", SparseKernelConfig::NUM_REG_FOR_SPARSE_KERNEL); //64
            printf("BLOCK_ROW_WARPS: %d\n", TilingConfig::BLOCK_ROW_WARPS); //4
            printf("BLOCK_COL_WARPS: %d\n", TilingConfig::BLOCK_COL_WARPS); //1
            printf("Block_WARPS: %d\n", TilingConfig::BLOCK_WARPS); //4
            printf("Vector_Size: %d\n", SparseKernelConfig::VECTOR_SIZE); //4
            printf("Split_K: %d\n", Split_K); //follows input

            printf("PipelinedCoreComputations debugging...\n");
            //printf("WARP_ROW_TENSOR: %d\n", TilingConfig::WARP_ROW_TENSORS); //4
            //printf("WARP_COL_TENSOR: %d\n", TilingConfig::WARP_COL_TENSORS); //4
            printf("WARP_ROW_TENSOR is already set as 4\n");
            //printf("WARP_COL_TENSOR is already set as 4\n"); //nope, it is one.
            printf("BLOCK_K_TENSOR is already set as 4\n");
            printf("N8 for double mmaM16N8K16: %d\n", TilingConfig::N8); //1
            printf("Tiling config N2: %d\n", TilingConfig::TILE_N2); //8
            printf("REG_PER_C_TENSOR_16_16: %d\n", REG_PER_C_TENSOR_16_16); //8
        }
    #endif

    //batched SpMV ID from Z dimension
    const int mustafar_batch_id = blockIdx.z;
    // Access batched data using offsets
    const uint4* NZ_batch = NZ + NZ_offset[mustafar_batch_id]; //awwwww the problem was here! I broadcast that to uint4. 
    //const uint32_t* idx_batch = idx + bmp_idx_offset[mustafar_batch_id];
    const uint32_t* idx_batch = idx + mustafar_batch_id * (1 + M_Global * K_Global / 64); //because idx has 1 extra element per batch. 
    const uint64_t* bmp_batch = bmp + mustafar_batch_id * (M_Global * K_Global / 64);

    // Access B and C with strides
    const half* B_batch = B + mustafar_batch_id * K_Global * N_Global;
    half* C_batch = Reduction_Workspace + mustafar_batch_id * M_Global * N_Global;

    #if DEBUG
        if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
            
            printf("Output location: %p for batch %d\n", C_batch, mustafar_batch_id);
            printf("Processing batch %d (blockIdx.z = %d)\n", mustafar_batch_id, blockIdx.z);
            printf("Grid dimensions: %d x %d x %d\n", gridDim.x, gridDim.y, gridDim.z);
            printf("Block dimensions: %d x %d x %d\n", blockDim.x, blockDim.y, blockDim.z);
            // Print NZ_batch values properly (uint4 has x,y,z,w components)
            
            printf("NZ_offset[%d]: %d\n", mustafar_batch_id, NZ_offset[mustafar_batch_id]);
            printf("NZ_batch[0] for batch %d: x=%u, y=%u, z=%u, w=%u\n", 
                   mustafar_batch_id, 
                   NZ_batch[0].x, NZ_batch[0].y, NZ_batch[0].z, NZ_batch[0].w);
            printf("NZ_batch[1] for batch %d: x=%u, y=%u, z=%u, w=%u\n", 
                   mustafar_batch_id, 
                   NZ_batch[1].x, NZ_batch[1].y, NZ_batch[1].z, NZ_batch[1].w);
            // Print bmp_batch as hex since it's a bitmap
            printf("bmp_batch[0] for batch %d: 0x%016lx\n", 
                    mustafar_batch_id, 
                    bmp_batch[0]);
            printf("bmp_batch[1] for batch %d: 0x%016lx\n", 
                    mustafar_batch_id, 
                    bmp_batch[1]);
                printf("idx_batch[0] for batch %d: %u\n", 
                   mustafar_batch_id, 
                   idx_batch[0]);
            printf("idx_batch[1] for batch %d: %u\n", 
                   mustafar_batch_id, 
                   idx_batch[1]);
            printf("idx_batch[2] for batch %d: %u\n", 
                   mustafar_batch_id, 
                   idx_batch[2]);
            printf("idx_batch[M_Global*K_Global/64] for batch %d: %u\n", 
                   mustafar_batch_id, 
                   idx_batch[M_Global*K_Global/64]);
            
        }
    #endif

    //[05/19]
    //
    const int BatchID     = blockIdx.y / (M_Global / TilingConfig::TILE_M); //M_Global / TILE_M: tiling the M dimension of Matrix A.
    const int IsLastBatch = (BatchID == (Split_K - 1));
    const int x           = blockIdx.x; //block DimX is 1 for skinny matrices (see SpMM_API/line 42)
    const int y           = blockIdx.y % (M_Global / TilingConfig::TILE_M);  //blockIdx.y % (num M Tile rows): wrap around num_tile_rows
        //i.e., TB0, TB(num M tile rows), TB(2*num M tile rows) .. handle the first M tile row
    //
    const int NumKBlock        = K_Global / TILE_K;  // assert (K_Global%TILE_K==0);
    const int AverageNumKBlock = (NumKBlock - 1) / Split_K + 1;
    const int RoundedKBlock    = AverageNumKBlock * Split_K;
    const int PaddingKBlock    = RoundedKBlock - NumKBlock;
    int       NumIter          = 0;
    if (IsLastBatch)
        NumIter = AverageNumKBlock - PaddingKBlock;
    else
        NumIter = AverageNumKBlock;
    #if DEBUG
        if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { //[7168, 7168, 8]
            printf("------K dimension related Debugging info...\n");
            printf("NumKBlock: %d\n", NumKBlock); //112: how many iterations it takes to finish computing that output tile
            printf("AverageNumKBlock: %d\n", AverageNumKBlock); //16: 
            printf("RoundedKBlock: %d\n", RoundedKBlock); //112: related to the padding
            printf("PaddingKBlock: %d\n", PaddingKBlock); //0: re  lated to the padding
            printf("NumIter: %d\n", NumIter); //16: thus the final conclusion
        }
    #endif

    //the following will reside in SMSP regfile
    uint64_t Registers_bmp[2];  //4 regs
    uint32_t Registers_nnz[2];  //2 regs
    uint32_t Registers_nz[64];  //64 regs // Enough to hold non-zero values for 2 tiles 
    uint32_t nnz_tile0;
    uint32_t nnz_tile1;

    extern __shared__ __align__(128) half smem[];  // at least be 128 Bytes aligned 

    // Warp and lane identification.
    const unsigned int warpId       = threadIdx.x / WARP_SIZE;
    const int          Tile_Start_M = y * TilingConfig::TILE_M;
    const int          Tile_Start_N = x * TilingConfig::TILE_N;
    /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Compute a grid of C matrix tiles in each warp.
    int Warp_i         = warpId / TilingConfig::BLOCK_COL_WARPS;
    int Warp_j         = warpId % TilingConfig::BLOCK_COL_WARPS;
    int warp_start_row = WARP_ROW_TENSORS * MMA_M * Warp_i;
    int warp_start_col = TilingConfig::WARP_COL_TENSORS * MMA_N * Warp_j;
    //if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) {
    //    printf("#1 TilingConfig::WARP_COL_TENSORS: %d\n", TilingConfig::WARP_COL_TENSORS); //1 for sub-16, 2 for 32. 
    //}
    uint32_t __restrict__ a[WARP_ROW_TENSORS * 2][4];//[8][4] = 32 uint32 
    uint32_t __restrict__ b[TilingConfig::WARP_COL_TENSORS * 2][4]; //[8][4] = 32 uint32
    // copying B tile from GlobalMemory to SharedMemory
    //const half* BTileGlobalPTR = //B was supposed to be col-major. 
    //    B + Tile_Start_N * K_Global
    //    + BatchID * AverageNumKBlock * TILE_K;  // Address for matrix B, taking SplitK into consideration
    //
    const half* BTileGlobalPTR = B_batch + Tile_Start_N * K_Global + BatchID * AverageNumKBlock * TILE_K;

    //my definition ~ see whiteboard and paper
    //int BaseTileIdx = y * (32 * K_Global / 8) + BatchID * K_Global / (8*Split_K); //For original 8x8 tile
    //int BaseTileIdx = y * (4 * K_Global) + BatchID * K_Global / Split_K; //For 1-64 col tiles. 
    int BaseTileIdx = y * (4 * K_Global) + BatchID * K_Global / Split_K; //For 1-64 col tiles. new ver (2/7) -> hm looks correct? 
    //below changed to allow the column-wise bitmap format. 
    int tid = threadIdx.x;
    int TB_Row = tid / 32;
    int TB_Col = tid % 32;
    //int StartTileIdx = BaseTileIdx + TB_Row * K_Global / 8 + TB_Col * 2; 
    int StartTileIdx = BaseTileIdx + TB_Row * K_Global + TB_Col * 2;
    //int StartTileIdx = BaseTileIdx + tid_times_2 -2;
    //int tileIdx = 2 * tid; // for 64x1 local index for DecompressFromRegisterToShared (64x64)


    #if DEBUG
        //if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){
        if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){ //Debugging 256x64 for rectangular sanity
        //if (blockIdx.x == 0 && blockIdx.y == 0){
        //if (blockIdx.x == 127){
                printf("---Going into SpMM_CopyFromGlobalToReg()...\n \
                For thread %d, blockIdx.x: %d, blockIdx.y: %d, mustafar_batch_id: %d\n \
                StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, mustafar_batch_id, StartTileIdx);
            }
        __syncthreads(); // only for debugging
    #endif

    SpMM_CopyFromGlobalToReg<TilingConfig, SparseKernelConfig>(Registers_nz,
                                                                Registers_bmp,
                                                                Registers_nnz,
                                                                //NZ, 
                                                                //bmp, 
                                                                //idx,
                                                                NZ_batch,
                                                                bmp_batch,
                                                                idx_batch,
                                                                &nnz_tile0, 
                                                                &nnz_tile1,
                                                                StartTileIdx); 
    #if DEBUG
        //if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){
        if ( threadIdx.x == 32 && blockIdx.x == 0 && blockIdx.y == 0){ //Debugging 256x64 for rectangular sanity
        //if (blockIdx.x == 0 && blockIdx.y == 0){
        //if (blockIdx.x == 127){
                printf("---Going into SpMM_InitSharedMemory(Line 500)...\n \
                For thread %d, blockIdx.x: %d, blockIdx.y: %d, mustafar_batch_id: %d\n \
                StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, mustafar_batch_id, StartTileIdx);
            }
        __syncthreads(); // only for debugging
    #endif

    SpMM_InitSharedMemory<TilingConfig>(smem); //rst_smem
    cp_async_group_commit();
    CopyTileFromGlobalToShared_X_64<TilingConfig::TILE_N2, TilingConfig>( 
        smem + TilingConfig::TILE_M * TILE_K, BTileGlobalPTR, K_Global); //ld_dense: this is async, defined in MatMulUtilies.cuh
    cp_async_group_commit();
    
    // Initilazing C Matrix to Zeros
    float c[WARP_ROW_TENSORS * TilingConfig::WARP_COL_TENSORS][REG_PER_C_TENSOR_16_16]; // [4*4][8 in TilingConfig] = 64 floats
    for (int i = 0; i < WARP_ROW_TENSORS * TilingConfig::WARP_COL_TENSORS; i++)
        for (int j = 0; j < REG_PER_C_TENSOR_16_16; j++)
            c[i][j] = 0.0f;
    //
    cp_async_wait_group<1>();
    __syncthreads();
     #if DEBUG
        //if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){
        if ( threadIdx.x == 127 && blockIdx.x == 0 && blockIdx.y == 0){ //Debugging 256x64 for rectangular sanity
        //if (blockIdx.x == 0 && blockIdx.y == 0){
        //if (blockIdx.x == 127){
                printf("---Going into SpMM_Decompress from Reg to smem()...\n \
                For thread %d, blockIdx.x: %d, blockIdx.y: %d, mustafar_batch_id: %d\n \
                StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, mustafar_batch_id, StartTileIdx);
            }
        __syncthreads(); // only for debugging
    #endif
    SpMM_DecompressFromRegisterToShared<TilingConfig, SparseKernelConfig>(
                                                                    //SharedPTR,
                                                                    smem,
                                                                    Registers_nz,
                                                                    Registers_bmp,
                                                                    &nnz_tile0, 
                                                                    &nnz_tile1,
                                                                    TB_Row, 
                                                                    TB_Col);
                                                                    //tileIdx); //make sure to keep this tid * 2 
    //
    cp_async_wait_group<0>();
    __syncthreads();
     #if DEBUG
        //if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){
        if ( (threadIdx.x == 31 | threadIdx.x == 127) && blockIdx.x == 0 && blockIdx.y == 0){ //Debugging 256x64 for rectangular sanity
        //if (blockIdx.x == 0 && blockIdx.y == 0){
        //if (blockIdx.x == 127){
                printf("---Exit SpMM Decompression...\n \
                For thread %d, blockIdx.x: %d, blockIdx.y: %d, mustafar_batch_id: %d\n \
                StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, mustafar_batch_id, StartTileIdx);
            }
        __syncthreads(); // only for debugging
    #endif
    //StartTileIdx += 8; //for the next 246x64 tile (8 8x8 block apart row-wise)
    StartTileIdx +=64;


//
// Go through the global K dimension by a fixed step at a time.
// write buffer[1] first, read buffer[0] first
#pragma unroll(1) //unroll exactly once.
    for (int tile_id_k = 0; tile_id_k < NumIter-1; tile_id_k++) { //remove the last iteration and move computation to epilogue
        
        #if DEBUG
            //if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){
            if ( threadIdx.x == 32 && blockIdx.x == 0 && blockIdx.y == 0){ //Debugging 256x64 for rectangular sanity
            //if (blockIdx.x == 0 && blockIdx.y == 0){
            //if (blockIdx.x == 127){
                    printf("---This prob shouldn't print...\n \
                    For thread %d, blockIdx.x: %d, blockIdx.y: %d\n \
                    StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, StartTileIdx);
                }
            __syncthreads(); // only for debugging
        #endif
        

        // Using the previous prefetched value
        //int StartIndex_SparseTiles1 = StartIndex_SparseTiles_Prefetch1;
        //int NNZ_ThisTile1           = NNZ_ThisTile_Prefetch1;
       
        //
        //StartIndex_SparseTiles_Prefetch1 = TileOffsets_ThisBlock1[tile_id_k + 1 + 1];
        //NNZ_ThisTile_Prefetch1 = TileOffsets_ThisBlock1[tile_id_k + 1 + 2] - TileOffsets_ThisBlock1[tile_id_k + 1 + 1];
        
        // copying B tile from GlobalMemory to SharedMemory
        //BTileGlobalPTR = B + Tile_Start_N * K_Global + BatchID * AverageNumKBlock * TILE_K + ((tile_id_k + 1) * TILE_K);
        const half* BTileGlobalPTR = B_batch + Tile_Start_N * K_Global + BatchID * AverageNumKBlock * TILE_K + ((tile_id_k + 1) * TILE_K);

        // double buffer
        half* __restrict__ smem_write_PTR = smem;
        half* __restrict__ smem_read_PTR  = smem;
        smem_write_PTR = smem + ((tile_id_k + 1) % 2) * (TilingConfig::TILE_M * TILE_K + TILE_K * TilingConfig::TILE_N); //place for 256x64 A and 64x16 B (or TileN=32)
        smem_read_PTR  = smem + ((tile_id_k) % 2) * (TilingConfig::TILE_M * TILE_K + TILE_K * TilingConfig::TILE_N);
        //
        bool GlobalCopy = (tile_id_k + 1) < NumIter;

        SpMM_InitSharedMemory<TilingConfig>(smem_write_PTR); //rst_smem
        cp_async_group_commit();
        #if DEBUG1
            //if (threadIdx.x == 127 && blockIdx.x == 0 && blockIdx.y == 1){
            //if (blockIdx.x == 0 && blockIdx.y == 0){
            //if ((threadIdx.x == 0 | threadIdx.x == 127)  && blockIdx.x == 0 && blockIdx.y == 15){ //for debugging M_Global 4096 (SPlitK 1)
            //if ((threadIdx.x == 0 | threadIdx.x == 127)  && blockIdx.x == 0 && blockIdx.y == 127){ //for debugging M_Global 4096 (SPlitK 8)
            //if ((threadIdx.x == 0 | threadIdx.x == 127)  && blockIdx.x == 0 && blockIdx.y == 287){ //for M_Global 9216 (SPlitK 8)
            if ((threadIdx.x == 0 | threadIdx.x == 127)  && blockIdx.x == 0 && blockIdx.y == 343){ //for debugging M_Global 11008 (SPlitK 8)
            //if ((threadIdx.x == 0 | threadIdx.x == 127)  && blockIdx.x == 0 && blockIdx.y == 42){ //for debugging M_Global 11008 (SPlitK 1)
            //if ((threadIdx.x == 0 | threadIdx.x == 127)  && blockIdx.x == 0 && blockIdx.y == 255){ //for debugging M_Global 8192 (SPlitK 8)
            //if ((threadIdx.x == 0 | threadIdx.x == 127)  && blockIdx.x == 0 && blockIdx.y == 31){ //for debugging M_Global 8192 (SPlitK 1)
                printf("---Start of the main loop, post- SpMM_Init Smem (L606):  ...\n \
                For thread %d, blockIdx.x: %d, blockIdx.y: %d, mustafar_batch_id: %d\n \
                StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, mustafar_batch_id, StartTileIdx);
                printf("--------> NumIter(of last blk?):  %d\n", NumIter);
            }
            __syncthreads(); // only for debugging -> A TB level synchronization
        #endif
        SpMM_CopyFromGlobalToReg<TilingConfig, SparseKernelConfig>(Registers_nz,
                                                                Registers_bmp,
                                                                Registers_nnz,
                                                                //NZ, 
                                                                //bmp,
                                                                //idx,
                                                                NZ_batch,
                                                                bmp_batch,
                                                                idx_batch,
                                                                &nnz_tile0,
                                                                &nnz_tile1, 
                                                                StartTileIdx); 

        // Copying B Tile
        CopyTileFromGlobalToShared_X_64<TilingConfig::TILE_N2, TilingConfig>(
            smem_write_PTR + TilingConfig::TILE_M * TILE_K, BTileGlobalPTR, K_Global, GlobalCopy);  //ld_dense
        cp_async_group_commit();
        #if DEBUG
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){
            //if (blockIdx.x == 0 && blockIdx.y == 0){
            //if (blockIdx.x == 127){
                printf("--- Start of the main loop, post- CopyTileFromGlobalToShared_X_64(L.630)...\n \
                For thread %d, blockIdx.x: %d, blockIdx.y: %d, mustafar_batch_id: %d\n \
                StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, mustafar_batch_id, StartTileIdx);
            }
            __syncthreads(); // only for debugging
        #endif


        PipelinedCoreComputations<TilingConfig>(c, a, b, smem_read_PTR, warp_start_row, warp_start_col); //compute
        //

        cp_async_wait_group<1>();
        __syncthreads();  // Sync to ensure the completion of stage 2, but the asyncopy of Tile_B may not finished yet
        SpMM_DecompressFromRegisterToShared<TilingConfig, SparseKernelConfig>(
                                                                    smem_write_PTR,
                                                                    Registers_nz,
                                                                    Registers_bmp,
                                                                    &nnz_tile0,
                                                                    &nnz_tile1,
                                                                    TB_Row, 
                                                                    TB_Col);
                                                                    //tileIdx); //make sure to keep this tid * 2
            
            //smem_write_PTR,
            //Registers_GlobalToShared,
            //NNZ_ThreadLocal1,
            //smem_write_PTR + TilingConfig::TILE_M * TILE_K / 2,
            //Registers_GlobalToShared + SparseKernelConfig::NUM_REG_FOR_SPARSE_KERNEL / 2,
            //NNZ_ThreadLocal2); //extract 
        cp_async_wait_group<0>();  // Sync to ensure the completion of Loading B to shared memory
        __syncthreads();
        //StartTileIdx += 8; //for the next 246x64 tile (8 8x8 block apart row-wise)
        StartTileIdx += 64;

    }
    
    //main loop ended: looking at 256x64 for what the f is wrong. 
    #if DEBUG
        //if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){
        if ( (threadIdx.x == 31 | threadIdx.x == 127) && blockIdx.x == 0 && blockIdx.y == 0){ //Debugging 256x64 for rectangular sanity
        //if (blockIdx.x == 0 && blockIdx.y == 0){
        //if (blockIdx.x == 127){
                printf("---Exit Main Loop, entering one Compute Pipeline...\n \
                For thread %d, blockIdx.x: %d, blockIdx.y: %d, mustafar_batch_id: %d\n \
                StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, mustafar_batch_id, StartTileIdx);
            }
        __syncthreads(); // only for debugging
    #endif
    
    //add epliogue
    half* __restrict__ smem_read_PTR  = smem;
    smem_read_PTR  = smem + ((NumIter-1) % 2) * (TilingConfig::TILE_M * TILE_K + TILE_K * TilingConfig::TILE_N);
    PipelinedCoreComputations<TilingConfig>(c, a, b, smem_read_PTR, warp_start_row, warp_start_col); //compute
    __syncthreads();
    //end of epliogue
    /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    #if DEBUG
        //if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){
        if ( (threadIdx.x == 31 | threadIdx.x == 127) && blockIdx.x == 0 && blockIdx.y == 0){ //Debugging 256x64 for rectangular sanity
        //if (blockIdx.x == 0 && blockIdx.y == 0){
        //if (blockIdx.x == 127){
                printf("---Exit Computation, entering StoreToSharedMemoryFromRegister()...\n \
                For thread %d, blockIdx.x: %d, blockIdx.y: %d, mustafar_batch_id: %d\n \
                StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, mustafar_batch_id, StartTileIdx);
            }
        __syncthreads(); // only for debugging
    #endif
    /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Store the C fragments to shared memory.
    float(*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C] =
        reinterpret_cast<float(*)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C]>(smem);
    StoreToSharedMemoryFromRegister<TilingConfig>(smem_CFrag, c);
    __syncthreads();
    // Now that shared memory contains all the D tiles, stream them to global memory.
    //half* BlockGlobalPTR =
    //    Reduction_Workspace + BatchID * (M_Global * N_Global) + Tile_Start_M + Tile_Start_N * M_Global;
    //half* BlockGlobalPTR = C_batch + Tile_Start_M + Tile_Start_N * M_Global;
    half* BlockGlobalPTR =
        C_batch + BatchID * (M_Global * N_Global) + Tile_Start_M + Tile_Start_N * M_Global;

    #if DEBUG
        //if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0){
        if ( (threadIdx.x == 31 | threadIdx.x == 127) && blockIdx.x == 0 && blockIdx.y == 0){ //Debugging 256x64 for rectangular sanity
        //if (blockIdx.x == 0 && blockIdx.y == 0){
        //if (blockIdx.x == 127){
                printf("---Exit StoreToSharedMemoryFromRegister(), Entering write to global memory...\n \
                For thread %d, blockIdx.x: %d, blockIdx.y: %d, mustafar_batch_id: %d\n \
                StartTileIdx, the access index for bmp and nnz: %d\n", threadIdx.x, blockIdx.x, blockIdx.y, mustafar_batch_id, StartTileIdx);
            }
        __syncthreads(); // only for debugging
    #endif
    
#pragma unroll
    for (int i = warpId; i < TilingConfig::TILE_N2; i += TilingConfig::BLOCK_WARPS)  // i-th column
#pragma unroll
        for (int j = threadIdx.x % WARP_SIZE; j < TilingConfig::TILE_M; j += WARP_SIZE)  // j-th row
            BlockGlobalPTR[j + i * M_Global] = __float2half_rn((*(smem_CFrag + i))[j]);
}

