/***************************************************************************
 * Copyright 2023 The FLash-LLM Authors. All rights reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ***************************************************************************/

#include "./MatMulUtilities.cuh"
#include "./Reduction_Kernel.cuh"
#include "./SpMM_Kernel.cuh"
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

void print_packed_halfs(uint32_t packed_value) {
    // Extract the first half (lower 16 bits)
    half first_half = (half)(packed_value & 0xFFFF);  // Mask to get the lower 16 bits

    // Extract the second half (upper 16 bits)
    half second_half = (half)((packed_value >> 16) & 0xFFFF);  // Shift right and mask to get the upper 16 bits

    // Print the two half values
    printf("First half: %f\n", __half2float(first_half));  // Convert half to float for readable output
    printf("Second half: %f\n", __half2float(second_half));
}

template<typename TilingConfig, typename SparseKernelConfig>
static void SpMM_SplitK_Kernel_Ex(cudaStream_t stream,
                                  const half*  A,
                                  const uint64_t* bmp, 
                                  const uint4* NZ,
                                  //const uint32_t* NZ, 
                                  const uint32_t* idx,
                                  //const uint32_t* bmp_idx_offset, 
                                  const uint32_t* NZ_offset,
                                  //const uint4* Compressed_A,
                                  //const int*   TileOffsets,
                                  const half*  B,
                                  half*        Reduction_Workspace,
                                  const int    M_Global,
                                  const int    N_Global,
                                  const int    K_Global,
                                  int          Split_K, 
                                  const int    Batch_Size)
{
    Split_K = 1;
    static int SHMEM_SZ = max((TilingConfig::TILE_M * TILE_K + TilingConfig::TILE_N * TILE_K) * sizeof(half) * 2,
                              (TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C) * TilingConfig::TILE_N * sizeof(float));
    cudaFuncSetAttribute(
        SpMM_Kernel<TilingConfig, SparseKernelConfig>, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);
    // printf("Max shared memory size: %d B\n", SHMEM_SZ);
    int dimN =
        max(N_Global / TilingConfig::TILE_N, 1);  // max(N_Global/TilingConfig::TILE_N,1) used when N=8, TILE_N=16

    int  dimM = M_Global * Split_K / TilingConfig::TILE_M;
    //dim3 GridDim(dimN, dimM, 1);  // Grid Size is increased due to SplitK for higher SM occupancy
        //each M tiled row handled by SplitK TBs.
    dim3 GridDim(dimN, dimM, Batch_Size);
    dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1);

    

    //std::cout << "----SpMM_SplitK_Kernel_Ex(): Shared Memory Size: " << SHMEM_SZ << " Bytes" << std::endl;
    //std::cout << "----SpMM_SplitK_Kernel_Ex(): GridDim: " << dimN << "x" << dimM << " BlockDim: " << WARP_SIZE * TilingConfig::BLOCK_WARPS << "x1x1" << std::endl;
        // GridDim: 1x196: (7168/256) * 7(Split_K)
    // stream is just the GPU job_ID.
    SpMM_Kernel<TilingConfig, SparseKernelConfig><<<GridDim, BlockDim, SHMEM_SZ, stream>>>(
        A, bmp, NZ, idx, /*bmp_idx_offset,*/ NZ_offset, //Compressed_A, TileOffsets, 
        B, Reduction_Workspace, M_Global, N_Global, K_Global, 1, Batch_Size); //explicitly set Split_K to 1. 
}

