#include <cuda.h>
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <iostream>
#include "helper.h"
#include <stdio.h>
#include <vector>
#include "cuda_bf16.h"



__device__ __forceinline__ int reorder_index(int row, int col, int m){
    int dest_row = (row / 32) * 32 + (row % 8) * 4 + (row % 32) / 8;
    int dest_col = col;

    if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)){
        ++dest_row;
        --dest_col;
    }else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)){
        --dest_row;
        ++dest_col;
    }

    int column_major = dest_col / 2;
    int column_minor = dest_col % 2;
    int idx = column_major * 2 * m + dest_row * 2 + column_minor;
    return idx;
}


__device__ __forceinline__ void Dense2Sparse_(
    int m, int k,
    const float* __restrict__ dense_matrix,
    float* __restrict__ sparse_matrix,
    float* __restrict__ uncompressed_matrix,
    int16_t * __restrict__ metadata,
    int16_t * __restrict__ metadata_reorder)
{
    // One thread block per row
    int m_index = blockIdx.x;

    const float* dense_matrix_t = dense_matrix + m_index * k;
    float* uncompressed_matrix_t = uncompressed_matrix + m_index * k;
    float* sparse_matrix_t = sparse_matrix + m_index * (k/2);
    int16_t* metadata_t = metadata + m_index * (k / 8);
    
    for(int i = threadIdx.x; i < k / 8; i += blockDim.x){
        int16_t meta = 0;
        for(int j = 0; j < 4; j++){
            float first = __ldg(dense_matrix_t + 8 * i + 2 * j);
            float second = __ldg(dense_matrix_t + 8 * i + 2 * j + 1);
            // The meta data of the current thread
            meta = meta | (first > second ? 4 : 14)<< (j * 4);
            *(uncompressed_matrix_t + 8 * i + 2 * j) = first > second ? first : 0.0f;
            *(uncompressed_matrix_t + 8 * i + 2 * j + 1) = first > second ? 0.0f: second;
            *(sparse_matrix_t + 4 * i + j) = first > second ? first : second;
        }
        // Write the meta data
        *(metadata_t + i) = meta;   // This returns the not reordered meta data
        int idx = reorder_index(m_index, i, m);
        *(metadata_reorder + idx) = meta;
    }
}


