// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/fmoe_cuda.cpp

#include <iostream>
#include <vector>
#include <torch/extension.h>

// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/fmoe_cuda.cpp#L25
void _chunk_pos(
        torch::Tensor chunk_cum_count,
        torch::Tensor chunk_idx,
        torch::Tensor chunk_pos);

// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/fmoe_cuda.cpp#L29
void _chunk_count(
        torch::Tensor chunk_idx,
        torch::Tensor chunk_count);

// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/fmoe_cuda.cpp#L59
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("chunk_count", &_chunk_count, "Chunk count (CUDA)");
    m.def("chunk_pos", &_chunk_pos, "Chunk pos (CUDA)");
}
