

#include <torch/types.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>

#include <cuda.h>
#include <cuda_runtime.h>


template <typename scalar_t>
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
    int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
    int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;

    scalar_t zero = 0.0;

    for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
        scalar_t x = p_x[xi];

        if (use_bias) {
            x += p_b[(xi / step_b) % size_b];
        }

        scalar_t ref = use_ref ? p_ref[xi] : zero;

        scalar_t y;

        switch (act * 10 + grad) {
            default:
            case 10: y = x; break;
            case 11: y = x; break;
            case 12: y = 0.0; break;

            case 30: y = (x > 0.0) ? x : x * alpha; break;
            case 31: y = (ref > 0.0) ? x : x * alpha; break;
            case 32: y = 0.0; break;
        }

        out[xi] = y * scale;
    }
}


torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
    int act, int grad, float alpha, float scale) {
    int curDevice = -1;
    cudaGetDevice(&curDevice);
    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);

    auto x = input.contiguous();
    auto b = bias.contiguous();
    auto ref = refer.contiguous();

    int use_bias = b.numel() ? 1 : 0;
    int use_ref = ref.numel() ? 1 : 0;

    int size_x = x.numel();
    int size_b = b.numel();
    int step_b = 1;

    for (int i = 1 + 1; i < x.dim(); i++) {
        step_b *= x.size(i);
    }

    int loop_x = 4;
    int block_size = 4 * 32;
    int grid_size = (size_x - 1) / (loop_x * block_size) + 1;

    auto y = torch::empty_like(x);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
        fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
            y.data_ptr<scalar_t>(),
            x.data_ptr<scalar_t>(),
            b.data_ptr<scalar_t>(),
            ref.data_ptr<scalar_t>(),
            act,
            grad,
            alpha,
            scale,
            loop_x,
            size_x,
            step_b,
            size_b,
            use_bias,
            use_ref
        );
    });

    return y;
}