#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cuda.h>
#include <vector>
#include "helper_math.h"
#include <cub/cub.cuh>
#include <stdio.h>
#include "cuda_bf16.h"

#define BlockDim_ 512

// Array types from TurboTransformer
template <int BlockDim, int K, typename LoadType>
__global__ void long_seq_softmax_kernel(
    const float* __restrict__ attn_weight,
    float* __restrict__ attn, const int to_seq_len,
    const float scaler) 
{

    // Specialize BlockReduce for a 1D block on type float
    typedef cub::BlockReduce<float, BlockDim> BlockReduce;

    __shared__ typename BlockReduce::TempStorage temp_storage;

    __shared__ float row_max;
    __shared__ float row_sum;

    // The length of the vector type
    constexpr int ValuesPerLoad = sizeof(LoadType) / sizeof(float);
    static_assert(K % ValuesPerLoad == 0, "The vector type should match the K of each thread");
    // The number of steps to load the operand
    constexpr int LoadSteps = K / ValuesPerLoad;
    // The thread block should be large enough to hold all entries in the row
    assert(BlockDim * K >= to_seq_len);

    // Private register file that holds the row
    LoadType in_attn[LoadSteps];

    // Local Max accumulator
    float local_max[ValuesPerLoad] = {-1e10f};
    LoadType* local_max_v = reinterpret_cast<LoadType *>(local_max);
    // Pointer to the input attention weight
    const LoadType * attn_weight_v = reinterpret_cast<const LoadType *>(attn_weight + blockIdx.x * to_seq_len) + threadIdx.x;

    // Step 1: Load the input attention weights
    #pragma unroll
    for (int i = 0; i < LoadSteps; i ++){
        // TODO: residual handling
        LoadType tmp = __ldg(attn_weight_v) * scaler;
        in_attn[i] = tmp;
        attn_weight_v += BlockDim;
        // local_acc_v += tmp;
        vmax(*local_max_v, tmp);
    }

    // Step 2: Get the maximum attention entry in the row
    float max_attn = BlockReduce(temp_storage).Reduce(local_max, cub::Max());

    if (threadIdx.x == 0) row_max = max_attn;

    __syncthreads();

    float row_max_l = row_max;

    float local_acc = 0.0f;
    float* in_attn_s = reinterpret_cast<float* >(in_attn);

    #pragma unroll
    for (int i = 0; i < K; i ++){
        in_attn_s[i] = __expf(in_attn_s[i] - row_max_l);
        local_acc += in_attn_s[i];
    }

    float sum_attn = BlockReduce(temp_storage).Sum(local_acc) + 1e-6f;

    if (threadIdx.x == 0) row_sum = sum_attn;

    __syncthreads();
    
    float row_sum_l = row_sum;


    #pragma unroll
    for (int i = 0; i < K; i ++){
        in_attn_s[i] /= row_sum_l;
    }

    LoadType * attn_v = reinterpret_cast<LoadType *>(attn + blockIdx.x * to_seq_len) + threadIdx.x;

    #pragma unroll
    for (int i = 0; i < LoadSteps; i ++){
        *(attn_v) = in_attn[i];
        attn_v += BlockDim;
    }
}


template <int BlockDim, int K, typename LoadType>
__global__ void long_seq_softmax_kernel_bf16(
    const nv_bfloat16* __restrict__ attn_weight,
    nv_bfloat16* __restrict__ attn, const int to_seq_len,
    const float scaler) 
{

    // Specialize BlockReduce for a 1D block on type float
    typedef cub::BlockReduce<float, BlockDim> BlockReduce;

    __shared__ typename BlockReduce::TempStorage temp_storage;

    __shared__ float row_max;
    __shared__ float row_sum;

    // The length of the vector type
    constexpr int ValuesPerLoad = sizeof(LoadType) / sizeof(nv_bfloat16);
    static_assert(K % ValuesPerLoad == 0, "The vector type should match the K of each thread");
    // The number of steps to load the operand
    constexpr int LoadSteps = K / ValuesPerLoad;
    // The thread block should be large enough to hold all entries in the row
    assert(BlockDim * K >= to_seq_len);

    // Private register file that holds the row
    // We need two buffers for each element.
    LoadType in_attn[LoadSteps * 2];
    float* in_attn_float = reinterpret_cast<float*>(in_attn);

    // Local Max accumulator
    float local_max[ValuesPerLoad] = {-1e10f};
    LoadType* local_max_v = reinterpret_cast<LoadType *>(local_max);
    // Pointer to the input attention weight
    const LoadType * attn_weight_v = reinterpret_cast<const LoadType *>(attn_weight + blockIdx.x * to_seq_len) + threadIdx.x;

    // Step 1: Load the input attention weights
    #pragma unroll
    for (int i = 0; i < LoadSteps; i ++){
        // TODO: residual handling
        nv_bfloat16 tmp_bf16[ValuesPerLoad];
        LoadType* tmp = reinterpret_cast<LoadType*>(tmp_bf16);
        tmp[0] = __ldg(attn_weight_v);
        #pragma unroll
        for (int j = 0; j < ValuesPerLoad; j++){
            in_attn_float[j] = __bfloat162float(tmp_bf16[j]) * scaler;
        }
        // in_attn[i] = tmp;
        attn_weight_v += BlockDim;
        // local_acc_v += tmp;
        vmax(*local_max_v, in_attn[2 * i]);
        vmax(*(local_max_v + 1), in_attn[2 * i + 1]);
    }

    // Step 2: Get the maximum attention entry in the row
    float max_attn = BlockReduce(temp_storage).Reduce(local_max, cub::Max());

    if (threadIdx.x == 0) row_max = max_attn;

    __syncthreads();

    float row_max_l = row_max;

    float local_acc = 0.0f;
    float* in_attn_s = reinterpret_cast<float* >(in_attn);
    nv_bfloat16* in_attn_bf16 = reinterpret_cast<nv_bfloat16* >(in_attn);

    #pragma unroll
    for (int i = 0; i < K; i ++){
        in_attn_s[i] = __expf(in_attn_s[i] - row_max_l);
        local_acc += in_attn_s[i];
    }

    float sum_attn = BlockReduce(temp_storage).Sum(local_acc) + 1e-6f;

    if (threadIdx.x == 0) row_sum = sum_attn;

    __syncthreads();
    
    float row_sum_l = row_sum;


    #pragma unroll
    for (int i = 0; i < K; i ++){
        in_attn_bf16[i] = __float2bfloat16(in_attn_s[i] / row_sum_l);
    }

    LoadType * attn_v = reinterpret_cast<LoadType *>(attn + blockIdx.x * to_seq_len) + threadIdx.x;

    #pragma unroll
    for (int i = 0; i < LoadSteps; i ++){
        *(attn_v) = in_attn[i];
        attn_v += BlockDim;
    }
}


