/**
 * @file torch_cuda.cpp
 * @brief Soft OSA CUDA PyTorch Bindings
 *
 * Provides torch.ops.d2p.soft_osa* operators for CUDA tensors.
 */

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <vector>
#include <tuple>

#include "kernels.cuh"

// ============================================================================
// Helper Functions
// ============================================================================

static torch::Tensor make_lengths_tensor(
    const c10::optional<torch::Tensor>& lengths,
    int B, int L1, int L2, torch::Device device
) {
    if (lengths.has_value()) {
        return lengths.value().to(device).to(torch::kInt32).contiguous();
    }
    auto lens = torch::empty({B, 2}, torch::dtype(torch::kInt32).device(device));
    lens.select(1, 0).fill_(L1);
    lens.select(1, 1).fill_(L2);
    return lens;
}

// ============================================================================
// Autograd Function
// ============================================================================

class SoftOSACUDAFunction : public torch::autograd::Function<SoftOSACUDAFunction> {
public:
    static torch::autograd::variable_list forward(
        torch::autograd::AutogradContext* ctx,
        torch::Tensor sub_costs,
        torch::Tensor trans_mask,
        double ins_cost,
        double del_cost,
        double trans_cost,
        double temperature,
        c10::optional<torch::Tensor> lengths
    ) {
        TORCH_CHECK(sub_costs.dim() == 3, "sub_costs must be 3D [B, L1, L2]");
        TORCH_CHECK(trans_mask.dim() == 3, "trans_mask must be 3D [B, L1, L2]");
        TORCH_CHECK(sub_costs.is_cuda(), "sub_costs must be CUDA tensor");
        TORCH_CHECK(trans_mask.is_cuda(), "trans_mask must be CUDA tensor");
        TORCH_CHECK(sub_costs.scalar_type() == torch::kFloat32, "sub_costs must be float32");
        TORCH_CHECK(trans_mask.scalar_type() == torch::kFloat32, "trans_mask must be float32");

        const int B = sub_costs.size(0);
        const int L1 = sub_costs.size(1);
        const int L2 = sub_costs.size(2);
        const float T = static_cast<float>(temperature);
        const float ins = static_cast<float>(ins_cost);
        const float del = static_cast<float>(del_cost);
        const float trans = static_cast<float>(trans_cost);

        c10::cuda::CUDAGuard device_guard(sub_costs.device());

        auto sub_costs_c = sub_costs.contiguous();
        auto trans_mask_c = trans_mask.contiguous();
        auto lengths_t = make_lengths_tensor(lengths, B, L1, L2, sub_costs.device());

        // Allocate alpha and osa_score
        auto alpha = torch::empty({B, (L1+1)*(L2+1)}, sub_costs.options());
        auto osa_score = torch::empty({B}, sub_costs.options());

        // Forward pass
        d2p::osa::osa_forward(
            sub_costs_c.data_ptr<float>(),
            trans_mask_c.data_ptr<float>(),
            alpha.data_ptr<float>(),
            osa_score.data_ptr<float>(),
            lengths_t.data_ptr<int>(),
            ins, del, trans,
            B, L1, L2, T
        );

        // Backward pass to get posteriors
        auto beta = torch::empty_like(alpha);
        auto posteriors = torch::zeros_like(sub_costs_c);
        auto grad_T = torch::zeros({B}, sub_costs.options());
        auto grad_ins = torch::zeros({B}, sub_costs.options());
        auto grad_del = torch::zeros({B}, sub_costs.options());
        auto grad_trans = torch::zeros({B}, sub_costs.options());

        d2p::osa::osa_backward(
            alpha.data_ptr<float>(),
            sub_costs_c.data_ptr<float>(),
            trans_mask_c.data_ptr<float>(),
            osa_score.data_ptr<float>(),
            beta.data_ptr<float>(),
            posteriors.data_ptr<float>(),
            grad_T.data_ptr<float>(),
            grad_ins.data_ptr<float>(),
            grad_del.data_ptr<float>(),
            grad_trans.data_ptr<float>(),
            lengths_t.data_ptr<int>(),
            ins, del, trans,
            B, L1, L2, T
        );

        // Save for backward
        ctx->save_for_backward({sub_costs_c, trans_mask_c, alpha, lengths_t});
        ctx->saved_data["temperature"] = temperature;
        ctx->saved_data["ins_cost"] = ins_cost;
        ctx->saved_data["del_cost"] = del_cost;
        ctx->saved_data["trans_cost"] = trans_cost;

        return {osa_score, posteriors};
    }