__device__ void Dense2Sparse_(
    int m, int k,
    const nv_bfloat16* __restrict__ dense_matrix,
    nv_bfloat16* __restrict__ sparse_matrix,
    nv_bfloat16* __restrict__ uncompressed_matrix,
    int16_t * __restrict__ metadata,
    int16_t * __restrict__ metadata_reorder)
{
    // One thread block per row
    int m_index = blockIdx.x;

    const nv_bfloat16* dense_matrix_t = dense_matrix + m_index * k;
    nv_bfloat16* uncompressed_matrix_t = uncompressed_matrix + m_index * k;
    nv_bfloat16* sparse_matrix_t = sparse_matrix + m_index * (k/2);
    int16_t* metadata_t = metadata + m_index * (k / 16);

    for(int i = threadIdx.x; i < k / 16; i += blockDim.x){
        int16_t meta = 0;
        // Loop through the threadgroup
        #pragma unroll
        for (int j = 0; j < 4; j++){
            nv_bfloat16 data[4];
            float data_float[4];

            #pragma unroll
            for (int v = 0; v < 4; v++){
                data[v] = __ldg(dense_matrix_t + 16 * i + 4 * j + v);
                data_float[v] = __bfloat162float(data[v]);
            }
            
            // TTFF
            nv_bfloat16 value_sp[2] = {data[0], data[1]};
            nv_bfloat16 value_uncompressed[4] = {data[0], data[1], __float2bfloat16(0.0f), __float2bfloat16(0.0f)};
            int16_t meta_bit = 4;
            float max_val = data_float[0] + data_float[1];

            // TFTF
            if (data_float[0] + data_float[2] > max_val){
                meta_bit = 8;
                value_sp[0] = data[0];
                value_sp[1] = data[2];
                value_uncompressed[0] = data[0];
                value_uncompressed[1] = __float2bfloat16(0.0f);
                value_uncompressed[2] = data[2];
                value_uncompressed[3] = __float2bfloat16(0.0f);
                max_val = data_float[0] + data_float[2];
            }

            // TFFT
            if (data_float[0] + data_float[3] > max_val){
                meta_bit = 12;
                value_sp[0] = data[0];
                value_sp[1] = data[3];
                value_uncompressed[0] = data[0];
                value_uncompressed[1] = __float2bfloat16(0.0f);
                value_uncompressed[2] = __float2bfloat16(0.0f);
                value_uncompressed[3] = data[3];
                max_val = data_float[0] + data_float[3];
            }

            // FTTF
            if (data_float[1] + data_float[2] > max_val){
                meta_bit = 9;
                value_sp[0] = data[1];
                value_sp[1] = data[2];
                value_uncompressed[0] = __float2bfloat16(0.0f);
                value_uncompressed[1] = data[1];
                value_uncompressed[2] = data[2];
                value_uncompressed[3] = __float2bfloat16(0.0f);
                max_val = data_float[1] + data_float[2];
            }

            // FTFT
            if (data_float[1] + data_float[3] > max_val){
                meta_bit = 13;
                value_sp[0] = data[1];
                value_sp[1] = data[3];
                value_uncompressed[0] = __float2bfloat16(0.0f);
                value_uncompressed[1] = data[1];
                value_uncompressed[2] = __float2bfloat16(0.0f);
                value_uncompressed[3] = data[3];
                max_val = data_float[1] + data_float[3];
            }

            // FFTT
            if (data_float[2] + data_float[3] > max_val){
                meta_bit = 14;
                value_sp[0] = data[2];
                value_sp[1] = data[3];
                value_uncompressed[0] = __float2bfloat16(0.0f);
                value_uncompressed[1] = __float2bfloat16(0.0f);
                value_uncompressed[2] = data[2];
                value_uncompressed[3] = data[3];

            }

            meta |= meta_bit << (j * 4);
            #pragma unroll
            for (int v = 0; v < 4; v++){
                *(uncompressed_matrix_t + 16 * i + 4 * j + v) = value_uncompressed[v];
            }
            *(sparse_matrix_t + 8 * i + 2 * j) = value_sp[0];
            *(sparse_matrix_t + 8 * i + 2 * j + 1) = value_sp[1];
        }
        *(metadata_t + i) = meta;

        int idx = reorder_index(m_index, i, m);
        *(metadata_reorder + idx) = meta;
    }
}


template <typename Scalar>
__global__ void Dense2Sparse(
    int m, int k,
    const Scalar* __restrict__ dense_matrix,
    Scalar* __restrict__ sparse_matrix,
    Scalar* __restrict__ uncompressed_matrix,
    int16_t * __restrict__ metadata,
    int16_t * __restrict__ metadata_reorder)
{
    Dense2Sparse_(m, k, dense_matrix, sparse_matrix, uncompressed_matrix, metadata, metadata_reorder);
}


template <typename Scalar>
__global__ void batchedDense2Sparse(
    int m, int k,
    const Scalar* __restrict__ dense_matrix_b,
    int dense_stride,
    Scalar* __restrict__ sparse_matrix_b,
    int sparse_stride,
    Scalar* __restrict__ uncompressed_matrix_b,
    int uncompressed_stride,
    int16_t* __restrict__ metadata_b,
    int meta_stride,
    int16_t* __restrict__ metadata_reorder_b,
    int meta_reorder_stride)
{
    // Get the entry index
    int entry_idx = blockIdx.z;

    // Get the input pointers for the current entry in the batch
    const Scalar* dense_matrix = dense_matrix_b + entry_idx * dense_stride;
    Scalar* sparse_matrix = sparse_matrix_b + entry_idx * sparse_stride;
    Scalar* uncompressed_matrix = uncompressed_matrix_b + entry_idx * uncompressed_stride;
    int16_t* metadata = metadata_b + entry_idx * meta_stride;
    int16_t* metadata_reorder = metadata_reorder_b + entry_idx * meta_reorder_stride;

    Dense2Sparse_(m, k, dense_matrix, sparse_matrix, uncompressed_matrix, metadata, metadata_reorder);
}


