/**
 * @file torch_cuda.cpp
 * @brief Soft Eisner CUDA PyTorch Bindings
 *
 * GPU-accelerated projective dependency parsing with full autograd support.
 */

#include <torch/extension.h>
#include <cuda_runtime.h>
#include <vector>

#include "kernels.cuh"

// =============================================================================
// Helper Macros
// =============================================================================

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

// =============================================================================
// Autograd Function
// =============================================================================

class SoftEisnerCUDAFunction : public torch::autograd::Function<SoftEisnerCUDAFunction> {
public:
    static torch::autograd::tensor_list forward(
        torch::autograd::AutogradContext* ctx,
        torch::Tensor arc_scores,
        torch::Tensor temperature,
        c10::optional<torch::Tensor> lengths_opt
    ) {
        CHECK_INPUT(arc_scores);
        TORCH_CHECK(arc_scores.dim() == 3, "arc_scores must be 3D [B, n, n]");
        TORCH_CHECK(arc_scores.dtype() == torch::kFloat32, "arc_scores must be float32");
        TORCH_CHECK(temperature.numel() == 1, "temperature must be a scalar tensor");

        int B = arc_scores.size(0);
        int n = arc_scores.size(1);

        float temp_val = temperature.cpu().item<float>();
        auto options = arc_scores.options();

        // Get lengths pointer
        const int* lengths_ptr = nullptr;
        torch::Tensor lengths;
        if (lengths_opt.has_value() && lengths_opt->defined()) {
            lengths = lengths_opt->contiguous().to(torch::kInt32);
            if (!lengths.is_cuda()) lengths = lengths.cuda();
            lengths_ptr = lengths.data_ptr<int>();
        }

        // Allocate tables
        torch::Tensor C_R = torch::zeros({B, n, n}, options);
        torch::Tensor C_L = torch::zeros({B, n, n}, options);
        torch::Tensor I_R = torch::zeros({B, n, n}, options);
        torch::Tensor I_L = torch::zeros({B, n, n}, options);
        torch::Tensor partition = torch::zeros({B}, options);

        // Forward pass
        d2p::eisner::forward(
            arc_scores.data_ptr<float>(),
            C_R.data_ptr<float>(),
            C_L.data_ptr<float>(),
            I_R.data_ptr<float>(),
            I_L.data_ptr<float>(),
            partition.data_ptr<float>(),
            lengths_ptr,
            B, n, temp_val
        );

        // Backward pass (compute marginals)
        torch::Tensor beta_C_R = torch::zeros({B, n, n}, options);
        torch::Tensor beta_C_L = torch::zeros({B, n, n}, options);
        torch::Tensor beta_I_R = torch::zeros({B, n, n}, options);
        torch::Tensor beta_I_L = torch::zeros({B, n, n}, options);
        torch::Tensor marginals = torch::zeros({B, n, n}, options);
        torch::Tensor grad_T = torch::zeros({B}, options);

        d2p::eisner::backward(
            arc_scores.data_ptr<float>(),
            C_R.data_ptr<float>(),
            C_L.data_ptr<float>(),
            I_R.data_ptr<float>(),
            I_L.data_ptr<float>(),
            beta_C_R.data_ptr<float>(),
            beta_C_L.data_ptr<float>(),
            beta_I_R.data_ptr<float>(),
            beta_I_L.data_ptr<float>(),
            marginals.data_ptr<float>(),
            grad_T.data_ptr<float>(),
            lengths_ptr,
            B, n, temp_val
        );

        // Save for backward
        ctx->saved_data["temperature"] = temp_val;
        if (lengths_opt.has_value() && lengths_opt->defined()) {
            ctx->saved_data["has_lengths"] = true;
            ctx->save_for_backward({arc_scores.clone(), C_R, C_L, I_R, I_L, grad_T, lengths});
        } else {
            ctx->saved_data["has_lengths"] = false;
            ctx->save_for_backward({arc_scores.clone(), C_R, C_L, I_R, I_L, grad_T});
        }

        return {partition, marginals};
    }