    static torch::autograd::variable_list backward(
        torch::autograd::AutogradContext* ctx,
        torch::autograd::variable_list grad_outputs
    ) {
        auto saved = ctx->get_saved_variables();
        auto sub_costs = saved[0];
        auto trans_mask = saved[1];
        auto alpha = saved[2];
        auto lengths_t = saved[3];

        double temperature = ctx->saved_data["temperature"].toDouble();
        double ins_cost = ctx->saved_data["ins_cost"].toDouble();
        double del_cost = ctx->saved_data["del_cost"].toDouble();
        double trans_cost_val = ctx->saved_data["trans_cost"].toDouble();

        float T = static_cast<float>(temperature);
        float ins = static_cast<float>(ins_cost);
        float del = static_cast<float>(del_cost);
        float trans = static_cast<float>(trans_cost_val);

        const int B = sub_costs.size(0);
        const int L1 = sub_costs.size(1);
        const int L2 = sub_costs.size(2);

        c10::cuda::CUDAGuard device_guard(sub_costs.device());

        auto grad_posteriors = grad_outputs[1].contiguous();

        // HVP to compute gradient of loss w.r.t. sub_costs
        auto osa_score = torch::empty({B}, sub_costs.options());
        auto d_alpha = torch::empty_like(alpha);
        auto d_osa_score = torch::empty({B}, sub_costs.options());
        auto beta = torch::empty_like(alpha);
        auto d_beta = torch::empty_like(alpha);
        auto grad_sub_costs = torch::zeros_like(sub_costs);

        d2p::osa::osa_hvp(
            alpha.data_ptr<float>(),
            sub_costs.data_ptr<float>(),
            trans_mask.data_ptr<float>(),
            osa_score.data_ptr<float>(),
            grad_posteriors.data_ptr<float>(),
            d_alpha.data_ptr<float>(),
            d_osa_score.data_ptr<float>(),
            beta.data_ptr<float>(),
            d_beta.data_ptr<float>(),
            grad_sub_costs.data_ptr<float>(),
            lengths_t.data_ptr<int>(),
            ins, del, trans,
            B, L1, L2, T
        );

        // Return gradients: sub_costs, trans_mask, ins_cost, del_cost, trans_cost, temperature, lengths
        return {grad_sub_costs, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()};
    }
};

// ============================================================================
// Operator Implementations
// ============================================================================

std::vector<torch::Tensor> soft_osa_cuda(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    torch::Tensor ins_cost_t,
    torch::Tensor del_cost_t,
    torch::Tensor trans_cost_t,
    torch::Tensor temperature_t,
    torch::Tensor lengths
) {
    double ins_cost = ins_cost_t.cpu().item<double>();
    double del_cost = del_cost_t.cpu().item<double>();
    double trans_cost = trans_cost_t.cpu().item<double>();
    double temp_val = temperature_t.cpu().item<double>();
    c10::optional<torch::Tensor> lengths_opt;
    if (lengths.numel() > 0) {
        lengths_opt = lengths;
    }
    auto result = SoftOSACUDAFunction::apply(sub_costs, trans_mask, ins_cost, del_cost, trans_cost, temp_val, lengths_opt);
    return result;
}

