/**
 * @file torch_cuda.cpp
 * @brief Soft CKY CUDA Extension with PyTorch Autograd
 *
 * GPU-accelerated soft CKY parsing with:
 *   - Span-based DP (constituency parsing)
 *   - Temperature-scaled logsumexp for differentiability
 *   - Full gradient support through PyTorch autograd
 *
 * Operations:
 *   - soft_cky: Main autograd function (tensor params) -> (logZ, posteriors)
 *   - soft_cky_float: Convenience function (float params)
 *   - soft_cky_with_grads: Explicit gradients for debugging
 *   - soft_cky_hvp: Hessian-vector product
 *   - soft_cky_param_jacobian: dP/dT (posteriors derivative w.r.t. temperature)
 *   - soft_cky_backward_full: Complete backward given grad_posteriors
 */

#include <torch/extension.h>
#include <cuda_runtime.h>
#include <vector>

// Shared utilities
#include "common/torch_utils.h"
#include "common/cuda_utils.h"

// CUDA kernel declarations
#include "cky/kernels.cuh"

using namespace d2p::common;

// =============================================================================
// Soft CKY CUDA Autograd Function
//
// Forward: merge_scores, leaf_scores, temperature -> (logZ, posteriors)
// Backward: uses HVP for grad w.r.t. scores, param_grad for temperature
//
// CKY uses logsumexp (maximization), similar to SW.
// =============================================================================

class SoftCKYCUDAFunction : public torch::autograd::Function<SoftCKYCUDAFunction> {
public:
    static torch::autograd::tensor_list forward(
        torch::autograd::AutogradContext *ctx,
        torch::Tensor merge_scores,
        torch::Tensor leaf_scores,
        torch::Tensor temperature
    ) {
        D2P_CHECK_INPUT_CUDA(merge_scores);
        D2P_CHECK_INPUT_CUDA(leaf_scores);
        TORCH_CHECK(merge_scores.dim() == 4, "merge_scores must be 4D [B, n, n, n]");
        TORCH_CHECK(leaf_scores.dim() == 2, "leaf_scores must be 2D [B, n]");
        TORCH_CHECK(merge_scores.dtype() == torch::kFloat32, "merge_scores must be float32");
        TORCH_CHECK(leaf_scores.dtype() == torch::kFloat32, "leaf_scores must be float32");
        TORCH_CHECK(temperature.numel() == 1, "temperature must be a scalar tensor");

        int B = merge_scores.size(0);
        int n = merge_scores.size(1);

        float temp_val = temperature.cpu().item<float>();
        auto options = merge_scores.options();

        // Allocate outputs
        torch::Tensor Z = torch::zeros({B, n, n}, options);
        torch::Tensor partition = torch::zeros({B}, options);
        torch::Tensor beta = torch::zeros({B, n, n}, options);
        torch::Tensor Pcond = torch::zeros({B, n, n, n}, options);
        torch::Tensor Pjoint = torch::zeros({B, n, n, n}, options);
        torch::Tensor grad_merge = torch::zeros({B, n, n, n}, options);
        torch::Tensor grad_leaf = torch::zeros({B, n}, options);
        torch::Tensor grad_T = torch::zeros({B}, options);

        // Forward pass (inside algorithm)
        cky_forward(
            merge_scores.data_ptr<float>(),
            leaf_scores.data_ptr<float>(),
            Z.data_ptr<float>(),
            partition.data_ptr<float>(),
            B, n, temp_val
        );

        // Backward pass of DP (outside algorithm) - computes posteriors
        cky_backward(
            Z.data_ptr<float>(),
            merge_scores.data_ptr<float>(),
            leaf_scores.data_ptr<float>(),
            partition.data_ptr<float>(),
            beta.data_ptr<float>(),
            Pcond.data_ptr<float>(),
            Pjoint.data_ptr<float>(),
            grad_merge.data_ptr<float>(),
            grad_leaf.data_ptr<float>(),
            grad_T.data_ptr<float>(),
            B, n, temp_val
        );

        // Save for backward
        ctx->save_for_backward({merge_scores.clone(), leaf_scores.clone(), Z.clone(),
                               partition.clone(), grad_leaf.clone(), grad_T.clone()});
        ctx->saved_data["temperature"] = temp_val;

        // Return (logZ, posteriors) - both differentiable
        return {partition, Pjoint};
    }