    static torch::autograd::tensor_list backward(
        torch::autograd::AutogradContext* ctx,
        torch::autograd::tensor_list grad_outputs
    ) {
        auto saved = ctx->get_saved_variables();
        torch::Tensor arc_scores = saved[0];
        torch::Tensor C_R = saved[1];
        torch::Tensor C_L = saved[2];
        torch::Tensor I_R = saved[3];
        torch::Tensor I_L = saved[4];
        torch::Tensor grad_T_fwd = saved[5];

        float temp_val = static_cast<float>(ctx->saved_data["temperature"].toDouble());
        bool has_lengths = ctx->saved_data["has_lengths"].toBool();

        const int* lengths_ptr = nullptr;
        if (has_lengths && saved.size() > 6) {
            torch::Tensor lengths = saved[6];
            lengths_ptr = lengths.data_ptr<int>();
        }

        int B = arc_scores.size(0);
        int n = arc_scores.size(1);
        auto options = arc_scores.options();

        torch::Tensor grad_partition = grad_outputs[0];
        torch::Tensor grad_marginals = grad_outputs[1];

        // Initialize gradients
        torch::Tensor grad_arc = torch::zeros({B, n, n}, options);
        torch::Tensor total_grad_T = torch::zeros({1}, options);

        // Gradient from partition function
        if (grad_partition.defined() && grad_partition.numel() > 0) {
            // Recompute marginals
            torch::Tensor beta_C_R = torch::zeros({B, n, n}, options);
            torch::Tensor beta_C_L = torch::zeros({B, n, n}, options);
            torch::Tensor beta_I_R = torch::zeros({B, n, n}, options);
            torch::Tensor beta_I_L = torch::zeros({B, n, n}, options);
            torch::Tensor marginals = torch::zeros({B, n, n}, options);
            torch::Tensor grad_T = torch::zeros({B}, options);

            d2p::eisner::backward(
                arc_scores.data_ptr<float>(),
                C_R.data_ptr<float>(),
                C_L.data_ptr<float>(),
                I_R.data_ptr<float>(),
                I_L.data_ptr<float>(),
                beta_C_R.data_ptr<float>(),
                beta_C_L.data_ptr<float>(),
                beta_I_R.data_ptr<float>(),
                beta_I_L.data_ptr<float>(),
                marginals.data_ptr<float>(),
                grad_T.data_ptr<float>(),
                lengths_ptr,
                B, n, temp_val
            );

            grad_arc += grad_partition.view({B, 1, 1}) * marginals;
            total_grad_T += (grad_partition * grad_T_fwd).sum().reshape({1});
        }

        // Gradient from marginals via HVP
        if (grad_marginals.defined() && grad_marginals.numel() > 0) {
            grad_marginals = grad_marginals.contiguous().to(torch::kFloat32);

            torch::Tensor d_C_R = torch::zeros({B, n, n}, options);
            torch::Tensor d_C_L = torch::zeros({B, n, n}, options);
            torch::Tensor d_I_R = torch::zeros({B, n, n}, options);
            torch::Tensor d_I_L = torch::zeros({B, n, n}, options);
            torch::Tensor beta_C_R = torch::zeros({B, n, n}, options);
            torch::Tensor beta_C_L = torch::zeros({B, n, n}, options);
            torch::Tensor beta_I_R = torch::zeros({B, n, n}, options);
            torch::Tensor beta_I_L = torch::zeros({B, n, n}, options);
            torch::Tensor d_beta_C_R = torch::zeros({B, n, n}, options);
            torch::Tensor d_beta_C_L = torch::zeros({B, n, n}, options);
            torch::Tensor d_beta_I_R = torch::zeros({B, n, n}, options);
            torch::Tensor d_beta_I_L = torch::zeros({B, n, n}, options);
            torch::Tensor HVP = torch::zeros({B, n, n}, options);

            d2p::eisner::hvp(
                arc_scores.data_ptr<float>(),
                grad_marginals.data_ptr<float>(),
                C_R.data_ptr<float>(),
                C_L.data_ptr<float>(),
                I_R.data_ptr<float>(),
                I_L.data_ptr<float>(),
                d_C_R.data_ptr<float>(),
                d_C_L.data_ptr<float>(),
                d_I_R.data_ptr<float>(),
                d_I_L.data_ptr<float>(),
                beta_C_R.data_ptr<float>(),
                beta_C_L.data_ptr<float>(),
                beta_I_R.data_ptr<float>(),
                beta_I_L.data_ptr<float>(),
                d_beta_C_R.data_ptr<float>(),
                d_beta_C_L.data_ptr<float>(),
                d_beta_I_R.data_ptr<float>(),
                d_beta_I_L.data_ptr<float>(),
                HVP.data_ptr<float>(),
                lengths_ptr,
                B, n, temp_val
            );

            grad_arc += HVP;
        }

        return {grad_arc, total_grad_T, torch::Tensor()};
    }
};

// =============================================================================
// Wrapper Functions
// =============================================================================

std::vector<torch::Tensor> soft_eisner_cuda(
    torch::Tensor arc_scores,
    torch::Tensor temperature,
    c10::optional<torch::Tensor> lengths
) {
    return SoftEisnerCUDAFunction::apply(arc_scores, temperature, lengths);
}

