#include <torch/extension.h>
#include <vector>

std::vector<torch::Tensor> batched_dense2sparse_cuda(
    torch::Tensor dense_tensor);

std::vector<torch::Tensor> batched_dense2sparse(
    torch::Tensor dense_tensor)
{
    return batched_dense2sparse_cuda(dense_tensor);
}

std::vector<torch::Tensor> block_ell_cuda(
    torch::Tensor input_data,
    torch::Tensor indices, 
    int block_size_n);

std::vector<torch::Tensor> block_ell(
    torch::Tensor input_data,
    torch::Tensor indices, 
    int block_size_n)
{
    return block_ell_cuda(input_data, indices, block_size_n);
}


std::vector<torch::Tensor> meta_ell_cuda(
    torch::Tensor input_data,
    torch::Tensor indices, 
    int block_size_n);

std::vector<torch::Tensor> meta_ell(
    torch::Tensor input_data,
    torch::Tensor indices, 
    int block_size_n)
{
    return meta_ell_cuda(input_data, indices, block_size_n);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("bdense2sparse", &batched_dense2sparse, "Convert dense matrix to sparse");
    m.def("block_ell", &block_ell, "prune a dense matrix to block-ell");
    m.def("meta_ell", &meta_ell, "prune a meta data to block-ell");
}