    static torch::autograd::tensor_list backward(
        torch::autograd::AutogradContext *ctx,
        torch::autograd::tensor_list grad_outputs
    ) {
        auto saved = ctx->get_saved_variables();
        torch::Tensor merge_scores = saved[0];
        torch::Tensor leaf_scores = saved[1];
        torch::Tensor Z = saved[2];
        torch::Tensor partition = saved[3];
        torch::Tensor grad_leaf_fwd = saved[4];  // dZ/d(leaf) per batch
        torch::Tensor grad_T_fwd = saved[5];     // dZ/dT per batch

        float temp_val = static_cast<float>(ctx->saved_data["temperature"].toDouble());

        int B = merge_scores.size(0);
        int n = merge_scores.size(1);

        auto options = merge_scores.options();

        torch::Tensor grad_logZ = grad_outputs[0];      // [B]
        torch::Tensor grad_Pjoint = grad_outputs[1];    // [B, n, n, n]

        // Initialize accumulated gradients
        torch::Tensor grad_merge = torch::zeros({B, n, n, n}, options);
        torch::Tensor grad_leaf = torch::zeros({B, n}, options);
        torch::Tensor total_grad_T = torch::zeros({1}, options);

        // ============ Gradient from logZ (partition function) ============
        if (grad_logZ.defined() && grad_logZ.numel() > 0) {
            // Recompute posteriors
            torch::Tensor beta = torch::zeros({B, n, n}, options);
            torch::Tensor Pcond = torch::zeros({B, n, n, n}, options);
            torch::Tensor Pjoint = torch::zeros({B, n, n, n}, options);
            torch::Tensor gm = torch::zeros({B, n, n, n}, options);
            torch::Tensor gl = torch::zeros({B, n}, options);
            torch::Tensor gT = torch::zeros({B}, options);

            cky_backward(
                Z.data_ptr<float>(),
                merge_scores.data_ptr<float>(),
                leaf_scores.data_ptr<float>(),
                partition.data_ptr<float>(),
                beta.data_ptr<float>(),
                Pcond.data_ptr<float>(),
                Pjoint.data_ptr<float>(),
                gm.data_ptr<float>(),
                gl.data_ptr<float>(),
                gT.data_ptr<float>(),
                B, n, temp_val
            );

            // dZ/d(merge) = Pjoint
            grad_merge += grad_logZ.view({B, 1, 1, 1}) * Pjoint;

            // dZ/d(leaf)
            grad_leaf += grad_logZ.view({B, 1}) * grad_leaf_fwd;

            // dL/dT via logZ path
            total_grad_T += (grad_logZ * grad_T_fwd).sum().reshape({1});
        }

        // ============ Gradient from Pjoint (posteriors) ============
        if (grad_Pjoint.defined() && grad_Pjoint.numel() > 0) {
            grad_Pjoint = grad_Pjoint.contiguous().to(torch::kFloat32);

            // HVP: d^2Z/d(merge)^2 * grad_Pjoint
            torch::Tensor d_Z = torch::zeros({B, n, n}, options);
            torch::Tensor d_partition = torch::zeros({B}, options);
            torch::Tensor beta = torch::zeros({B, n, n}, options);
            torch::Tensor d_beta = torch::zeros({B, n, n}, options);
            torch::Tensor hvp_merge = torch::zeros({B, n, n, n}, options);
            torch::Tensor hvp_leaf = torch::zeros({B, n}, options);

            cky_hvp(
                Z.data_ptr<float>(),
                merge_scores.data_ptr<float>(),
                leaf_scores.data_ptr<float>(),
                partition.data_ptr<float>(),
                grad_Pjoint.data_ptr<float>(),
                torch::zeros({B, n}, options).data_ptr<float>(),
                d_Z.data_ptr<float>(),
                d_partition.data_ptr<float>(),
                beta.data_ptr<float>(),
                d_beta.data_ptr<float>(),
                hvp_merge.data_ptr<float>(),
                hvp_leaf.data_ptr<float>(),
                B, n, temp_val
            );

            grad_merge += hvp_merge;
            grad_leaf += hvp_leaf;

            // Param grad: dP/dT
            torch::Tensor U = torch::zeros({B, n, n}, options);
            torch::Tensor beta_ws = torch::zeros({B, n, n}, options);
            torch::Tensor W = torch::zeros({B, n, n}, options);
            torch::Tensor dP_dT_merge = torch::zeros({B, n, n, n}, options);
            torch::Tensor dP_dT_leaf = torch::zeros({B, n}, options);

            cky_param_grad(
                Z.data_ptr<float>(),
                merge_scores.data_ptr<float>(),
                leaf_scores.data_ptr<float>(),
                partition.data_ptr<float>(),
                U.data_ptr<float>(),
                beta_ws.data_ptr<float>(),
                W.data_ptr<float>(),
                dP_dT_merge.data_ptr<float>(),
                dP_dT_leaf.data_ptr<float>(),
                B, n, temp_val
            );

            total_grad_T += (grad_Pjoint * dP_dT_merge).sum().reshape({1});
        }

        return {grad_merge, grad_leaf, total_grad_T};
    }
};

