#include <torch/extension.h>

torch::Tensor long_seq_softmax_cuda(
    torch::Tensor attn_weight,
    int dim, float scaler);


torch::Tensor long_seq_softmax(
    torch::Tensor attn_weight,
    int dim, float scaler)
{
    return long_seq_softmax_cuda(attn_weight, dim, scaler);        
}

torch::Tensor long_seq_softmax_bf16_cuda(
    torch::Tensor attn_weight,
    int dim, float scaler);


torch::Tensor long_seq_softmax_bf16(
    torch::Tensor attn_weight,
    int dim, float scaler)
{
    return long_seq_softmax_bf16_cuda(attn_weight, dim, scaler);        
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("long_seq_softmax", &long_seq_softmax, "Softmax for long sequence modeling");
    m.def("long_seq_softmax_bf16", &long_seq_softmax_bf16, "Softmax for long sequence modeling");
}