/*
half* Reduction_Workspace:  1. Requiring an extra memory space in device memory for un-reducted intermediate output
tensors
                            2. Reduction_Workspace_Size = max( Split_K * M_Global * N_Global ) * sizeof(fp16)
int Split_K:                Split K dimension into Split_K Parts
*/
cudaError_t SpMM_SplitK_API(cudaStream_t stream,
                            const half*  A,
                            const uint64_t* bmp, 
                            const uint4* NZ,
                            //const uint32_t* NZ, 
                            const uint32_t* idx,
                            const uint32_t* NZ_offset,
                            //const uint4* Compressed_A,
                            //const int*   TileOffsets,
                            const half*  B,
                            half*        C,
                            const int    M_Global,
                            const int    N_Global,
                            const int    K_Global,
                            half*        Reduction_Workspace,  // Identical workspace for all SpMM kernel launches
                            int          Split_K, //given that this is always 1. 
                            const int    Batch_Size)
{
#ifdef DEBUG_MODE
    printf("--- SpMM_API.cu/SpMM_SplitK_API(): Entering SpMM_SplitK_API----\n");
    printf(
        "SpMM_API.cu->SpMM_SplitK_API():  M: %d, N: %d, K: %d, SplitK: %d \n", M_Global, N_Global, K_Global, Split_K);
    assert(K_Global % TILE_K == 0);
    assert(M_Global % 256 == 0);
#endif
    half* SpMM_SplitK_OutputPTR;
    if (Split_K == 1)
        SpMM_SplitK_OutputPTR = C;
    else
        SpMM_SplitK_OutputPTR = Reduction_Workspace;
    // Batched SpMM

    switch (N_Global) {

        case 8:
            SpMM_SplitK_Kernel_Ex<TilingConfig<4, 1, 1, 1>, SparseKernelConfig<64>>(
                stream, A, bmp, NZ, idx, NZ_offset,//Compressed_A, TileOffsets, 
                B, SpMM_SplitK_OutputPTR, M_Global, N_Global, K_Global, Split_K, Batch_Size);
            break;

    }
    //
    cudaError_t Error = cudaGetLastError();
    if (Error != cudaSuccess)
        return Error;

    if (Split_K == 1)
        return Error;
    dim3 GridDim((M_Global * N_Global) / 256, 1, 1);
    dim3 BlockDim(WARP_SIZE, 1, 1);
    SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(C, Reduction_Workspace, M_Global, N_Global, Split_K);
    return cudaGetLastError();
}


__host__ int InitSparseMatrixA_API_bmp_colwise(
                                            //half*      A_h,
                                            // int        M,
                                            // int        N,
                                            // int        K,
                                             uint32_t** idx,  // CPU PTR
                                             uint64_t** bmp,
                                             uint32_t** NZ)        // CPU_PTR
{
    int matrix_m = 7168;
    int matrix_n = 7168;
    
    //const size_t intArrayLength = matrix_m * matrix_n + 1;
    const size_t intArrayLength = matrix_m * matrix_n / 64;
    const size_t uint64ArrayLength = matrix_m * matrix_n / 64;
    const size_t fp16ArrayLength = matrix_m * matrix_n / 4; //50% sparsity /pack two fp16 into a 32b.

    *idx = (uint32_t*)malloc(intArrayLength * sizeof(uint32_t));
    *bmp = (uint64_t*)malloc(uint64ArrayLength * sizeof(uint64_t));
    *NZ = (uint32_t*)malloc(fp16ArrayLength * sizeof(uint32_t));

    for (size_t i = 0; i < intArrayLength; ++i) {
        (*idx)[i] = static_cast<uint32_t>(i * 32); 
    }

    for (size_t i = 0; i < uint64ArrayLength; ++i) {
        (*bmp)[i] = 0xAAAAAAAAAAAAAAAA;  // 1010 repeated
    }

    //int32_t fp16_1s = (0x3C00 << 16) | 0x4200;  // 1.0 and 3.0 in fp16 (1 is lower half, the first element)
    //int32_t fp16_1s = (0x4200 << 16) | 0x3C00 ; //inversed?  (1 is higher half, the first element)
    //int32_t fp16_1s = (0x3C2B << 16) | 0x3DDE ; //(0.01 is higher half, the first element)
    //int32_t fp16_1s = (__half_as_ushort(__float2half_rn(0.01f)) << 16) | __half_as_ushort(__float2half_rn(0.03f)); //(0.01 is higher half, the first element)
    int32_t fp16_1s = (__half_as_ushort(__float2half_rn(0.03f)) << 16) | __half_as_ushort(__float2half_rn(0.01f)); //(0.01 is lower half, the first element)
    for (size_t i = 0; i < fp16ArrayLength; ++i) {
        (*NZ)[i] = fp16_1s;
    }

    std::cout << "----InitSparseMat_bmp(): Compression Complete----" << std::endl;
    return 0;  // number of Elements in array TileOffsets
}

void print_first_column(half* matrix, int rows, int cols) {
    // We are assuming the matrix is stored in row-major order
    for (int i = 0; i < 10 && i < rows; ++i) {
        // Print the element in the first column of each row
        float value = __half2float(matrix[i * cols]);
        printf("Row %d, Col 0: %.4f\n", i, value);
    }
}