#define RUN_KERNEL(...)                                                                                     \
    do {                                                                                                    \
        switch (k) {                                                                                        \
            case 1: long_seq_softmax_kernel<BlockDim_, 1, float><<<grid, block>>>(__VA_ARGS__);    break;   \
            case 2: long_seq_softmax_kernel<BlockDim_, 2, float2><<<grid, block>>>(__VA_ARGS__);   break;   \
            case 4: long_seq_softmax_kernel<BlockDim_, 4, float4><<<grid, block>>>(__VA_ARGS__);   break;   \
            case 8: long_seq_softmax_kernel<BlockDim_, 8, float4><<<grid, block>>>(__VA_ARGS__);   break;   \
            case 16: long_seq_softmax_kernel<BlockDim_, 16, float4><<<grid, block>>>(__VA_ARGS__); break;   \
            default:                                                                                        \
                throw std::runtime_error("Unsupported Sequence Length.");                                   \
        }                                                                                                   \
    } while(0)                                                                                              \


#define RUN_KERNEL_BF16(...)                                                                                        \
    do {                                                                                                            \
        switch (k) {                                                                                                \
            case 2: long_seq_softmax_kernel_bf16<BlockDim_, 2, float><<<grid, block>>>(__VA_ARGS__);       break;   \
            case 4: long_seq_softmax_kernel_bf16<BlockDim_, 4, float2><<<grid, block>>>(__VA_ARGS__);      break;   \
            case 8: long_seq_softmax_kernel_bf16<BlockDim_, 8, float4><<<grid, block>>>(__VA_ARGS__);      break;   \
            case 16: long_seq_softmax_kernel_bf16<BlockDim_, 16, float4><<<grid, block>>>(__VA_ARGS__);    break;   \
            default:                                                                                                \
                throw std::runtime_error("Unsupported Sequence Length.");                                           \
        }                                                                                                           \
    } while(0)                                                                                                      \

torch::Tensor long_seq_softmax_cuda(
    torch::Tensor attn_weight,
    int dim, float scaler)
{
    int to_seq_len = attn_weight.size(dim);
    int high_dim_size = attn_weight.numel() / to_seq_len;

    auto attn = torch::empty_like(attn_weight);

    int k = (to_seq_len + BlockDim_ - 1) / BlockDim_;

    dim3 block, grid;

    block.x = BlockDim_;
    grid.x = high_dim_size;

    RUN_KERNEL(attn_weight.data<float>(), attn.data<float>(), to_seq_len, scaler);

    return attn;
}


torch::Tensor long_seq_softmax_bf16_cuda(
    torch::Tensor attn_weight,
    int dim, float scaler)
{
    int to_seq_len = attn_weight.size(dim);
    int high_dim_size = attn_weight.numel() / to_seq_len;

    auto attn = torch::empty_like(attn_weight);

    int k = (to_seq_len + BlockDim_ - 1) / BlockDim_;

    dim3 block, grid;

    block.x = BlockDim_;
    grid.x = high_dim_size;

    RUN_KERNEL_BF16((nv_bfloat16*)attn_weight.data_ptr(), (nv_bfloat16*)attn.data_ptr(), to_seq_len, scaler);

    return attn;
}

#undef RUN_KERNEL
#undef RUN_KERNEL_BF16