#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cuda.h>
#include <vector>
#include <cub/cub.cuh>


#define BLOCKS(N, T) (N + T - 1)/T

/*
The fused inplementation of lSELU
*/


template <typename scalar_t>
__global__ void fuse_lselu_f(
    const scalar_t* __restrict__ x, 
    scalar_t* y, float alpha,
    float beta, float lambda, unsigned int numel
){
    for (unsigned int i=blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += gridDim.x * blockDim.x){
        scalar_t x_ = x[i];
        scalar_t y_;
        if (x_ > 0){
            y_ = lambda * x_;
        }
        else{
            y_ = lambda * (alpha * exp(x_) + beta * x_ - alpha);
        }
        y[i] = y_;
    }
}


torch::Tensor fuse_lselu_f_cuda(
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda
){
    // get the size of feature map
    unsigned int numel = x.numel();
    auto y = torch::empty_like(x);

    AT_DISPATCH_FLOATING_TYPES(x.type(), "fuse_lselu_f_a", ([&]{
        fuse_lselu_f<scalar_t><<<BLOCKS(numel, 1024), 1024>>>(
            x.data<scalar_t>(), y.data<scalar_t>(),
            alpha, beta, lambda, numel
        );
    }));

    return y;
}


template <typename scalar_t>
__global__ void fuse_lselu_b(
    const scalar_t* __restrict__ grad_y,
    const scalar_t* __restrict__ x,
    scalar_t* grad_x, float alpha,
    float beta, float lambda, unsigned int numel
){
    for (unsigned int i=blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += gridDim.x * blockDim.x){
        scalar_t x_ = x[i];
        scalar_t dydx;
        if (x_ > 0){
            dydx = lambda;
        }
        else{
            dydx = lambda * (alpha * exp(x_) + beta);
        }
        grad_x[i] = grad_y[i] * dydx;
    }
}


template <typename scalar_t, unsigned int blockSize>
__global__ void fuse_lselu_b_affine(
    const scalar_t* __restrict__ grad_y,
    const scalar_t* __restrict__ x,
    scalar_t* grad_x, scalar_t* grad_lambda,
    float alpha, float beta, float lambda, unsigned int numel
){
    typedef cub::BlockReduce<scalar_t, blockSize> BlockReduce;
    __shared__ typename BlockReduce::TempStorage TempStorage;

    scalar_t smp = 0;
    scalar_t c = 0;

    for (unsigned int i=blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += gridDim.x * blockDim.x){
        scalar_t x_ = x[i];
        scalar_t dydx;
        scalar_t y;
        scalar_t g_y = grad_y[i];
        if (x_ > 0){
            dydx = lambda;
            y = g_y * x_ - c;
        }
        else{
            dydx = lambda * (alpha * exp(x_) + beta);
            y = g_y * (alpha * exp(x_) + beta * x_ - alpha) - c;
        }
        scalar_t t = smp + y;
        c = (t - smp) - y;
        smp = t;
        grad_x[i] = g_y * dydx;
    }

    __syncthreads();
    scalar_t smr = BlockReduce(TempStorage).Sum(smp);
    if (threadIdx.x == 0){
        atomicAdd(&grad_lambda[0], smr);
    }
}


std::vector<torch::Tensor> fuse_lselu_b_cuda(
    torch::Tensor grad_y,
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda,
    bool lambda_train
){
    // get the size of feature map
    unsigned int numel = grad_y.numel();
    auto grad_x = torch::empty_like(grad_y);

    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(x.device());
    auto grad_lambda = torch::zeros({1}, options);

    if (lambda_train){
        AT_DISPATCH_FLOATING_TYPES(grad_y.type(), "fuse_lselu_b_affine", ([&]{
            fuse_lselu_b_affine<scalar_t, 1024><<<6400, 1024>>>(
                grad_y.data<scalar_t>(), x.data<scalar_t>(),
                grad_x.data<scalar_t>(), grad_lambda.data<scalar_t>(), 
                alpha, beta, lambda, numel
            );
        }));
    }
    else{
        AT_DISPATCH_FLOATING_TYPES(grad_y.type(), "fuse_lselu_b", ([&]{
            fuse_lselu_b<scalar_t><<<BLOCKS(numel, 1024), 1024>>>(
                grad_y.data<scalar_t>(), x.data<scalar_t>(),
                grad_x.data<scalar_t>(), alpha, beta, lambda, numel
            );
        }));
    }
    return {grad_x, grad_lambda};
}