// =============================================================================
// Python Interface Functions
// =============================================================================

// Main autograd function with tensor params
std::vector<torch::Tensor> soft_cky_cuda(
    torch::Tensor merge_scores,
    torch::Tensor leaf_scores,
    torch::Tensor temperature
) {
    return SoftCKYCUDAFunction::apply(merge_scores, leaf_scores, temperature);
}

// Convenience function with float params
std::vector<torch::Tensor> soft_cky_float_cuda(
    torch::Tensor merge_scores,
    torch::Tensor leaf_scores,
    double temperature
) {
    auto options = merge_scores.options();
    auto temp_t = torch::tensor({static_cast<float>(temperature)}, options);
    return SoftCKYCUDAFunction::apply(merge_scores, leaf_scores, temp_t);
}

// With explicit gradients (for debugging/inspection)
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
soft_cky_with_grads_cuda(
    torch::Tensor merge_scores,
    torch::Tensor leaf_scores,
    double temperature
) {
    D2P_CHECK_INPUT_CUDA(merge_scores);
    D2P_CHECK_INPUT_CUDA(leaf_scores);

    int B = merge_scores.size(0);
    int n = merge_scores.size(1);

    auto options = merge_scores.options();
    float T = static_cast<float>(temperature);

    torch::Tensor Z = torch::zeros({B, n, n}, options);
    torch::Tensor partition = torch::zeros({B}, options);
    torch::Tensor beta = torch::zeros({B, n, n}, options);
    torch::Tensor Pcond = torch::zeros({B, n, n, n}, options);
    torch::Tensor Pjoint = torch::zeros({B, n, n, n}, options);
    torch::Tensor grad_merge = torch::zeros({B, n, n, n}, options);
    torch::Tensor grad_leaf = torch::zeros({B, n}, options);
    torch::Tensor grad_T = torch::zeros({B}, options);

    cky_forward(
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        Z.data_ptr<float>(),
        partition.data_ptr<float>(),
        B, n, T
    );

    cky_backward(
        Z.data_ptr<float>(),
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        partition.data_ptr<float>(),
        beta.data_ptr<float>(),
        Pcond.data_ptr<float>(),
        Pjoint.data_ptr<float>(),
        grad_merge.data_ptr<float>(),
        grad_leaf.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        B, n, T
    );

    return std::make_tuple(partition, Pjoint, grad_leaf, grad_T);
}

