#include <torch/extension.h>
#include <vector>


torch::Tensor fuse_lselu_f_cuda(
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda
);

torch::Tensor fuse_lselu_f(
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda
){
    return fuse_lselu_f_cuda(x, alpha, beta, lambda);
}

std::vector<torch::Tensor> fuse_lselu_b_cuda(
    torch::Tensor grad_y,
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda,
    bool lambda_train
);

std::vector<torch::Tensor> fuse_lselu_b(
    torch::Tensor grad_y,
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda,
    bool lambda_train
){
    return fuse_lselu_b_cuda(grad_y, x, alpha, beta, lambda, lambda_train);
}


torch::Tensor fuse_sselu_f_cuda(
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda
);

torch::Tensor fuse_sselu_f(
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda
){
    return fuse_sselu_f_cuda(x, alpha, beta, lambda);
}

std::vector<torch::Tensor> fuse_sselu_b_cuda(
    torch::Tensor grad_y,
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda,
    bool lambda_train
);

std::vector<torch::Tensor> fuse_sselu_b(
    torch::Tensor grad_y,
    torch::Tensor x,
    float alpha,
    float beta,
    float lambda,
    bool lambda_train
){
    return fuse_sselu_b_cuda(grad_y, x, alpha, beta, lambda, lambda_train);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("fuse_sselu_f", &fuse_sselu_f, "fused sselu forward");
    m.def("fuse_sselu_b", &fuse_sselu_b, "fused sselu backward");
    m.def("fuse_lselu_f", &fuse_lselu_f, "fused lselu forward");
    m.def("fuse_lselu_b", &fuse_lselu_b, "fused lselu backward");
}