std::vector<torch::Tensor> batched_dense2sparse_cuda(
    torch::Tensor dense_matrix)
{
    // Get problem size
    int m = dense_matrix.size(-2);
    int k = dense_matrix.size(-1);
    int batch_size = dense_matrix.numel() / (m * k);

    int meta_ratio;
    if (dense_matrix.dtype() == torch::kBFloat16){
        meta_ratio = 16;
    }
    else{
        meta_ratio = 8;
    }

    // Initiate output matrices
    auto options_val = torch::TensorOptions().dtype(dense_matrix.dtype()).device(dense_matrix.device());
    auto options_meta = torch::TensorOptions().dtype(torch::kInt16).device(dense_matrix.device());

    torch::Tensor sparse_matrix;
    torch::Tensor uncompressed_matrix;
    torch::Tensor metadata;
    torch::Tensor metadata_reorder;

    // For batched implementation
    if (batch_size > 1){
        sparse_matrix = torch::empty({batch_size, m, k/2}, options_val);
        uncompressed_matrix = torch::empty({batch_size, m, k}, options_val);
        metadata = torch::empty({batch_size, m, k/meta_ratio}, options_meta);
        metadata_reorder = torch::empty({batch_size, m, k/meta_ratio}, options_meta);
    }
    // For single Matrix
    else{
        sparse_matrix = torch::empty({m, k/2}, options_val);
        uncompressed_matrix = torch::empty({m, k}, options_val);
        metadata = torch::empty({m, k/meta_ratio}, options_meta);
        metadata_reorder = torch::empty({m, k/meta_ratio}, options_meta);
    }

    // Get stride
    int dense_stride = m * k;
    int sparse_stride = m * k / 2;
    int uncompressed_stride = m * k;

    int meta_stride = m * k / meta_ratio;
    int meta_reorder_stride = m * k / meta_ratio;

    // Get grid size and block size
    dim3 grid;
    grid.x = m;
    grid.z = batch_size;

    dim3 block;
    block.x = 128;

    // Launch kernels
    if (dense_matrix.dtype() == torch::kBFloat16){
        if (batch_size > 1){
            batchedDense2Sparse<nv_bfloat16><<<grid, block>>>(
                m, k, 
                (nv_bfloat16*)dense_matrix.data_ptr(), dense_stride,
                (nv_bfloat16*)sparse_matrix.data_ptr(), sparse_stride,
                (nv_bfloat16*)uncompressed_matrix.data_ptr(), uncompressed_stride, 
                metadata.data<int16_t>(), meta_stride,
                metadata_reorder.data<int16_t>(), meta_reorder_stride
            );
        }
        else{
            Dense2Sparse<nv_bfloat16><<<grid, block>>>(
                m, k, 
                (nv_bfloat16*)dense_matrix.data_ptr(), 
                (nv_bfloat16*)sparse_matrix.data_ptr(), 
                (nv_bfloat16*)uncompressed_matrix.data_ptr(), 
                metadata.data<int16_t>(), metadata_reorder.data<int16_t>());
        }
        
    }
    else{
        if (batch_size > 1){
            batchedDense2Sparse<float><<<grid, block>>>(
                m, k, 
                dense_matrix.data<float>(), dense_stride,
                sparse_matrix.data<float>(), sparse_stride,
                uncompressed_matrix.data<float>(), uncompressed_stride, 
                metadata.data<int16_t>(), meta_stride,
                metadata_reorder.data<int16_t>(), meta_reorder_stride
            );
        }
        else{
            Dense2Sparse<float><<<grid, block>>>(
                m, k, dense_matrix.data<float>(), sparse_matrix.data<float>(), 
                uncompressed_matrix.data<float>(), metadata.data<int16_t>(), metadata_reorder.data<int16_t>());
        }
        
    }
    return {sparse_matrix, uncompressed_matrix, metadata, metadata_reorder};
}


/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Prune the matrix with block sparsity
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Scalar_t>
__device__ void blockEll_kernel_(
    int block_size_m,
    int block_size_n,
    int m, int n, int nnz,
    const Scalar_t* __restrict__ input_data,
    const int* __restrict__ indices,
    Scalar_t* __restrict__ output_data,
    Scalar_t* __restrict__ uncompressed)
{
    // One thread block per row
    int m_index = blockIdx.x;

    // The number of nonzero blocks per row
    int nnz_block = nnz / block_size_n;
    int indices_row = m_index / block_size_m;

    // Set base indices
    const int* indices_base_t = indices + indices_row * nnz_block;
    const Scalar_t* input_data_base_t = input_data + m_index * n;
    Scalar_t* uncompressed_base_t = uncompressed + m_index * n;
    Scalar_t* output_data_t = output_data + m_index * nnz;
    
    for (int n_index = threadIdx.x; n_index < nnz; n_index += blockDim.x){
        // Get the nonzero block index
        int indices_col = n_index / block_size_n;
        // Get the index within the current block
        int indices_block = n_index % block_size_n;
        // Get the absolute index in the row
        int n_index_uncompressed = indices_base_t[indices_col] * block_size_n + indices_block;

        // Load value
        Scalar_t value = input_data_base_t[n_index_uncompressed];
        // Write value to uncompressed
        uncompressed_base_t[n_index_uncompressed] = value;
        // Write value to compressed
        output_data_t[n_index] = value;
    }
}