// HVP: Hessian-vector product
torch::Tensor soft_cky_hvp_cuda(
    torch::Tensor merge_scores,
    torch::Tensor leaf_scores,
    torch::Tensor v_merge,
    torch::Tensor v_leaf,
    double temperature
) {
    D2P_CHECK_INPUT_CUDA(merge_scores);
    D2P_CHECK_INPUT_CUDA(leaf_scores);
    D2P_CHECK_INPUT_CUDA(v_merge);
    D2P_CHECK_INPUT_CUDA(v_leaf);

    int B = merge_scores.size(0);
    int n = merge_scores.size(1);

    auto options = merge_scores.options();
    float T = static_cast<float>(temperature);

    torch::Tensor Z = torch::zeros({B, n, n}, options);
    torch::Tensor partition = torch::zeros({B}, options);

    cky_forward(
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        Z.data_ptr<float>(),
        partition.data_ptr<float>(),
        B, n, T
    );

    torch::Tensor d_Z = torch::zeros({B, n, n}, options);
    torch::Tensor d_partition = torch::zeros({B}, options);
    torch::Tensor beta = torch::zeros({B, n, n}, options);
    torch::Tensor d_beta = torch::zeros({B, n, n}, options);
    torch::Tensor hvp_merge = torch::zeros({B, n, n, n}, options);
    torch::Tensor hvp_leaf = torch::zeros({B, n}, options);

    cky_hvp(
        Z.data_ptr<float>(),
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        partition.data_ptr<float>(),
        v_merge.data_ptr<float>(),
        v_leaf.data_ptr<float>(),
        d_Z.data_ptr<float>(),
        d_partition.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        hvp_merge.data_ptr<float>(),
        hvp_leaf.data_ptr<float>(),
        B, n, T
    );

    return hvp_merge;
}

// Param Jacobian: dP/dT
torch::Tensor soft_cky_param_jacobian_cuda(
    torch::Tensor merge_scores,
    torch::Tensor leaf_scores,
    double temperature
) {
    D2P_CHECK_INPUT_CUDA(merge_scores);
    D2P_CHECK_INPUT_CUDA(leaf_scores);

    int B = merge_scores.size(0);
    int n = merge_scores.size(1);

    auto options = merge_scores.options();
    float T = static_cast<float>(temperature);

    torch::Tensor Z = torch::zeros({B, n, n}, options);
    torch::Tensor partition = torch::zeros({B}, options);

    cky_forward(
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        Z.data_ptr<float>(),
        partition.data_ptr<float>(),
        B, n, T
    );

    torch::Tensor U = torch::zeros({B, n, n}, options);
    torch::Tensor beta = torch::zeros({B, n, n}, options);
    torch::Tensor W = torch::zeros({B, n, n}, options);
    torch::Tensor dP_dT_merge = torch::zeros({B, n, n, n}, options);
    torch::Tensor dP_dT_leaf = torch::zeros({B, n}, options);

    cky_param_grad(
        Z.data_ptr<float>(),
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        partition.data_ptr<float>(),
        U.data_ptr<float>(),
        beta.data_ptr<float>(),
        W.data_ptr<float>(),
        dP_dT_merge.data_ptr<float>(),
        dP_dT_leaf.data_ptr<float>(),
        B, n, T
    );

    return dP_dT_merge;
}

