/**
 * @file torch_cpu.cpp
 * @brief Soft OSA CPU PyTorch Bindings
 *
 * Provides torch.ops.d2p.soft_osa* operators for CPU tensors.
 */

#include <torch/extension.h>
#include <vector>
#include <tuple>

#include "kernels_cpu.h"

// ============================================================================
// Helper Macros
// ============================================================================

#define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT_CPU(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)

// ============================================================================
// Helper Functions
// ============================================================================

static torch::Tensor make_default_lengths(int B, int L1, int L2) {
    auto options = torch::TensorOptions().dtype(torch::kInt32);
    auto lengths = torch::empty({B, 2}, options);
    auto acc = lengths.accessor<int32_t, 2>();
    for (int b = 0; b < B; b++) {
        acc[b][0] = L1;
        acc[b][1] = L2;
    }
    return lengths;
}

// ============================================================================
// Autograd Function
// ============================================================================

class SoftOSACPUFunction : public torch::autograd::Function<SoftOSACPUFunction> {
public:
    static torch::autograd::tensor_list forward(
        torch::autograd::AutogradContext *ctx,
        torch::Tensor sub_costs,
        torch::Tensor trans_mask,
        torch::Tensor ins_cost_t,
        torch::Tensor del_cost_t,
        torch::Tensor trans_cost_t,
        torch::Tensor temperature,
        torch::Tensor lengths
    ) {
        CHECK_INPUT_CPU(sub_costs);
        CHECK_INPUT_CPU(trans_mask);
        TORCH_CHECK(sub_costs.dim() == 3, "sub_costs must be 3D (B, L1, L2)");
        TORCH_CHECK(trans_mask.dim() == 3, "trans_mask must be 3D (B, L1, L2)");
        TORCH_CHECK(sub_costs.dtype() == torch::kFloat32, "sub_costs must be float32");
        TORCH_CHECK(trans_mask.dtype() == torch::kFloat32, "trans_mask must be float32");
        TORCH_CHECK(temperature.numel() == 1, "temperature must be a scalar tensor");

        int B = sub_costs.size(0);
        int max_L1 = sub_costs.size(1);
        int max_L2 = sub_costs.size(2);
        int alpha_size = (max_L1 + 1) * (max_L2 + 1);

        CHECK_INPUT_CPU(lengths);
        TORCH_CHECK(lengths.dim() == 2 && lengths.size(0) == B && lengths.size(1) == 2);
        TORCH_CHECK(lengths.dtype() == torch::kInt32);

        float temp_val = temperature.item<float>();
        float ins_cost = ins_cost_t.item<float>();
        float del_cost = del_cost_t.item<float>();
        float trans_cost = trans_cost_t.item<float>();

        auto options = sub_costs.options();
        torch::Tensor alpha = torch::zeros({B, alpha_size}, options);
        torch::Tensor osa_score = torch::zeros({B}, options);
        torch::Tensor beta = torch::zeros({B, alpha_size}, options);
        torch::Tensor posteriors = torch::zeros({B, max_L1, max_L2}, options);
        torch::Tensor grad_T = torch::zeros({B}, options);
        torch::Tensor grad_ins = torch::zeros({B}, options);
        torch::Tensor grad_del = torch::zeros({B}, options);
        torch::Tensor grad_trans = torch::zeros({B}, options);

        d2p::osa::cpu::osa_forward_cpu(
            sub_costs.data_ptr<float>(),
            trans_mask.data_ptr<float>(),
            alpha.data_ptr<float>(),
            osa_score.data_ptr<float>(),
            lengths.data_ptr<int>(),
            ins_cost, del_cost, trans_cost,
            B, max_L1, max_L2, temp_val
        );

        d2p::osa::cpu::osa_backward_cpu(
            alpha.data_ptr<float>(),
            sub_costs.data_ptr<float>(),
            trans_mask.data_ptr<float>(),
            osa_score.data_ptr<float>(),
            beta.data_ptr<float>(),
            posteriors.data_ptr<float>(),
            grad_T.data_ptr<float>(),
            grad_ins.data_ptr<float>(),
            grad_del.data_ptr<float>(),
            grad_trans.data_ptr<float>(),
            lengths.data_ptr<int>(),
            ins_cost, del_cost, trans_cost,
            B, max_L1, max_L2, temp_val
        );

        ctx->save_for_backward({sub_costs.clone(), trans_mask.clone(), alpha.clone(), osa_score.clone(), lengths.clone(), grad_T.clone()});
        ctx->saved_data["temperature"] = temp_val;
        ctx->saved_data["ins_cost"] = ins_cost;
        ctx->saved_data["del_cost"] = del_cost;
        ctx->saved_data["trans_cost"] = trans_cost;

        return {osa_score, posteriors};
    }

