// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/local_exchange.cuh

#include "stream_manager.h"
#include "helper_cuda.h"

#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)

// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/local_exchange.cuh#L5
__global__
void chunk_pos_kernel(int* chunk_cum_count, const long* chunk_idx, long* chunk_pos,
        size_t n_ctx, size_t n_chunk) {
    size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
    if (idx < n_ctx) {
        long index = chunk_idx[blockIdx.y * n_ctx + idx];
        if (index > -1) {
            int p = atomicSub(chunk_cum_count + blockIdx.y * n_chunk + index, 1);
            chunk_pos[blockIdx.y * n_ctx + p - 1] = (long)idx;
        }
    }
}

// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/local_exchange.cuh#L18
void chunk_pos_impl(
        int* chunk_cum_count, const long* chunk_idx, long* chunk_pos,
        const size_t n_batch, const size_t n_ctx, const size_t n_chunk, CudaStreamManager* smgr) {
    dim3 grid(CEIL(n_ctx, 256), n_batch);
	chunk_pos_kernel
        <<<grid, 256, 0, smgr->stream(0)>>>
        (chunk_cum_count, chunk_idx, chunk_pos, n_ctx, n_chunk);
    smgr->sync(1);
}

#define PER_THREAD_CHUNKS 256
#define WARP_SIZE 32

// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/local_exchange.cuh#L37
__global__
void chunk_count_kernel(const long* chunk_idx, int* chunk_count,
        const size_t n_ctx, const size_t n_chunk) {
    int res_tmp[PER_THREAD_CHUNKS] = {0};
    long chunk_min = blockIdx.x * PER_THREAD_CHUNKS;
    long chunk_max = chunk_min + PER_THREAD_CHUNKS;
    if (chunk_max > n_chunk) {
        chunk_max = n_chunk;
    }
    for (int i = threadIdx.x; i < n_ctx; i += blockDim.x) {
        long idx = chunk_idx[blockIdx.y * n_ctx + i];
        if (idx == -1) {
            continue;
        }
        if (idx < chunk_min || idx >= chunk_max) {
            continue;
        }
        res_tmp[idx - chunk_min] += 1;
    }
    for (int i = chunk_min; i < chunk_max; ++i) {
        int x = res_tmp[i - chunk_min];
#pragma unroll
        for (int j = 1; j < WARP_SIZE; j <<= 1) {
            x = x + __shfl_down_sync(-1u, x, j);
        }
        if (threadIdx.x % WARP_SIZE == 0) {
            atomicAdd(chunk_count + blockIdx.y * n_chunk + i, x);
        }
    }
}

// Adapted from https://github.com/laekov/fastmoe/blob/v0.3.0/cuda/local_exchange.cuh#L72
void chunk_count_impl(
        const long* chunk_idx, int* chunk_count,
        const size_t n_batch, const size_t n_ctx, const size_t n_chunk,
        CudaStreamManager* smgr) {
	dim3 grid(CEIL(n_chunk, PER_THREAD_CHUNKS), n_batch);
    chunk_count_kernel
        <<<grid, 256, 0, smgr->stream(0)>>>
        (chunk_idx, chunk_count, n_ctx, n_chunk);
    smgr->sync(1);
}