torch::Tensor soft_eisner_float_cuda(
    torch::Tensor arc_scores,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    auto options = arc_scores.options();
    auto temp_t = torch::tensor({static_cast<float>(temperature)}, options);
    auto results = SoftEisnerCUDAFunction::apply(arc_scores, temp_t, lengths);
    return results[0];
}

std::tuple<torch::Tensor, torch::Tensor> soft_eisner_with_grads_cuda(
    torch::Tensor arc_scores,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT(arc_scores);

    int B = arc_scores.size(0);
    int n = arc_scores.size(1);
    float T = static_cast<float>(temperature);
    auto options = arc_scores.options();

    const int* lengths_ptr = nullptr;
    torch::Tensor lengths;
    if (lengths_opt.has_value() && lengths_opt->defined()) {
        lengths = lengths_opt->contiguous().to(torch::kInt32);
        if (!lengths.is_cuda()) lengths = lengths.cuda();
        lengths_ptr = lengths.data_ptr<int>();
    }

    torch::Tensor C_R = torch::zeros({B, n, n}, options);
    torch::Tensor C_L = torch::zeros({B, n, n}, options);
    torch::Tensor I_R = torch::zeros({B, n, n}, options);
    torch::Tensor I_L = torch::zeros({B, n, n}, options);
    torch::Tensor partition = torch::zeros({B}, options);

    d2p::eisner::forward(
        arc_scores.data_ptr<float>(),
        C_R.data_ptr<float>(),
        C_L.data_ptr<float>(),
        I_R.data_ptr<float>(),
        I_L.data_ptr<float>(),
        partition.data_ptr<float>(),
        lengths_ptr,
        B, n, T
    );

    torch::Tensor beta_C_R = torch::zeros({B, n, n}, options);
    torch::Tensor beta_C_L = torch::zeros({B, n, n}, options);
    torch::Tensor beta_I_R = torch::zeros({B, n, n}, options);
    torch::Tensor beta_I_L = torch::zeros({B, n, n}, options);
    torch::Tensor marginals = torch::zeros({B, n, n}, options);
    torch::Tensor grad_T = torch::zeros({B}, options);

    d2p::eisner::backward(
        arc_scores.data_ptr<float>(),
        C_R.data_ptr<float>(),
        C_L.data_ptr<float>(),
        I_R.data_ptr<float>(),
        I_L.data_ptr<float>(),
        beta_C_R.data_ptr<float>(),
        beta_C_L.data_ptr<float>(),
        beta_I_R.data_ptr<float>(),
        beta_I_L.data_ptr<float>(),
        marginals.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        lengths_ptr,
        B, n, T
    );

    return std::make_tuple(partition, marginals);
}

