// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/local_exchange.cu

#include "chunk_function.cuh"
#include <torch/extension.h>

// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/local_exchange.cu#L5
void _chunk_pos(
    torch::Tensor chunk_cum_count,
    torch::Tensor chunk_idx,
    torch::Tensor chunk_pos) {
    auto smgr = getCudaStreamManager(chunk_cum_count.device().index());
	auto chunk_idx_shp = chunk_idx.sizes();
	size_t n_batch = chunk_idx_shp[0];
    size_t n_ctx = chunk_idx_shp[1];
	auto chunk_cum_count_shp = chunk_cum_count.sizes();
	size_t n_chunk = chunk_cum_count_shp[1];
    chunk_pos_impl(
		chunk_cum_count.data_ptr<int>(),
		chunk_idx.data_ptr<long>(),
		chunk_pos.data_ptr<long>(),
		n_batch, n_ctx, n_chunk, smgr);
}

// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/local_exchange.cu#L22
void _chunk_count(
        torch::Tensor chunk_idx,
        torch::Tensor chunk_count) {
    auto smgr = getCudaStreamManager(chunk_idx.device().index());
	auto chunk_idx_shp = chunk_idx.sizes();
	size_t n_batch = chunk_idx_shp[0];
    size_t n_ctx = chunk_idx_shp[1];
	auto chunk_count_shp = chunk_count.sizes();
	size_t n_chunk = chunk_count_shp[1];
    chunk_count_impl(
		chunk_idx.data_ptr<long>(),
		chunk_count.data_ptr<int>(),
		n_batch, n_ctx, n_chunk, smgr);
}