/*
The fused Implementation of sSELU
*/


template <typename scalar_t>
__global__ void fuse_sselu_f(
    const scalar_t* __restrict__ x, 
    scalar_t* y, float alpha,
    float beta, float lambda, unsigned int numel
){
    for (unsigned int i=blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += gridDim.x * blockDim.x){
        scalar_t x_ = x[i];
        scalar_t y_;
        if (x_ > 0){
            y_ = lambda * x_;
        }
        else{
            y_ = lambda * (alpha * exp(x_ * beta) - alpha);
        }
        y[i] = y_;
    }
}


torch::Tensor fuse_sselu_f_cuda(
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda
){
    // get the size of feature map
    unsigned int numel = x.numel();
    auto y = torch::empty_like(x);

    AT_DISPATCH_FLOATING_TYPES(x.type(), "fuse_sselu_f", ([&]{
        fuse_sselu_f<scalar_t><<<BLOCKS(numel, 1024), 1024>>>(
            x.data<scalar_t>(), y.data<scalar_t>(),
            alpha, beta, lambda, numel
        );
    }));

    return y;
}


template <typename scalar_t>
__global__ void fuse_sselu_b(
    const scalar_t* __restrict__ grad_y,
    const scalar_t* __restrict__ x,
    scalar_t* grad_x, float alpha,
    float beta, float lambda, unsigned int numel
){
    for (unsigned int i=blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += gridDim.x * blockDim.x){
        scalar_t x_ = x[i];
        scalar_t dydx;
        if (x_ > 0){
            dydx = lambda;
        }
        else{
            dydx = lambda * alpha * beta * exp(beta * x_) + beta;
        }
        grad_x[i] = grad_y[i] * dydx;
    }
}


template <typename scalar_t, unsigned int blockSize>
__global__ void fuse_sselu_b_affine(
    const scalar_t* __restrict__ grad_y,
    const scalar_t* __restrict__ x,
    scalar_t* grad_x, scalar_t* grad_lambda,
    float alpha, float beta, float lambda, unsigned int numel
){
    typedef cub::BlockReduce<scalar_t, blockSize> BlockReduce;
    __shared__ typename BlockReduce::TempStorage TempStorage;

    scalar_t smp = 0;
    scalar_t c = 0;

    for (unsigned int i=blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += gridDim.x * blockDim.x){
        scalar_t x_ = x[i];
        scalar_t dydx;
        scalar_t y;
        scalar_t g_y = grad_y[i];
        if (x_ > 0){
            dydx = lambda;
            y = g_y * x_ - c;
        }
        else{
            dydx = lambda * alpha * beta * exp(beta * x_);
            y = g_y * (alpha * exp(beta * x_) - alpha) - c;
        }
        scalar_t t = smp + y;
        c = (t - smp) - y;
        smp = t;
        grad_x[i] = g_y * dydx;
    }

    __syncthreads();
    scalar_t smr = BlockReduce(TempStorage).Sum(smp);
    if (threadIdx.x == 0){
        atomicAdd(&grad_lambda[0], smr);
    }
}


std::vector<torch::Tensor> fuse_sselu_b_cuda(
    torch::Tensor grad_y,
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda,
    bool lambda_train
){
    // get the size of feature map
    unsigned int numel = grad_y.numel();
    auto grad_x = torch::empty_like(grad_y);

    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(x.device());
    auto grad_lambda = torch::zeros({1}, options);

    if (lambda_train){
        AT_DISPATCH_FLOATING_TYPES(grad_y.type(), "fuse_sselu_b_affine", ([&]{
            fuse_sselu_b_affine<scalar_t, 1024><<<6400, 1024>>>(
                grad_y.data<scalar_t>(), x.data<scalar_t>(),
                grad_x.data<scalar_t>(), grad_lambda.data<scalar_t>(), 
                alpha, beta, lambda, numel
            );
        }));
    }
    else{
        AT_DISPATCH_FLOATING_TYPES(grad_y.type(), "fuse_sselu_b", ([&]{
            fuse_sselu_b<scalar_t><<<BLOCKS(numel, 1024), 1024>>>(
                grad_y.data<scalar_t>(), x.data<scalar_t>(),
                grad_x.data<scalar_t>(), alpha, beta, lambda, numel
            );
        }));
    }
    return {grad_x, grad_lambda};
}

