#include <cassert>
#include <optional>

#include <pybind11/stl.h>
#include <torch/torch.h>

namespace py = pybind11;

torch::Tensor sparse_linear_forward(
    torch::Tensor input,
    torch::Tensor weight,
    torch::Tensor dv,
    torch::Tensor di,
    std::optional<torch::Tensor> bias
) {
    torch::Tensor W = weight.clone();
    W.reshape(-1).scatter_add_(0, di.to(torch::kInt64), dv.to(W.dtype()));
    return torch::nn::functional::linear(
        input.to(W.dtype()),
        W,
        bias ? *bias : torch::Tensor()
    );
}

torch::autograd::tensor_list sparse_linear_backward(
    torch::Tensor output_grad,
    torch::Tensor input,
    torch::Tensor weight,
    torch::Tensor dv,
    torch::Tensor di,
    bool input_needs_grad,
    bool weight_needs_grad,
    bool dv_needs_grad,
    bool bias_needs_grad,
    std::optional<torch::Tensor> bias = std::nullopt
) {
    torch::Tensor W = weight.clone();
    di = di.to(torch::kInt64);
    W.reshape(-1).scatter_add_(0, di, dv.to(W.dtype()));

    torch::Tensor input_grad, weight_grad, dv_grad, bias_grad;
    torch::Tensor output_grad_2d = output_grad.reshape({-1, output_grad.size(-1)});

    if (input_needs_grad)
        input_grad = output_grad_2d.mm(W.to(output_grad_2d.dtype())).view_as(input);

        torch::Tensor input_2d = input.reshape({-1, input.size(-1)});
        weight_grad = output_grad_2d.t().mm(input_2d.to(output_grad_2d.dtype()));
        if (dv_needs_grad)
            dv_grad = weight_grad.view(-1).gather(0, di);

    if (bias && bias_needs_grad)
        bias_grad = output_grad_2d.sum(0);

    return {input_grad, weight_grad, dv_grad, torch::Tensor(), bias_grad};
}



class SparseLinearFunc : public torch::autograd::Function<SparseLinearFunc> {
  public:

    static torch::Tensor forward(
        torch::autograd::AutogradContext *ctx,
        torch::Tensor input,
        torch::Tensor weight,
        torch::Tensor dv,
        torch::Tensor di,
        std::optional<torch::Tensor> bias
    ) {
        if (bias)
            ctx->save_for_backward({input, weight, dv, di, *bias});
        else
            ctx->save_for_backward({input, weight, dv, di});
        return sparse_linear_forward(input, weight, dv, di, bias);
    }

    static torch::autograd::tensor_list backward(
        torch::autograd::AutogradContext *ctx,
        torch::autograd::tensor_list grad_outputs
    ) {
        auto saved = ctx->get_saved_variables();
        auto input = saved[0];
        auto weight = saved[1];
        auto dv = saved[2];
        auto di = saved[3];
        auto output_grad = grad_outputs[0];

        return sparse_linear_backward(
            output_grad,
            input,
            weight,
            dv,
            di,
            ctx->needs_input_grad(0),
            ctx->needs_input_grad(1),
            ctx->needs_input_grad(2),
            saved.size() > 4 && ctx->needs_input_grad(4),
            saved.size() > 4 ?
                std::optional<torch::Tensor>(saved[4]) :
                std::optional<torch::Tensor>(std::nullopt)
        );
    }

};

torch::Tensor apply_sparse_linear_func(
    torch::Tensor input,
    torch::Tensor weight,
    torch::Tensor dv,
    torch::Tensor di,
    std::optional<torch::Tensor> bias = std::nullopt
) {
    return SparseLinearFunc::apply(
        input,
        weight,
        dv,
        di,
        bias
    );
}



PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def(
        "forward",
        &sparse_linear_forward
    );
    m.def(
        "backward",
        &sparse_linear_backward
    );
    m.def(
        "apply",
        &apply_sparse_linear_func,
        py::arg("input"),
        py::arg("weight"),
        py::arg("dv"),
        py::arg("di"),
        py::arg_v("bias", std::nullopt)
    );
}