void print_packed_fp16(uint32_t* packed_array, int num_elements) {
    for (int i = 0; i < 10 && i < num_elements; ++i) {
        // Extract two half-precision numbers from the packed uint32_t
        uint32_t packed_value = packed_array[i];

        // Extract the lower and upper 16 bits (which are half-precision floats)
        uint16_t lower_half = static_cast<uint16_t>(packed_value & 0xFFFF);  // Lower 16 bits
        uint16_t upper_half = static_cast<uint16_t>(packed_value >> 16);     // Upper 16 bits

        // Convert the half-precision values to float using __half2float
        half lower_half_float = *reinterpret_cast<half*>(&lower_half);
        half upper_half_float = *reinterpret_cast<half*>(&upper_half);

        float lower_value = __half2float(lower_half_float);
        float upper_value = __half2float(upper_half_float);

        // Print the values
        printf("Packed element %d: Lower half = %.4f, Upper half = %.4f\n", i, lower_value, upper_value);
    }
}

void print_first_10_elements(uint32_t* array, int num_elements) {
    for (int i = 0; i < 10 && i < num_elements; ++i) {
        printf("Element %d: %d\n", i, array[i]);
    }
}

//The problem here is that I let GPT code this one. Should've looked into this myself. 
//__host__ int InitSparseMatrixA_API_bmp_real(half *matrix, int cols, int rows, 
__host__ int InitSparseMatrixA_API_bmp_real(half *matrix, int rows, int cols, 
                     uint64_t **bitmaps, uint32_t **packed_nonzeros, uint32_t **num_nonzeros) {
    int tile_size = 64;
    int num_tiles_row = cols;  // Number of tiles per row
    int num_tiles_col = rows / tile_size;  // Number of tiles per column
    int total_tiles = num_tiles_row * num_tiles_col;         // Total number of tiles
    printf("num_tiles_per_row : %d\n", num_tiles_row);
    printf("num_tiles_per_col : %d\n", num_tiles_col);

    int nnz_capacity = rows * cols;  // Max capacity for packed nonzeros (before padding)

    // If pointers are NULL, allocate memory dynamically
    if (*bitmaps == NULL) {
        *bitmaps = (uint64_t*)malloc(total_tiles * sizeof(uint64_t));
        if (*bitmaps == NULL) {
            fprintf(stderr, "Failed to allocate memory for bitmaps.\n");
            exit(EXIT_FAILURE);
        }
    }
    if (*packed_nonzeros == NULL) {
        *packed_nonzeros = (uint32_t*)malloc(nnz_capacity * sizeof(uint32_t));
        if (*packed_nonzeros == NULL) {
            fprintf(stderr, "Failed to allocate memory for packed nonzeros.\n");
            exit(EXIT_FAILURE);
        }
    }
    memset(*packed_nonzeros, 0, nnz_capacity * sizeof(uint32_t));
    if (*num_nonzeros == NULL) {
        *num_nonzeros = (uint32_t*)malloc((total_tiles+1) * sizeof(uint32_t)); //then discard the last element.
        if (*num_nonzeros == NULL) {
            fprintf(stderr, "Failed to allocate memory for num_nonzeros.\n");
            exit(EXIT_FAILURE);
        }
    }

    int tile_count = 0;  // To track the current tile being processed
    int nnz_count = 0;   // To track the current number of packed nonzeros (uint32_t)

    (*num_nonzeros)[0] = 0;

    for (int tile_col = 0; tile_col < num_tiles_col; ++tile_col) {
        for (int tile_row = 0; tile_row < num_tiles_row; ++tile_row) { //number of tiles per row. for 256x256, 256
         //number of tiles per column //for 256x256, 4
            uint64_t bitmap = 0;   // Bitmap for the current 8x8 tile
            int non_zero_count = 0; // Counter for nonzeros in the current tile
            uint32_t packed_value = 0;  // To store two half values packed into one uint32_t
            int half_count = 0;         // To count how many halfs are packed

            //looping over 64x1 tiles. -> 1x64 tiles. 
            for (int i = 0; i < tile_size; ++i) {
                //for (int j = 0; j < tile_size; ++j) {
                    //int row = tile_row;
                    //int col = tile_col * tile_size + j;
                    
                    int row = tile_col * tile_size + i;
                    int col = tile_row;
                    int pos = i;  // Position in the 64-bit bitmap

                    // Ensure we don't go out of bounds
                    if (row < rows && col < cols) {
                        half value = matrix[row * cols + col];
                        //if (row < 8 && col < 8){
                        //    printf("row: %d, col: %d: value: %f\n", row, col, __half2float(value));
                        //}
                        if (__half2float(value) != 0.0f) {
                            //bitmap |= (1ULL << pos);   // Set bit if non-zero
                            bitmap |= (1ULL << (63 - pos));
                            uint16_t raw_half_value = *(uint16_t *)&value;
                            
                            //if (row < 8 && col < 8){
                            //    printf("nonzero found! ~ ");
                            //    printf("row: %d, col: %d: value: %f\n", row, col, __half2float(value));
                            //    for (int fuc = 63; fuc >= 0; fuc--) {
                            //      printf("%c", (bitmap & ((uint64_t)1 << fuc)) ? '1' : '0');
                            //        }
                            //    printf("\n");
                            //    }
                            
                            if (half_count == 0) {
                                // Pack the first half into the lower 16 bits
                                packed_value = (uint32_t)raw_half_value;  // Store the first half
                                //printf("Packing value to first half: 0x%x\n", raw_half_value);
                                (half_count)++;
                            } else {
                                // Pack the second half into the upper 16 bits
                                packed_value |= ((uint32_t)raw_half_value << 16);  // Shift and store the second half
                                //printf("Packing value to second half: 0x%x\n", raw_half_value);
                                (*packed_nonzeros)[nnz_count++] = packed_value;  // Store packed value
                                half_count = 0;  // Reset for the next pair of half values
                            }
                            non_zero_count++;
                        }
                    }
                }
            //}
            //printf("for tile (%d, %d), nonzero %d\n", tile_row, tile_col, non_zero_count);
            //print_packed_halfs(*packed_nonzeros[0]);
            //printf("bitmap 0x%016llx\n", bitmap);

            // If an odd number of non-zeros, pad the remaining half value
            if (half_count == 1) {
                //packed_value |= ((uint32_t)0 << 16);
                
                //if (tile_row < 8 && tile_col < 8){
                //    printf("odd number of nz to be padded at: ");
                //    printf("~~~~ packed to index %d\n", nnz_count);
                //}
                (*packed_nonzeros)[nnz_count++] = packed_value;  // Pad with one half
                non_zero_count++;
            }

            // Pad non-zeros to a multiple of 8 (our vector size is uint4)
            int padding_needed = (non_zero_count % 8 == 0) ? 0 : (8 - (non_zero_count % 8));
            for (int pad = 0; pad < padding_needed; pad += 2) {
                (*packed_nonzeros)[nnz_count++] = 0;  // Add padding zeros (two halfs packed into one uint32_t)
            }

            (*bitmaps)[tile_count] = bitmap;           // Store bitmap for this tile
            (*num_nonzeros)[tile_count + 1] = (*num_nonzeros)[tile_count] + (non_zero_count + padding_needed)/2;
            //printf("bitmap 0x%016llx\n", bitmap);
            tile_count++;
        }
    }

    // Optionally, we could reallocate to shrink the packed_nonzeros array to its actual size
    //*packed_nonzeros = (uint32_t*)realloc(*packed_nonzeros, (nnz_count) * sizeof(uint32_t));
    *packed_nonzeros = (uint32_t*)realloc(*packed_nonzeros, ((*num_nonzeros)[tile_count]) * sizeof(uint32_t));
    //*packed_nonzeros = (uint32_t*)realloc(*packed_nonzeros, (nnz_count/2) * sizeof(uint32_t));
    
    //printf("---------Compression------\n");
    //printf("Length of packed nonzeros (in 32B): %d\n", nnz_count);
    //printf("Length of bitmap (in tiles): %d\n", total_tiles);

    //print_first_column(matrix, rows, cols);
    //printf("first bitmap: 0x%016llx:\n", (*bitmaps)[0]); //SHITTTT 

    //printf("First 10 packed nonzeros:\n");
    //print_packed_fp16(*packed_nonzeros, 10);
    //printf("the first 10 nnz count: \n");
    //print_first_10_elements(*num_nonzeros, 10);
    

    //return nnz_count-1;
    //return (*num_nonzeros)[tile_count] / 2;
    printf("Tile Count within bmp transformation: %d\n", tile_count);
    return (*num_nonzeros)[tile_count]; //we already regard num_nonzeros as the number of packed uint32_t.
    
    //to summarize
        //idx does have to include the pads 
        //return value should give the number of packed uint32, padding included.
        //bitmap does not have to include the pads.

}