// The meta data doesn't need to be uncompressed
// We can directly use the original metadata
__device__ void metaEll_kernel_(
    int block_size_m, 
    int block_size_n,
    int m, int n, int nnz,
    const int16_t* __restrict__ input_meta,
    const int* __restrict__ indices,
    int16_t* __restrict__ output_meta)
{
    // One thread block per column
    int n_index = blockIdx.x;

    // The number of nonzero blocks per row
    int nnz_block = nnz / block_size_n;
    int indices_col = n_index / block_size_n;

    const int* indices_base_t = indices + indices_col;
    
    // Set base indices
    for (int m_index = threadIdx.x; m_index < m; m_index += blockDim.x){
        // Get the nonzero block index
        int indices_row = m_index / block_size_m;
        int indice_temp = indices_base_t[indices_row * nnz_block];

        // Get the absolution index to the element
        int index_absolute = (indice_temp * block_size_n + n_index % block_size_n) * m + m_index;

        int16_t value = input_meta[index_absolute];
        // Write the value to the output
        output_meta[n_index * m + m_index] = value;
    }
}


template <typename Scalar_t>
__global__ void blockEll_kernel(
    int block_size_m,
    int block_size_n,
    int m, int n, int nnz,
    const Scalar_t* __restrict__ input_data,
    const int* __restrict__ indices,
    Scalar_t* __restrict__ output_data,
    Scalar_t* __restrict__ uncompressed)
{
    blockEll_kernel_<Scalar_t>(block_size_m, block_size_n, m, n, nnz, input_data, indices, output_data, uncompressed);
}


template <typename Scalar_t>
__global__ void batched_blockEll_kernel(
    int block_size_m, 
    int block_size_n,
    int m, int n, int nnz,
    const Scalar_t* __restrict__ b_input_data, int stride_input_data,
    const int* __restrict__ b_indices, int stride_indices,
    Scalar_t* __restrict__ b_output_data, int stride_output,
    Scalar_t* __restrict__ b_uncompressed, int stride_uncompressed)
{
    int entry_idx = blockIdx.z;

    const Scalar_t* input_data = b_input_data + entry_idx * stride_input_data;
    const int* indices = b_indices + entry_idx * stride_indices;
    Scalar_t* output_data = b_output_data + entry_idx * stride_output;
    Scalar_t* uncompressed = b_uncompressed + entry_idx * stride_uncompressed;

    blockEll_kernel_<Scalar_t>(block_size_m, block_size_n, m, n, nnz, input_data, indices, output_data, uncompressed);
}


__global__ void metaEll_kernel(
    int block_size_m, 
    int block_size_n,
    int m, int n, int nnz,
    const int16_t* __restrict__ input_meta,
    const int* __restrict__ indices,
    int16_t* __restrict__ output_meta)
{
    metaEll_kernel_(block_size_m, block_size_n, m, n, nnz, input_meta, indices, output_meta);
}


__global__ void batched_metaEll_kernel(
    int block_size_m,
    int block_size_n,
    int m, int n, int nnz,
    const int16_t* __restrict__ b_input_meta, int stride_meta,
    const int* __restrict__ b_indices, int stride_indices,
    int16_t* __restrict__ b_output_meta, int stride_output)
{
    int entry_idx = blockIdx.z;

    const int16_t* input_meta = b_input_meta + entry_idx * stride_meta;
    const int* indices = b_indices + entry_idx * stride_indices;
    int16_t* output_meta = b_output_meta + entry_idx * stride_output;

    metaEll_kernel_(block_size_m, block_size_n, m, n, nnz, input_meta, indices, output_meta);
}


