/**
 * @file torch_cuda.cpp
 * @brief Soft LCS CUDA PyTorch Bindings
 *
 * Provides torch.ops.d2p.soft_lcs* operators for CUDA tensors.
 */

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <vector>
#include <tuple>

#include "kernels.cuh"

// ============================================================================
// Helper Functions
// ============================================================================

static torch::Tensor make_lengths_tensor(
    const c10::optional<torch::Tensor>& lengths,
    int B, int L1, int L2, torch::Device device
) {
    if (lengths.has_value()) {
        return lengths.value().to(device).to(torch::kInt32).contiguous();
    }
    auto lens = torch::empty({B, 2}, torch::dtype(torch::kInt32).device(device));
    lens.select(1, 0).fill_(L1);
    lens.select(1, 1).fill_(L2);
    return lens;
}

// ============================================================================
// Autograd Function
// ============================================================================

class SoftLCSCUDAFunction : public torch::autograd::Function<SoftLCSCUDAFunction> {
public:
    static torch::autograd::variable_list forward(
        torch::autograd::AutogradContext* ctx,
        torch::Tensor scores,
        double temperature,
        c10::optional<torch::Tensor> lengths
    ) {
        TORCH_CHECK(scores.dim() == 3, "scores must be 3D [B, L1, L2]");
        TORCH_CHECK(scores.is_cuda(), "scores must be CUDA tensor");
        TORCH_CHECK(scores.scalar_type() == torch::kFloat32, "scores must be float32");

        const int B = scores.size(0);
        const int L1 = scores.size(1);
        const int L2 = scores.size(2);
        const float T = static_cast<float>(temperature);

        c10::cuda::CUDAGuard device_guard(scores.device());

        auto scores_c = scores.contiguous();
        auto lengths_t = make_lengths_tensor(lengths, B, L1, L2, scores.device());

        // Allocate alpha and lcs_score
        auto alpha = torch::empty({B, (L1+1)*(L2+1)}, scores.options());
        auto lcs_score = torch::empty({B}, scores.options());

        // Forward pass
        d2p::lcs::lcs_forward(
            scores_c.data_ptr<float>(),
            alpha.data_ptr<float>(),
            lcs_score.data_ptr<float>(),
            lengths_t.data_ptr<int>(),
            B, L1, L2, T
        );

        // Backward pass to get posteriors
        auto beta = torch::empty_like(alpha);
        auto posteriors = torch::zeros_like(scores_c);
        auto grad_T = torch::zeros({B}, scores.options());

        d2p::lcs::lcs_backward(
            alpha.data_ptr<float>(),
            scores_c.data_ptr<float>(),
            lcs_score.data_ptr<float>(),
            beta.data_ptr<float>(),
            posteriors.data_ptr<float>(),
            grad_T.data_ptr<float>(),
            lengths_t.data_ptr<int>(),
            B, L1, L2, T
        );

        // Save for backward
        ctx->save_for_backward({scores_c, alpha, lengths_t});
        ctx->saved_data["temperature"] = temperature;

        return {lcs_score, posteriors};
    }

    static torch::autograd::variable_list backward(
        torch::autograd::AutogradContext* ctx,
        torch::autograd::variable_list grad_outputs
    ) {
        auto saved = ctx->get_saved_variables();
        auto scores = saved[0];
        auto alpha = saved[1];
        auto lengths_t = saved[2];
        double temperature = ctx->saved_data["temperature"].toDouble();
        float T = static_cast<float>(temperature);

        const int B = scores.size(0);
        const int L1 = scores.size(1);
        const int L2 = scores.size(2);

        c10::cuda::CUDAGuard device_guard(scores.device());

        auto grad_posteriors = grad_outputs[1].contiguous();

        // HVP to compute gradient of loss w.r.t. scores
        auto lcs_score = torch::empty({B}, scores.options());
        auto d_alpha = torch::empty_like(alpha);
        auto d_lcs_score = torch::empty({B}, scores.options());
        auto beta = torch::empty_like(alpha);
        auto d_beta = torch::empty_like(alpha);
        auto grad_scores = torch::zeros_like(scores);

        d2p::lcs::lcs_hvp(
            alpha.data_ptr<float>(),
            scores.data_ptr<float>(),
            lcs_score.data_ptr<float>(),
            grad_posteriors.data_ptr<float>(),
            d_alpha.data_ptr<float>(),
            d_lcs_score.data_ptr<float>(),
            beta.data_ptr<float>(),
            d_beta.data_ptr<float>(),
            grad_scores.data_ptr<float>(),
            lengths_t.data_ptr<int>(),
            B, L1, L2, T
        );

        return {grad_scores, torch::Tensor(), torch::Tensor()};
    }
};