    static torch::autograd::tensor_list backward(
        torch::autograd::AutogradContext *ctx,
        torch::autograd::tensor_list grad_outputs
    ) {
        auto saved = ctx->get_saved_variables();
        torch::Tensor sub_costs = saved[0];
        torch::Tensor trans_mask = saved[1];
        torch::Tensor alpha = saved[2];
        torch::Tensor osa_score = saved[3];
        torch::Tensor lengths = saved[4];
        torch::Tensor grad_T_fwd = saved[5];

        float temp_val = static_cast<float>(ctx->saved_data["temperature"].toDouble());
        float ins_cost = static_cast<float>(ctx->saved_data["ins_cost"].toDouble());
        float del_cost = static_cast<float>(ctx->saved_data["del_cost"].toDouble());
        float trans_cost = static_cast<float>(ctx->saved_data["trans_cost"].toDouble());

        int B = sub_costs.size(0);
        int max_L1 = sub_costs.size(1);
        int max_L2 = sub_costs.size(2);
        int alpha_size = (max_L1 + 1) * (max_L2 + 1);

        auto options = sub_costs.options();

        torch::Tensor grad_osa_score = grad_outputs[0];
        torch::Tensor grad_posteriors = grad_outputs[1];

        torch::Tensor grad_sub_costs = torch::zeros({B, max_L1, max_L2}, options);
        torch::Tensor total_grad_T = torch::zeros({1}, options);

        // Gradient from osa_score path
        if (grad_osa_score.defined() && grad_osa_score.numel() > 0) {
            torch::Tensor beta = torch::zeros({B, alpha_size}, options);
            torch::Tensor posteriors = torch::zeros({B, max_L1, max_L2}, options);
            torch::Tensor tmp_T = torch::zeros({B}, options);
            torch::Tensor tmp_ins = torch::zeros({B}, options);
            torch::Tensor tmp_del = torch::zeros({B}, options);
            torch::Tensor tmp_trans = torch::zeros({B}, options);

            d2p::osa::cpu::osa_backward_cpu(
                alpha.data_ptr<float>(),
                sub_costs.data_ptr<float>(),
                trans_mask.data_ptr<float>(),
                osa_score.data_ptr<float>(),
                beta.data_ptr<float>(),
                posteriors.data_ptr<float>(),
                tmp_T.data_ptr<float>(),
                tmp_ins.data_ptr<float>(),
                tmp_del.data_ptr<float>(),
                tmp_trans.data_ptr<float>(),
                lengths.data_ptr<int>(),
                ins_cost, del_cost, trans_cost,
                B, max_L1, max_L2, temp_val
            );

            grad_sub_costs += grad_osa_score.view({B, 1, 1}) * posteriors;
            total_grad_T += (grad_osa_score * grad_T_fwd).sum().reshape({1});
        }

        // Gradient from posteriors path (HVP)
        if (grad_posteriors.defined() && grad_posteriors.numel() > 0) {
            grad_posteriors = grad_posteriors.contiguous().to(torch::kFloat32);

            torch::Tensor d_alpha = torch::zeros({B, alpha_size}, options);
            torch::Tensor d_osa_score = torch::zeros({B}, options);
            torch::Tensor beta = torch::zeros({B, alpha_size}, options);
            torch::Tensor d_beta = torch::zeros({B, alpha_size}, options);
            torch::Tensor hvp_grad_sub_costs = torch::zeros({B, max_L1, max_L2}, options);

            d2p::osa::cpu::osa_hvp_cpu(
                alpha.data_ptr<float>(),
                sub_costs.data_ptr<float>(),
                trans_mask.data_ptr<float>(),
                osa_score.data_ptr<float>(),
                grad_posteriors.data_ptr<float>(),
                d_alpha.data_ptr<float>(),
                d_osa_score.data_ptr<float>(),
                beta.data_ptr<float>(),
                d_beta.data_ptr<float>(),
                hvp_grad_sub_costs.data_ptr<float>(),
                lengths.data_ptr<int>(),
                ins_cost, del_cost, trans_cost,
                B, max_L1, max_L2, temp_val
            );

            grad_sub_costs += hvp_grad_sub_costs;
        }

        // Return gradients: sub_costs, trans_mask, ins_cost, del_cost, trans_cost, temperature, lengths
        return {grad_sub_costs, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), total_grad_T, torch::Tensor()};
    }
};

// ============================================================================
// Operator Implementations
// ============================================================================

std::vector<torch::Tensor> soft_osa_cpu(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    torch::Tensor ins_cost,
    torch::Tensor del_cost,
    torch::Tensor trans_cost,
    torch::Tensor temperature,
    torch::Tensor lengths
) {
    return SoftOSACPUFunction::apply(sub_costs, trans_mask, ins_cost, del_cost, trans_cost, temperature, lengths);
}

