#include <torch/torch.h>
#include <c10/cuda/CUDAStream.h>


int cvt_bf16_e2m1_cuda(
    void*,
    const void*,
    const int,
    cudaStream_t
);


void cvt_bf16_e2m1(
    torch::Tensor &y,
    const torch::Tensor &x
) {
    int err = cvt_bf16_e2m1_cuda(
        y.data_ptr(),
        x.data_ptr(),
        x.numel(),
        at::cuda::getCurrentCUDAStream(x.device().index())
    );
}


int cvt_e2m1_bf16_cuda(
    void*,
    const void*,
    const int,
    cudaStream_t
);


void cvt_e2m1_bf16(
    torch::Tensor &y,
    const torch::Tensor &x
) {
    int err = cvt_e2m1_bf16_cuda(
        y.data_ptr(),
        x.data_ptr(),
        y.numel(),
        at::cuda::getCurrentCUDAStream(x.device().index())
    );
}


int forward_bf16_cuda(
    const void*,
    const void*,
    void*,
    void*,
    void*,
    const int,
    cudaStream_t
);


void forward_bf16(
    const torch::Tensor& x,
    const torch::Tensor& h,
    torch::Tensor& xh_e2m1,
    torch::Tensor& xh_e8m0,
    torch::Tensor& clip_mask
) {
    int err = forward_bf16_cuda(
        x.data_ptr(),
        h.data_ptr(),
        xh_e2m1.data_ptr(),
        xh_e8m0.data_ptr(),
        clip_mask.data_ptr(),
        x.numel(),
        at::cuda::getCurrentCUDAStream(h.device().index())
    );
}


int backward_bf16_cuda(
    const void*,
    const void*,
    void*,
    void*,
    int,
    cudaStream_t
);


void backward_bf16(
    const torch::Tensor& x,
    const torch::Tensor& h,
    torch::Tensor& xh_e2m1,
    torch::Tensor& xh_e8m0
) {
    int err = backward_bf16_cuda(
        x.data_ptr(),
        h.data_ptr(),
        xh_e2m1.data_ptr(),
        xh_e8m0.data_ptr(),
        x.numel(),
        at::cuda::getCurrentCUDAStream(h.device().index())
    );
}


int backward_t_bf16_cuda(
    const void*,
    const void*,
    void*,
    void*,
    const int,
    const int,
    const int,
    cudaStream_t
);


void backward_t_bf16(
    const torch::Tensor& x,
    const torch::Tensor& h,
    torch::Tensor& xh_e2m1,
    torch::Tensor& xh_e8m0
) {
    int err = backward_t_bf16_cuda(
        x.data_ptr(),
        h.data_ptr(),
        xh_e2m1.data_ptr(),
        xh_e8m0.data_ptr(),
        x.size(-1),
        x.size(-2),
        x.numel() / (x.size(-2) * x.size(-1)),
        at::cuda::getCurrentCUDAStream(h.device().index())
    );
}


int backward_qt_bf16_cuda(
    const void*,
    const void*,
    const void*,
    const float,
    void*,
    void*,
    const int,
    const int,
    const int,
    cudaStream_t
);


void backward_qt_bf16(
    const torch::Tensor& x_e2m1,
    const torch::Tensor& x_e8m0,
    const torch::Tensor& h,
    const float alpha,
    torch::Tensor& xh_e2m1,
    torch::Tensor& xh_e8m0
) {
    int err = backward_qt_bf16_cuda(
        x_e2m1.data_ptr(),
        x_e8m0.data_ptr(),
        h.data_ptr(),
        alpha,
        xh_e2m1.data_ptr(),
        xh_e8m0.data_ptr(),
        x_e2m1.size(-1) * 2,
        x_e2m1.size(-2),
        x_e2m1.numel() / (x_e2m1.size(-2) * x_e2m1.size(-1)),
        at::cuda::getCurrentCUDAStream(h.device().index())
    );
}
