#pragma once



#include <torch/all.h>

#include <cuda_runtime.h>



#define TAIL_CAPACITY_ROWS 2048



struct hybrid_sp_t : torch::CustomClassHolder {

    at::Tensor _ell_col_indices;

    at::Tensor _ell_values;

    at::Tensor _row_counters;

    at::Tensor _overflow_counter;

    at::Tensor _tail_dense;
    at::Tensor _tail_dense_map;
    at::Tensor _tail_dense_map_reverse;
    cudaEvent_t _counter_copy_ev;

    at::Tensor hN;

    int _dense_active_rows;

    hybrid_sp_t() {};

    hybrid_sp_t(int M, int N, torch::Device device);

    hybrid_sp_t(const hybrid_sp_t& sp);

    void reset_vals();



    inline uint16_t* ell_col_indices() const {return static_cast<uint16_t*>(_ell_col_indices.data_ptr());}

    inline __nv_bfloat16* ell_values() const {return static_cast<__nv_bfloat16*>(_ell_values.data_ptr());}

    inline int32_t* row_counters() const {return static_cast<int32_t*>(_row_counters.data_ptr());}

    inline int32_t* overflow_counter() const {return static_cast<int32_t*>(_overflow_counter.data_ptr());}

    inline __nv_bfloat16* tail_dense() const {return static_cast<__nv_bfloat16*>(_tail_dense.data_ptr());}

    inline int32_t* tail_dense_map() const {return static_cast<int32_t*>(_tail_dense_map.data_ptr());}

    inline int32_t* tail_dense_map_reverse() const {return static_cast<int32_t*>(_tail_dense_map_reverse.data_ptr());}

};





void create_hybrid_sparse(at::Tensor& slabs_data_d, at::Tensor& num_slabs_d, hybrid_sp_t* sp, at::Tensor& l0, at::Tensor& l1, int M, int N, cudaStream_t stream);

void create_hybrid_sparse(at::Tensor& slabs_data_d, at::Tensor& num_slabs_d, hybrid_sp_t* sp, int M, int N, cudaStream_t stream);

void new_product_as_sparse_sma(hybrid_sp_t* out, at::Tensor const& a, at::Tensor const& b, at::Tensor const& init_val, int M, int N, int K, cudaStream_t stream);

void pack_sparse(at::Tensor out, at::Tensor out_n, at::Tensor sp, int M, int N, int64_t out_size, cudaStream_t stream);

void transpose_hybrid_dense(const hybrid_sp_t& A, hybrid_sp_t& AT, int M_rows, int N_cols, cudaStream_t stream, const int* precomputed_tail_dense_map = nullptr, const int* precomputed_tail_dense_map_reverse = nullptr);

void sparse_dense_gemm_hybrid_dense(at::Tensor& out, hybrid_sp_t* A, const at::Tensor& B, int M, int N, int K, bool transpose_dense_part, cudaStream_t stream, const at::Tensor& B_fp32_cache = at::Tensor());

void sparse_elementwise(hybrid_sp_t* out, hybrid_sp_t* A, hybrid_sp_t* B, int M, int N, cudaStream_t stream);

void sparse_elementwise(hybrid_sp_t* out, hybrid_sp_t* A, hybrid_sp_t* B, int M, int N, at::Tensor& acc_init, cudaStream_t stream);