std::vector<torch::Tensor> soft_osa_cpu_float(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    int B = sub_costs.size(0);
    int L1 = sub_costs.size(1);
    int L2 = sub_costs.size(2);

    auto options = sub_costs.options();
    torch::Tensor temp_t = torch::tensor({static_cast<float>(temperature)}, options);
    torch::Tensor ins_t = torch::tensor({static_cast<float>(ins_cost)}, options);
    torch::Tensor del_t = torch::tensor({static_cast<float>(del_cost)}, options);
    torch::Tensor trans_t = torch::tensor({static_cast<float>(trans_cost)}, options);
    torch::Tensor lengths = lengths_opt.has_value() ? lengths_opt.value()
                                                    : make_default_lengths(B, L1, L2);

    return SoftOSACPUFunction::apply(sub_costs, trans_mask, ins_t, del_t, trans_t, temp_t, lengths);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
soft_osa_cpu_with_grads(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT_CPU(sub_costs);
    CHECK_INPUT_CPU(trans_mask);
    TORCH_CHECK(sub_costs.dim() == 3, "sub_costs must be 3D (B, L1, L2)");
    TORCH_CHECK(trans_mask.dim() == 3, "trans_mask must be 3D (B, L1, L2)");

    int B = sub_costs.size(0);
    int max_L1 = sub_costs.size(1);
    int max_L2 = sub_costs.size(2);
    int alpha_size = (max_L1 + 1) * (max_L2 + 1);

    auto options = sub_costs.options();
    torch::Tensor lengths = lengths_opt.has_value() ? lengths_opt.value()
                                                    : make_default_lengths(B, max_L1, max_L2);

    CHECK_INPUT_CPU(lengths);
    TORCH_CHECK(lengths.dim() == 2 && lengths.size(0) == B && lengths.size(1) == 2);
    TORCH_CHECK(lengths.dtype() == torch::kInt32);

    float temp_val = static_cast<float>(temperature);
    float ins = static_cast<float>(ins_cost);
    float del = static_cast<float>(del_cost);
    float trans = static_cast<float>(trans_cost);

    torch::Tensor alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor osa_score = torch::zeros({B}, options);
    torch::Tensor beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor posteriors = torch::zeros({B, max_L1, max_L2}, options);
    torch::Tensor grad_T = torch::zeros({B}, options);
    torch::Tensor grad_ins_out = torch::zeros({B}, options);
    torch::Tensor grad_del_out = torch::zeros({B}, options);
    torch::Tensor grad_trans_out = torch::zeros({B}, options);

    d2p::osa::cpu::osa_forward_cpu(
        sub_costs.data_ptr<float>(),
        trans_mask.data_ptr<float>(),
        alpha.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    d2p::osa::cpu::osa_backward_cpu(
        alpha.data_ptr<float>(),
        sub_costs.data_ptr<float>(),
        trans_mask.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        posteriors.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        grad_ins_out.data_ptr<float>(),
        grad_del_out.data_ptr<float>(),
        grad_trans_out.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    return std::make_tuple(osa_score, posteriors, grad_T, grad_ins_out, grad_del_out, grad_trans_out);
}

torch::Tensor soft_osa_hvp_cpu(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    torch::Tensor tangent,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT_CPU(sub_costs);
    CHECK_INPUT_CPU(trans_mask);
    CHECK_INPUT_CPU(tangent);
    TORCH_CHECK(sub_costs.dim() == 3, "sub_costs must be 3D");
    TORCH_CHECK(trans_mask.dim() == 3, "trans_mask must be 3D");
    TORCH_CHECK(sub_costs.sizes() == tangent.sizes(), "sub_costs and tangent must have same shape");

    int B = sub_costs.size(0);
    int max_L1 = sub_costs.size(1);
    int max_L2 = sub_costs.size(2);
    int alpha_size = (max_L1 + 1) * (max_L2 + 1);

    auto options = sub_costs.options();
    torch::Tensor lengths = lengths_opt.has_value() ? lengths_opt.value()
                                                    : make_default_lengths(B, max_L1, max_L2);

    CHECK_INPUT_CPU(lengths);
    TORCH_CHECK(lengths.dim() == 2 && lengths.size(0) == B && lengths.size(1) == 2);
    TORCH_CHECK(lengths.dtype() == torch::kInt32);

    float temp_val = static_cast<float>(temperature);
    float ins = static_cast<float>(ins_cost);
    float del = static_cast<float>(del_cost);
    float trans = static_cast<float>(trans_cost);

    torch::Tensor alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor osa_score = torch::zeros({B}, options);
    torch::Tensor d_alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor d_osa_score = torch::zeros({B}, options);
    torch::Tensor beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor d_beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor H_scores = torch::zeros({B, max_L1, max_L2}, options);

    d2p::osa::cpu::osa_forward_cpu(
        sub_costs.data_ptr<float>(),
        trans_mask.data_ptr<float>(),
        alpha.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    d2p::osa::cpu::osa_hvp_cpu(
        alpha.data_ptr<float>(),
        sub_costs.data_ptr<float>(),
        trans_mask.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        tangent.data_ptr<float>(),
        d_alpha.data_ptr<float>(),
        d_osa_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        H_scores.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    return H_scores;
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
soft_osa_backward_full_cpu(
    torch::Tensor sub_costs,
    torch::Tensor trans_mask,
    torch::Tensor grad_output,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT_CPU(sub_costs);
    CHECK_INPUT_CPU(trans_mask);
    TORCH_CHECK(sub_costs.dim() == 3 && trans_mask.dim() == 3 && grad_output.dim() == 3, "tensors must be 3D");

    int B = sub_costs.size(0);
    int max_L1 = sub_costs.size(1);
    int max_L2 = sub_costs.size(2);
    int alpha_size = (max_L1 + 1) * (max_L2 + 1);

    auto options = sub_costs.options();
    torch::Tensor lengths = lengths_opt.has_value() ? lengths_opt.value()
                                                    : make_default_lengths(B, max_L1, max_L2);

    CHECK_INPUT_CPU(lengths);
    TORCH_CHECK(lengths.dim() == 2 && lengths.size(0) == B && lengths.size(1) == 2);
    TORCH_CHECK(lengths.dtype() == torch::kInt32);

    grad_output = grad_output.contiguous().to(torch::kFloat32);

    float temp_val = static_cast<float>(temperature);
    float ins = static_cast<float>(ins_cost);
    float del = static_cast<float>(del_cost);
    float trans = static_cast<float>(trans_cost);

    // Forward pass
    torch::Tensor alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor osa_score = torch::zeros({B}, options);

    d2p::osa::cpu::osa_forward_cpu(
        sub_costs.data_ptr<float>(),
        trans_mask.data_ptr<float>(),
        alpha.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    // HVP for grad_sub_costs
    torch::Tensor d_alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor d_osa_score = torch::zeros({B}, options);
    torch::Tensor beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor d_beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor grad_sub_costs = torch::zeros({B, max_L1, max_L2}, options);

    d2p::osa::cpu::osa_hvp_cpu(
        alpha.data_ptr<float>(),
        sub_costs.data_ptr<float>(),
        trans_mask.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        grad_output.data_ptr<float>(),
        d_alpha.data_ptr<float>(),
        d_osa_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        grad_sub_costs.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    // Get cost gradients from backward
    torch::Tensor beta2 = torch::zeros({B, alpha_size}, options);
    torch::Tensor posteriors = torch::zeros({B, max_L1, max_L2}, options);
    torch::Tensor grad_T = torch::zeros({B}, options);
    torch::Tensor grad_ins_out = torch::zeros({B}, options);
    torch::Tensor grad_del_out = torch::zeros({B}, options);
    torch::Tensor grad_trans_out = torch::zeros({B}, options);

    d2p::osa::cpu::osa_backward_cpu(
        alpha.data_ptr<float>(),
        sub_costs.data_ptr<float>(),
        trans_mask.data_ptr<float>(),
        osa_score.data_ptr<float>(),
        beta2.data_ptr<float>(),
        posteriors.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        grad_ins_out.data_ptr<float>(),
        grad_del_out.data_ptr<float>(),
        grad_trans_out.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    // Weight by grad_output
    torch::Tensor total_grad_T = (grad_T * grad_output.sum({1, 2}));
    torch::Tensor total_grad_ins = (grad_ins_out * grad_output.sum({1, 2}));
    torch::Tensor total_grad_del = (grad_del_out * grad_output.sum({1, 2}));
    torch::Tensor total_grad_trans = (grad_trans_out * grad_output.sum({1, 2}));

    return std::make_tuple(grad_sub_costs, total_grad_T, total_grad_ins, total_grad_del, total_grad_trans);
}

// ============================================================================
// Registration
// ============================================================================

#ifdef USE_TORCH_LIBRARY

TORCH_LIBRARY_IMPL(d2p, CPU, m) {
    m.impl("soft_osa", soft_osa_cpu);
    m.impl("soft_osa_float", soft_osa_cpu_float);
    m.impl("soft_osa_with_grads", soft_osa_cpu_with_grads);
    m.impl("soft_osa_hvp", soft_osa_hvp_cpu);
    m.impl("soft_osa_backward_full", soft_osa_backward_full_cpu);
}

TORCH_LIBRARY_IMPL(d2p, AutogradCPU, m) {
    m.impl("soft_osa", soft_osa_cpu);
    m.impl("soft_osa_float", soft_osa_cpu_float);
}

#endif // USE_TORCH_LIBRARY