torch::Tensor soft_eisner_hvp_cuda(
    torch::Tensor arc_scores,
    torch::Tensor V,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT(arc_scores);
    CHECK_INPUT(V);

    int B = arc_scores.size(0);
    int n = arc_scores.size(1);
    float T = static_cast<float>(temperature);
    auto options = arc_scores.options();

    const int* lengths_ptr = nullptr;
    torch::Tensor lengths;
    if (lengths_opt.has_value() && lengths_opt->defined()) {
        lengths = lengths_opt->contiguous().to(torch::kInt32);
        if (!lengths.is_cuda()) lengths = lengths.cuda();
        lengths_ptr = lengths.data_ptr<int>();
    }

    torch::Tensor C_R = torch::zeros({B, n, n}, options);
    torch::Tensor C_L = torch::zeros({B, n, n}, options);
    torch::Tensor I_R = torch::zeros({B, n, n}, options);
    torch::Tensor I_L = torch::zeros({B, n, n}, options);
    torch::Tensor partition = torch::zeros({B}, options);

    d2p::eisner::forward(
        arc_scores.data_ptr<float>(),
        C_R.data_ptr<float>(),
        C_L.data_ptr<float>(),
        I_R.data_ptr<float>(),
        I_L.data_ptr<float>(),
        partition.data_ptr<float>(),
        lengths_ptr,
        B, n, T
    );

    torch::Tensor d_C_R = torch::zeros({B, n, n}, options);
    torch::Tensor d_C_L = torch::zeros({B, n, n}, options);
    torch::Tensor d_I_R = torch::zeros({B, n, n}, options);
    torch::Tensor d_I_L = torch::zeros({B, n, n}, options);
    torch::Tensor beta_C_R = torch::zeros({B, n, n}, options);
    torch::Tensor beta_C_L = torch::zeros({B, n, n}, options);
    torch::Tensor beta_I_R = torch::zeros({B, n, n}, options);
    torch::Tensor beta_I_L = torch::zeros({B, n, n}, options);
    torch::Tensor d_beta_C_R = torch::zeros({B, n, n}, options);
    torch::Tensor d_beta_C_L = torch::zeros({B, n, n}, options);
    torch::Tensor d_beta_I_R = torch::zeros({B, n, n}, options);
    torch::Tensor d_beta_I_L = torch::zeros({B, n, n}, options);
    torch::Tensor HVP = torch::zeros({B, n, n}, options);

    d2p::eisner::hvp(
        arc_scores.data_ptr<float>(),
        V.data_ptr<float>(),
        C_R.data_ptr<float>(),
        C_L.data_ptr<float>(),
        I_R.data_ptr<float>(),
        I_L.data_ptr<float>(),
        d_C_R.data_ptr<float>(),
        d_C_L.data_ptr<float>(),
        d_I_R.data_ptr<float>(),
        d_I_L.data_ptr<float>(),
        beta_C_R.data_ptr<float>(),
        beta_C_L.data_ptr<float>(),
        beta_I_R.data_ptr<float>(),
        beta_I_L.data_ptr<float>(),
        d_beta_C_R.data_ptr<float>(),
        d_beta_C_L.data_ptr<float>(),
        d_beta_I_R.data_ptr<float>(),
        d_beta_I_L.data_ptr<float>(),
        HVP.data_ptr<float>(),
        lengths_ptr,
        B, n, T
    );

    return HVP;
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> soft_eisner_backward_full_cuda(
    torch::Tensor arc_scores,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT(arc_scores);

    int B = arc_scores.size(0);
    int n = arc_scores.size(1);
    float T = static_cast<float>(temperature);
    auto options = arc_scores.options();

    const int* lengths_ptr = nullptr;
    torch::Tensor lengths;
    if (lengths_opt.has_value() && lengths_opt->defined()) {
        lengths = lengths_opt->contiguous().to(torch::kInt32);
        if (!lengths.is_cuda()) lengths = lengths.cuda();
        lengths_ptr = lengths.data_ptr<int>();
    }

    torch::Tensor C_R = torch::zeros({B, n, n}, options);
    torch::Tensor C_L = torch::zeros({B, n, n}, options);
    torch::Tensor I_R = torch::zeros({B, n, n}, options);
    torch::Tensor I_L = torch::zeros({B, n, n}, options);
    torch::Tensor partition = torch::zeros({B}, options);

    d2p::eisner::forward(
        arc_scores.data_ptr<float>(),
        C_R.data_ptr<float>(),
        C_L.data_ptr<float>(),
        I_R.data_ptr<float>(),
        I_L.data_ptr<float>(),
        partition.data_ptr<float>(),
        lengths_ptr,
        B, n, T
    );

    torch::Tensor beta_C_R = torch::zeros({B, n, n}, options);
    torch::Tensor beta_C_L = torch::zeros({B, n, n}, options);
    torch::Tensor beta_I_R = torch::zeros({B, n, n}, options);
    torch::Tensor beta_I_L = torch::zeros({B, n, n}, options);
    torch::Tensor marginals = torch::zeros({B, n, n}, options);
    torch::Tensor grad_T = torch::zeros({B}, options);

    d2p::eisner::backward(
        arc_scores.data_ptr<float>(),
        C_R.data_ptr<float>(),
        C_L.data_ptr<float>(),
        I_R.data_ptr<float>(),
        I_L.data_ptr<float>(),
        beta_C_R.data_ptr<float>(),
        beta_C_L.data_ptr<float>(),
        beta_I_R.data_ptr<float>(),
        beta_I_L.data_ptr<float>(),
        marginals.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        lengths_ptr,
        B, n, T
    );

    return std::make_tuple(partition, marginals, grad_T);
}

// =============================================================================
// Library Registration
// =============================================================================

#ifdef USE_TORCH_LIBRARY

TORCH_LIBRARY_IMPL(d2p, CUDA, m) {
    m.impl("soft_eisner", soft_eisner_cuda);
    m.impl("soft_eisner_float", soft_eisner_float_cuda);
    m.impl("soft_eisner_with_grads", soft_eisner_with_grads_cuda);
    m.impl("soft_eisner_hvp", soft_eisner_hvp_cuda);
    m.impl("soft_eisner_backward_full", soft_eisner_backward_full_cuda);
}

TORCH_LIBRARY_IMPL(d2p, AutogradCUDA, m) {
    m.impl("soft_eisner", soft_eisner_cuda);
    m.impl("soft_eisner_float", soft_eisner_float_cuda);
}

#endif // USE_TORCH_LIBRARY