// Full backward: returns (grad_merge, grad_leaf, grad_T) given grad_posteriors
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
soft_cky_backward_full_cuda(
    torch::Tensor merge_scores,
    torch::Tensor leaf_scores,
    torch::Tensor grad_posteriors,
    double temperature
) {
    D2P_CHECK_INPUT_CUDA(merge_scores);
    D2P_CHECK_INPUT_CUDA(leaf_scores);
    D2P_CHECK_INPUT_CUDA(grad_posteriors);

    int B = merge_scores.size(0);
    int n = merge_scores.size(1);

    auto options = merge_scores.options();
    float T = static_cast<float>(temperature);

    grad_posteriors = grad_posteriors.contiguous().to(torch::kFloat32);

    // Forward pass
    torch::Tensor Z = torch::zeros({B, n, n}, options);
    torch::Tensor partition = torch::zeros({B}, options);
    torch::Tensor beta_fwd = torch::zeros({B, n, n}, options);
    torch::Tensor Pcond = torch::zeros({B, n, n, n}, options);
    torch::Tensor Pjoint = torch::zeros({B, n, n, n}, options);
    torch::Tensor gm = torch::zeros({B, n, n, n}, options);
    torch::Tensor grad_leaf_fwd = torch::zeros({B, n}, options);
    torch::Tensor grad_T_fwd = torch::zeros({B}, options);

    cky_forward(
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        Z.data_ptr<float>(),
        partition.data_ptr<float>(),
        B, n, T
    );

    cky_backward(
        Z.data_ptr<float>(),
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        partition.data_ptr<float>(),
        beta_fwd.data_ptr<float>(),
        Pcond.data_ptr<float>(),
        Pjoint.data_ptr<float>(),
        gm.data_ptr<float>(),
        grad_leaf_fwd.data_ptr<float>(),
        grad_T_fwd.data_ptr<float>(),
        B, n, T
    );

    // HVP for grad_merge
    torch::Tensor d_Z = torch::zeros({B, n, n}, options);
    torch::Tensor d_partition = torch::zeros({B}, options);
    torch::Tensor beta = torch::zeros({B, n, n}, options);
    torch::Tensor d_beta = torch::zeros({B, n, n}, options);
    torch::Tensor grad_merge = torch::zeros({B, n, n, n}, options);
    torch::Tensor grad_leaf = torch::zeros({B, n}, options);

    cky_hvp(
        Z.data_ptr<float>(),
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        partition.data_ptr<float>(),
        grad_posteriors.data_ptr<float>(),
        torch::zeros({B, n}, options).data_ptr<float>(),
        d_Z.data_ptr<float>(),
        d_partition.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        grad_merge.data_ptr<float>(),
        grad_leaf.data_ptr<float>(),
        B, n, T
    );

    // Param grad for temperature
    torch::Tensor U = torch::zeros({B, n, n}, options);
    torch::Tensor beta_ws = torch::zeros({B, n, n}, options);
    torch::Tensor W = torch::zeros({B, n, n}, options);
    torch::Tensor dP_dT_merge = torch::zeros({B, n, n, n}, options);
    torch::Tensor dP_dT_leaf = torch::zeros({B, n}, options);

    cky_param_grad(
        Z.data_ptr<float>(),
        merge_scores.data_ptr<float>(),
        leaf_scores.data_ptr<float>(),
        partition.data_ptr<float>(),
        U.data_ptr<float>(),
        beta_ws.data_ptr<float>(),
        W.data_ptr<float>(),
        dP_dT_merge.data_ptr<float>(),
        dP_dT_leaf.data_ptr<float>(),
        B, n, T
    );

    torch::Tensor total_grad_T = (grad_posteriors * dP_dT_merge).sum().reshape({1});

    return std::make_tuple(grad_merge, grad_leaf, total_grad_T);
}

// =============================================================================
// TORCH_LIBRARY_IMPL Registration
// =============================================================================

#ifdef USE_TORCH_LIBRARY

TORCH_LIBRARY_IMPL(d2p, CUDA, m) {
    m.impl("soft_cky", soft_cky_cuda);
    m.impl("soft_cky_float", soft_cky_float_cuda);
    m.impl("soft_cky_with_grads", soft_cky_with_grads_cuda);
    m.impl("soft_cky_hvp", soft_cky_hvp_cuda);
    m.impl("soft_cky_param_jacobian", soft_cky_param_jacobian_cuda);
    m.impl("soft_cky_backward_full", soft_cky_backward_full_cuda);
}

TORCH_LIBRARY_IMPL(d2p, AutogradCUDA, m) {
    m.impl("soft_cky", soft_cky_cuda);
    m.impl("soft_cky_float", soft_cky_float_cuda);
}

#endif // USE_TORCH_LIBRARY