std::vector<torch::Tensor> block_ell_cuda(
    torch::Tensor input_data,
    torch::Tensor indices, 
    int block_size_n)
{
    // Get problem size
    int m = input_data.size(-2);
    int n = input_data.size(-1);

    // The number of blocked rows
    int m_block = indices.size(-2);
    // Number of nozero blocks in each row
    int n_block = indices.size(-1);

    int batch_size = input_data.numel() / (m * n);

    // Get block size
    int block_size_m = m / m_block;
    int nnz = n_block * block_size_n;

    auto options_val = torch::TensorOptions().dtype(input_data.dtype()).device(input_data.device());

    torch::Tensor output_data;
    torch::Tensor uncompressed;

    if (batch_size > 1){
        output_data = torch::empty({batch_size, m, nnz}, options_val);
        uncompressed = torch::zeros_like(input_data);
    }
    else{
        output_data = torch::empty({m, nnz}, options_val);
        uncompressed = torch::zeros_like(input_data);
    }

    int stride_input_data = m * n;
    int stride_indices = indices.numel() / batch_size;
    int stride_output = m * nnz;
    int stride_uncompressed = m * n;

    // Get grid size
    dim3 grid;
    grid.x = m;
    grid.z = batch_size;

    dim3 block;
    block.x = 128;

    // For float32 data
    if (input_data.dtype() == torch::kFloat32){
        if (batch_size > 1){
            batched_blockEll_kernel<float><<<grid, block>>>(
                block_size_m, block_size_n, m, n, nnz,
                input_data.data<float>(), stride_input_data,
                indices.data<int>(), stride_indices,
                output_data.data<float>(), stride_output,
                uncompressed.data<float>(), stride_uncompressed
            );
        }
        else{
            blockEll_kernel<float><<<grid, block>>>(
                block_size_m, block_size_n, m, n, nnz,
                input_data.data<float>(), 
                indices.data<int>(),
                output_data.data<float>(),
                uncompressed.data<float>()
            );
        }
    }
    else{
        if (batch_size > 1){
            batched_blockEll_kernel<nv_bfloat16><<<grid, block>>>(
                block_size_m, block_size_n, m, n, nnz,
                (nv_bfloat16*)input_data.data_ptr(), stride_input_data,
                indices.data<int>(), stride_indices,
                (nv_bfloat16*)output_data.data_ptr(), stride_output,
                (nv_bfloat16*)uncompressed.data_ptr(), stride_uncompressed
            );
        }
        else{
            blockEll_kernel<nv_bfloat16><<<grid, block>>>(
                block_size_m, block_size_n, m, n, nnz,
                (nv_bfloat16*)input_data.data_ptr(), 
                indices.data<int>(),
                (nv_bfloat16*)output_data.data_ptr(),
                (nv_bfloat16*)uncompressed.data_ptr()
            );
        }
    }

    return {output_data, uncompressed};
}



std::vector<torch::Tensor> meta_ell_cuda(
    torch::Tensor input_data,
    torch::Tensor indices, 
    int block_size_n)
{
    // Get problem size
    int m = input_data.size(-2) * 2;
    int n = input_data.size(-1) / 2;

    // The number of blocked rows
    int m_block = indices.size(-2);
    // Number of nozero blocks in each row
    int n_block = indices.size(-1);

    int batch_size = input_data.numel() / (m * n);

    // Get block size
    int block_size_m = m / m_block;
    int nnz = n_block * block_size_n / 2;

    auto options_val = torch::TensorOptions().dtype(input_data.dtype()).device(input_data.device());

    torch::Tensor output_data;
    torch::Tensor uncompressed;

    if (batch_size > 1){
        output_data = torch::empty({batch_size, m/2, nnz*2}, options_val);
    }
    else{
        output_data = torch::empty({m/2, nnz*2}, options_val);
    }

    int stride_meta = m * n;
    int stride_indices = indices.numel() / batch_size;
    int stride_output = m * nnz;

    // Get grid size
    dim3 grid;
    grid.x = nnz;
    grid.z = batch_size;

    dim3 block;
    block.x = 128;

    if (batch_size > 1){
        batched_metaEll_kernel<<<grid, block>>>(
            block_size_m, block_size_n / 2, m, n, nnz,
            input_data.data<int16_t>(), stride_meta,
            indices.data<int>(), stride_indices,
            output_data.data<int16_t>(), stride_output
        );
    }
    else{
        metaEll_kernel<<<grid, block>>>(
            block_size_m, block_size_n / 2, m, n, nnz,
            input_data.data<int16_t>(), 
            indices.data<int>(),
            output_data.data<int16_t>()
        );
    }
    return {output_data, input_data};
}