std::vector<torch::Tensor> soft_osa_float_cuda(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    auto result = SoftOSACUDAFunction::apply(sub_costs, trans_mask, ins_cost, del_cost, trans_cost, temperature, lengths);
    return result;
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
soft_osa_with_grads_cuda(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    TORCH_CHECK(sub_costs.dim() == 3, "sub_costs must be 3D [B, L1, L2]");
    TORCH_CHECK(trans_mask.dim() == 3, "trans_mask must be 3D [B, L1, L2]");
    TORCH_CHECK(sub_costs.is_cuda() && trans_mask.is_cuda(), "tensors must be CUDA");

    const int B = sub_costs.size(0);
    const int L1 = sub_costs.size(1);
    const int L2 = sub_costs.size(2);
    const float T = static_cast<float>(temperature);
    const float ins = static_cast<float>(ins_cost);
    const float del = static_cast<float>(del_cost);
    const float trans = static_cast<float>(trans_cost);

    c10::cuda::CUDAGuard device_guard(sub_costs.device());

    auto sub_costs_c = sub_costs.contiguous();
    auto trans_mask_c = trans_mask.contiguous();
    auto lengths_t = make_lengths_tensor(lengths, B, L1, L2, sub_costs.device());

    auto alpha = torch::empty({B, (L1+1)*(L2+1)}, sub_costs.options());
    auto osa_score = torch::empty({B}, sub_costs.options());

    d2p::osa::osa_forward(
        sub_costs_c.data_ptr<float>(),
        trans_mask_c.data_ptr<float>(),
        alpha.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        ins, del, trans,
        B, L1, L2, T
    );

    auto beta = torch::empty_like(alpha);
    auto posteriors = torch::zeros_like(sub_costs_c);
    auto grad_T = torch::zeros({B}, sub_costs.options());
    auto grad_ins_out = torch::zeros({B}, sub_costs.options());
    auto grad_del_out = torch::zeros({B}, sub_costs.options());
    auto grad_trans_out = torch::zeros({B}, sub_costs.options());

    d2p::osa::osa_backward(
        alpha.data_ptr<float>(),
        sub_costs_c.data_ptr<float>(),
        trans_mask_c.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        posteriors.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        grad_ins_out.data_ptr<float>(),
        grad_del_out.data_ptr<float>(),
        grad_trans_out.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        ins, del, trans,
        B, L1, L2, T
    );

    return std::make_tuple(osa_score, posteriors, grad_T, grad_ins_out, grad_del_out, grad_trans_out);
}

torch::Tensor soft_osa_hvp_cuda(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    torch::Tensor V,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    TORCH_CHECK(sub_costs.dim() == 3, "sub_costs must be 3D [B, L1, L2]");
    TORCH_CHECK(trans_mask.dim() == 3, "trans_mask must be 3D [B, L1, L2]");
    TORCH_CHECK(V.dim() == 3, "V must be 3D [B, L1, L2]");
    TORCH_CHECK(sub_costs.is_cuda() && trans_mask.is_cuda() && V.is_cuda(), "tensors must be CUDA");

    const int B = sub_costs.size(0);
    const int L1 = sub_costs.size(1);
    const int L2 = sub_costs.size(2);
    const float T = static_cast<float>(temperature);
    const float ins = static_cast<float>(ins_cost);
    const float del = static_cast<float>(del_cost);
    const float trans = static_cast<float>(trans_cost);

    c10::cuda::CUDAGuard device_guard(sub_costs.device());

    auto sub_costs_c = sub_costs.contiguous();
    auto trans_mask_c = trans_mask.contiguous();
    auto V_c = V.contiguous();
    auto lengths_t = make_lengths_tensor(lengths, B, L1, L2, sub_costs.device());

    // Forward pass
    auto alpha = torch::empty({B, (L1+1)*(L2+1)}, sub_costs.options());
    auto osa_score = torch::empty({B}, sub_costs.options());

    d2p::osa::osa_forward(
        sub_costs_c.data_ptr<float>(),
        trans_mask_c.data_ptr<float>(),
        alpha.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        ins, del, trans,
        B, L1, L2, T
    );

    // HVP
    auto d_alpha = torch::empty_like(alpha);
    auto d_osa_score = torch::empty({B}, sub_costs.options());
    auto beta = torch::empty_like(alpha);
    auto d_beta = torch::empty_like(alpha);
    auto H_scores = torch::zeros_like(sub_costs_c);

    d2p::osa::osa_hvp(
        alpha.data_ptr<float>(),
        sub_costs_c.data_ptr<float>(),
        trans_mask_c.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        V_c.data_ptr<float>(),
        d_alpha.data_ptr<float>(),
        d_osa_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        H_scores.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        ins, del, trans,
        B, L1, L2, T
    );

    return H_scores;
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
soft_osa_backward_full_cuda(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    torch::Tensor grad_output,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths
) {
    TORCH_CHECK(sub_costs.dim() == 3 && trans_mask.dim() == 3 && grad_output.dim() == 3, "tensors must be 3D");
    TORCH_CHECK(sub_costs.is_cuda() && trans_mask.is_cuda() && grad_output.is_cuda(), "tensors must be CUDA");

    const int B = sub_costs.size(0);
    const int L1 = sub_costs.size(1);
    const int L2 = sub_costs.size(2);
    const float T = static_cast<float>(temperature);
    const float ins = static_cast<float>(ins_cost);
    const float del = static_cast<float>(del_cost);
    const float trans = static_cast<float>(trans_cost);

    c10::cuda::CUDAGuard device_guard(sub_costs.device());

    auto sub_costs_c = sub_costs.contiguous();
    auto trans_mask_c = trans_mask.contiguous();
    auto grad_c = grad_output.contiguous();
    auto lengths_t = make_lengths_tensor(lengths, B, L1, L2, sub_costs.device());

    // Forward pass
    auto alpha = torch::empty({B, (L1+1)*(L2+1)}, sub_costs.options());
    auto osa_score = torch::empty({B}, sub_costs.options());

    d2p::osa::osa_forward(
        sub_costs_c.data_ptr<float>(),
        trans_mask_c.data_ptr<float>(),
        alpha.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        ins, del, trans,
        B, L1, L2, T
    );

    // HVP for grad_sub_costs
    auto d_alpha = torch::empty_like(alpha);
    auto d_osa_score = torch::empty({B}, sub_costs.options());
    auto beta = torch::empty_like(alpha);
    auto d_beta = torch::empty_like(alpha);
    auto grad_sub_costs = torch::zeros_like(sub_costs_c);

    d2p::osa::osa_hvp(
        alpha.data_ptr<float>(),
        sub_costs_c.data_ptr<float>(),
        trans_mask_c.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        grad_c.data_ptr<float>(),
        d_alpha.data_ptr<float>(),
        d_osa_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        grad_sub_costs.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        ins, del, trans,
        B, L1, L2, T
    );

    // Get cost gradients from backward
    auto beta2 = torch::empty_like(alpha);
    auto posteriors = torch::zeros_like(sub_costs_c);
    auto grad_T = torch::zeros({B}, sub_costs.options());
    auto grad_ins_out = torch::zeros({B}, sub_costs.options());
    auto grad_del_out = torch::zeros({B}, sub_costs.options());
    auto grad_trans_out = torch::zeros({B}, sub_costs.options());

    d2p::osa::osa_backward(
        alpha.data_ptr<float>(),
        sub_costs_c.data_ptr<float>(),
        trans_mask_c.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        beta2.data_ptr<float>(),
        posteriors.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        grad_ins_out.data_ptr<float>(),
        grad_del_out.data_ptr<float>(),
        grad_trans_out.data_ptr<float>(),
        lengths_t.data_ptr<int>(),
        ins, del, trans,
        B, L1, L2, T
    );

    // Weight by grad_output
    grad_T = (grad_T * grad_c.sum({1, 2}));
    grad_ins_out = (grad_ins_out * grad_c.sum({1, 2}));
    grad_del_out = (grad_del_out * grad_c.sum({1, 2}));
    grad_trans_out = (grad_trans_out * grad_c.sum({1, 2}));

    return std::make_tuple(grad_sub_costs, grad_T, grad_ins_out, grad_del_out, grad_trans_out);
}

// ============================================================================
// Registration
// ============================================================================

#ifdef USE_TORCH_LIBRARY

TORCH_LIBRARY_IMPL(d2p, CUDA, m) {
    m.impl("soft_osa", soft_osa_cuda);
    m.impl("soft_osa_float", soft_osa_float_cuda);
    m.impl("soft_osa_with_grads", soft_osa_with_grads_cuda);
    m.impl("soft_osa_hvp", soft_osa_hvp_cuda);
    m.impl("soft_osa_backward_full", soft_osa_backward_full_cuda);
}

TORCH_LIBRARY_IMPL(d2p, AutogradCUDA, m) {
    m.impl("soft_osa", soft_osa_cuda);
    m.impl("soft_osa_float", soft_osa_float_cuda);
}

#endif