// ============================================================================
// Operator Implementations
// ============================================================================

std::vector<torch::Tensor> soft_lcs_cuda(
    torch::Tensor scores,
    torch::Tensor temperature,
    torch::Tensor lengths
) {
    double temp_val = temperature.cpu().item<double>();
    c10::optional<torch::Tensor> lengths_opt;
    if (lengths.numel() > 0) {
        lengths_opt = lengths;
    }
    auto result = SoftLCSCUDAFunction::apply(scores, temp_val, lengths_opt);
    return result;
}

std::vector<torch::Tensor> soft_lcs_float_cuda(
    torch::Tensor scores,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    auto result = SoftLCSCUDAFunction::apply(scores, temperature, lengths);
    return result;
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> soft_lcs_with_grads_cuda(
    torch::Tensor scores,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    TORCH_CHECK(scores.dim() == 3, "scores must be 3D [B, L1, L2]");
    TORCH_CHECK(scores.is_cuda(), "scores must be CUDA tensor");

    const int B = scores.size(0);
    const int L1 = scores.size(1);
    const int L2 = scores.size(2);
    const float T = static_cast<float>(temperature);

    c10::cuda::CUDAGuard device_guard(scores.device());

    auto scores_c = scores.contiguous();
    auto lengths_t = make_lengths_tensor(lengths, B, L1, L2, scores.device());

    auto alpha = torch::empty({B, (L1+1)*(L2+1)}, scores.options());
    auto lcs_score = torch::empty({B}, scores.options());

    d2p::lcs::lcs_forward(
        scores_c.data_ptr<float>(),
        alpha.data_ptr<float>(),
        lcs_score.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        B, L1, L2, T
    );

    auto beta = torch::empty_like(alpha);
    auto posteriors = torch::zeros_like(scores_c);
    auto grad_T = torch::zeros({B}, scores.options());

    d2p::lcs::lcs_backward(
        alpha.data_ptr<float>(),
        scores_c.data_ptr<float>(),
        lcs_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        posteriors.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        B, L1, L2, T
    );

    return std::make_tuple(lcs_score, posteriors, grad_T);
}

torch::Tensor soft_lcs_hvp_cuda(
    torch::Tensor scores,
    torch::Tensor V,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    TORCH_CHECK(scores.dim() == 3, "scores must be 3D [B, L1, L2]");
    TORCH_CHECK(V.dim() == 3, "V must be 3D [B, L1, L2]");
    TORCH_CHECK(scores.is_cuda() && V.is_cuda(), "tensors must be CUDA");

    const int B = scores.size(0);
    const int L1 = scores.size(1);
    const int L2 = scores.size(2);
    const float T = static_cast<float>(temperature);

    c10::cuda::CUDAGuard device_guard(scores.device());

    auto scores_c = scores.contiguous();
    auto V_c = V.contiguous();
    auto lengths_t = make_lengths_tensor(lengths, B, L1, L2, scores.device());

    // Forward pass
    auto alpha = torch::empty({B, (L1+1)*(L2+1)}, scores.options());
    auto lcs_score = torch::empty({B}, scores.options());

    d2p::lcs::lcs_forward(
        scores_c.data_ptr<float>(),
        alpha.data_ptr<float>(),
        lcs_score.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        B, L1, L2, T
    );

    // HVP
    auto d_alpha = torch::empty_like(alpha);
    auto d_lcs_score = torch::empty({B}, scores.options());
    auto beta = torch::empty_like(alpha);
    auto d_beta = torch::empty_like(alpha);
    auto H_scores = torch::zeros_like(scores_c);

    d2p::lcs::lcs_hvp(
        alpha.data_ptr<float>(),
        scores_c.data_ptr<float>(),
        lcs_score.data_ptr<float>(),
        V_c.data_ptr<float>(),
        d_alpha.data_ptr<float>(),
        d_lcs_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        H_scores.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        B, L1, L2, T
    );

    return H_scores;
}

torch::Tensor soft_lcs_param_jacobian_cuda(
    torch::Tensor scores,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    TORCH_CHECK(scores.dim() == 3, "scores must be 3D [B, L1, L2]");
    TORCH_CHECK(scores.is_cuda(), "scores must be CUDA tensor");

    const int B = scores.size(0);
    const int L1 = scores.size(1);
    const int L2 = scores.size(2);
    const float T = static_cast<float>(temperature);

    c10::cuda::CUDAGuard device_guard(scores.device());

    auto scores_c = scores.contiguous();
    auto lengths_t = make_lengths_tensor(lengths, B, L1, L2, scores.device());

    // Forward pass
    auto alpha = torch::empty({B, (L1+1)*(L2+1)}, scores.options());
    auto lcs_score = torch::empty({B}, scores.options());

    d2p::lcs::lcs_forward(
        scores_c.data_ptr<float>(),
        alpha.data_ptr<float>(),
        lcs_score.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        B, L1, L2, T
    );

    // Param grad
    auto U = torch::empty_like(alpha);
    auto beta = torch::empty_like(alpha);
    auto W = torch::empty_like(alpha);
    auto dP_dT = torch::zeros_like(scores_c);

    d2p::lcs::lcs_param_grad(
        alpha.data_ptr<float>(),
        scores_c.data_ptr<float>(),
        lcs_score.data_ptr<float>(),
        U.data_ptr<float>(),
        beta.data_ptr<float>(),
        W.data_ptr<float>(),
        dP_dT.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        B, L1, L2, T
    );

    return dP_dT;
}

std::tuple<torch::Tensor, torch::Tensor> soft_lcs_backward_full_cuda(
    torch::Tensor scores,
    torch::Tensor grad_output,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    TORCH_CHECK(scores.dim() == 3 && grad_output.dim() == 3, "tensors must be 3D");
    TORCH_CHECK(scores.is_cuda() && grad_output.is_cuda(), "tensors must be CUDA");

    const int B = scores.size(0);
    const int L1 = scores.size(1);
    const int L2 = scores.size(2);
    const float T = static_cast<float>(temperature);

    c10::cuda::CUDAGuard device_guard(scores.device());

    auto scores_c = scores.contiguous();
    auto grad_c = grad_output.contiguous();
    auto lengths_t = make_lengths_tensor(lengths, B, L1, L2, scores.device());

    // Forward pass
    auto alpha = torch::empty({B, (L1+1)*(L2+1)}, scores.options());
    auto lcs_score = torch::empty({B}, scores.options());

    d2p::lcs::lcs_forward(
        scores_c.data_ptr<float>(),
        alpha.data_ptr<float>(),
        lcs_score.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        B, L1, L2, T
    );

    // HVP
    auto d_alpha = torch::empty_like(alpha);
    auto d_lcs_score = torch::empty({B}, scores.options());
    auto beta = torch::empty_like(alpha);
    auto d_beta = torch::empty_like(alpha);
    auto grad_scores = torch::zeros_like(scores_c);

    d2p::lcs::lcs_hvp(
        alpha.data_ptr<float>(),
        scores_c.data_ptr<float>(),
        lcs_score.data_ptr<float>(),
        grad_c.data_ptr<float>(),
        d_alpha.data_ptr<float>(),
        d_lcs_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        grad_scores.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        B, L1, L2, T
    );

    // Compute grad_T
    auto beta2 = torch::empty_like(alpha);
    auto posteriors = torch::zeros_like(scores_c);
    auto grad_T = torch::zeros({B}, scores.options());

    d2p::lcs::lcs_backward(
        alpha.data_ptr<float>(),
        scores_c.data_ptr<float>(),
        lcs_score.data_ptr<float>(),
        beta2.data_ptr<float>(),
        posteriors.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        B, L1, L2, T
    );

    // Weight by grad_output
    grad_T = (grad_T * grad_c.sum({1, 2}));

    return std::make_tuple(grad_scores, grad_T);
}

// ============================================================================
// Registration
// ============================================================================

#ifdef USE_TORCH_LIBRARY

TORCH_LIBRARY_IMPL(d2p, CUDA, m) {
    m.impl("soft_lcs", soft_lcs_cuda);
    m.impl("soft_lcs_float", soft_lcs_float_cuda);
    m.impl("soft_lcs_with_grads", soft_lcs_with_grads_cuda);
    m.impl("soft_lcs_hvp", soft_lcs_hvp_cuda);
    m.impl("soft_lcs_param_jacobian", soft_lcs_param_jacobian_cuda);
    m.impl("soft_lcs_backward_full", soft_lcs_backward_full_cuda);
}

TORCH_LIBRARY_IMPL(d2p, AutogradCUDA, m) {
    m.impl("soft_lcs", soft_lcs_cuda);
    m.impl("soft_lcs_float", soft_lcs_float_cuda);
}

#endif
