/**
 * @file torch_cpu.cpp
 * @brief Soft True Damerau-Levenshtein CPU PyTorch Bindings
 *
 * CPU implementations that mirror the CUDA interface.
 * Registered with TORCH_LIBRARY_IMPL for automatic dispatch.
 *
 * Damerau uses SOFTMIN (minimization) with 4-way transitions: substitute, delete, insert, transpose
 * Unlike OSA, transpositions can span variable distances via precomputed trans_src indices.
 */

#include "kernels_cpu.h"
#include <torch/extension.h>
#include <vector>
#include <tuple>

namespace d2p {
namespace damerau {

// ============================================================================
// Helper Macros and Functions
// ============================================================================

#define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT_CPU(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)

static torch::Tensor make_default_lengths_cpu(int B, int L1, int L2) {
    auto options = torch::TensorOptions().dtype(torch::kInt32);
    auto lengths = torch::empty({B, 2}, options);
    auto acc = lengths.accessor<int32_t, 2>();
    for (int b = 0; b < B; b++) {
        acc[b][0] = L1;
        acc[b][1] = L2;
    }
    return lengths;
}

// ============================================================================
// Autograd Function
// ============================================================================

class SoftDamerauCPUFunction : public torch::autograd::Function<SoftDamerauCPUFunction> {
public:
    static torch::autograd::tensor_list forward(
        torch::autograd::AutogradContext *ctx,
        torch::Tensor sub_costs,
        torch::Tensor trans_src,
        torch::Tensor ins_cost_t,
        torch::Tensor del_cost_t,
        torch::Tensor trans_cost_t,
        torch::Tensor temperature,
        torch::Tensor lengths
    ) {
        CHECK_INPUT_CPU(sub_costs);
        CHECK_INPUT_CPU(trans_src);
        TORCH_CHECK(sub_costs.dim() == 3, "sub_costs must be 3D (B, L1, L2)");
        TORCH_CHECK(trans_src.dim() == 4, "trans_src must be 4D (B, L1, L2, 2)");
        TORCH_CHECK(sub_costs.dtype() == torch::kFloat32, "sub_costs must be float32");
        TORCH_CHECK(trans_src.dtype() == torch::kInt32, "trans_src must be int32");
        TORCH_CHECK(temperature.numel() == 1, "temperature must be a scalar tensor");

        int B = sub_costs.size(0);
        int max_L1 = sub_costs.size(1);
        int max_L2 = sub_costs.size(2);
        int alpha_size = (max_L1 + 1) * (max_L2 + 1);

        CHECK_INPUT_CPU(lengths);
        TORCH_CHECK(lengths.dim() == 2 && lengths.size(0) == B && lengths.size(1) == 2);
        TORCH_CHECK(lengths.dtype() == torch::kInt32);

        float temp_val = temperature.item<float>();
        float ins_cost = ins_cost_t.item<float>();
        float del_cost = del_cost_t.item<float>();
        float trans_cost = trans_cost_t.item<float>();

        auto options = sub_costs.options();
        torch::Tensor alpha = torch::zeros({B, alpha_size}, options);
        torch::Tensor damerau_score = torch::zeros({B}, options);
        torch::Tensor beta = torch::zeros({B, alpha_size}, options);
        torch::Tensor posteriors = torch::zeros({B, max_L1, max_L2}, options);
        torch::Tensor grad_T = torch::zeros({B}, options);
        torch::Tensor grad_ins = torch::zeros({B}, options);
        torch::Tensor grad_del = torch::zeros({B}, options);
        torch::Tensor grad_trans = torch::zeros({B}, options);

        cpu::damerau_forward_cpu(
            sub_costs.data_ptr<float>(),
            trans_src.data_ptr<int>(),
            alpha.data_ptr<float>(),
            damerau_score.data_ptr<float>(),
            lengths.data_ptr<int>(),
            ins_cost, del_cost, trans_cost,
            B, max_L1, max_L2, temp_val
        );

        cpu::damerau_backward_cpu(
            alpha.data_ptr<float>(),
            sub_costs.data_ptr<float>(),
            trans_src.data_ptr<int>(),
            damerau_score.data_ptr<float>(),
            beta.data_ptr<float>(),
            posteriors.data_ptr<float>(),
            grad_T.data_ptr<float>(),
            grad_ins.data_ptr<float>(),
            grad_del.data_ptr<float>(),
            grad_trans.data_ptr<float>(),
            lengths.data_ptr<int>(),
            ins_cost, del_cost, trans_cost,
            B, max_L1, max_L2, temp_val
        );

        ctx->save_for_backward({sub_costs.clone(), trans_src.clone(), alpha.clone(), damerau_score.clone(), lengths.clone(), grad_T.clone()});
        ctx->saved_data["temperature"] = temp_val;
        ctx->saved_data["ins_cost"] = ins_cost;
        ctx->saved_data["del_cost"] = del_cost;
        ctx->saved_data["trans_cost"] = trans_cost;

        return {damerau_score, posteriors};
    }

    static torch::autograd::tensor_list backward(
        torch::autograd::AutogradContext *ctx,
        torch::autograd::tensor_list grad_outputs
    ) {
        auto saved = ctx->get_saved_variables();
        torch::Tensor sub_costs = saved[0];
        torch::Tensor trans_src = saved[1];
        torch::Tensor alpha = saved[2];
        torch::Tensor damerau_score = saved[3];
        torch::Tensor lengths = saved[4];
        torch::Tensor grad_T_fwd = saved[5];

        float temp_val = static_cast<float>(ctx->saved_data["temperature"].toDouble());
        float ins_cost = static_cast<float>(ctx->saved_data["ins_cost"].toDouble());
        float del_cost = static_cast<float>(ctx->saved_data["del_cost"].toDouble());
        float trans_cost = static_cast<float>(ctx->saved_data["trans_cost"].toDouble());

        int B = sub_costs.size(0);
        int max_L1 = sub_costs.size(1);
        int max_L2 = sub_costs.size(2);
        int alpha_size = (max_L1 + 1) * (max_L2 + 1);

        auto options = sub_costs.options();

        torch::Tensor grad_damerau_score = grad_outputs[0];
        torch::Tensor grad_posteriors = grad_outputs[1];

        torch::Tensor grad_sub_costs = torch::zeros({B, max_L1, max_L2}, options);
        torch::Tensor total_grad_T = torch::zeros({1}, options);

        // Gradient from damerau_score path
        if (grad_damerau_score.defined() && grad_damerau_score.numel() > 0) {
            torch::Tensor beta = torch::zeros({B, alpha_size}, options);
            torch::Tensor posteriors = torch::zeros({B, max_L1, max_L2}, options);
            torch::Tensor tmp_T = torch::zeros({B}, options);
            torch::Tensor tmp_ins = torch::zeros({B}, options);
            torch::Tensor tmp_del = torch::zeros({B}, options);
            torch::Tensor tmp_trans = torch::zeros({B}, options);

            cpu::damerau_backward_cpu(
                alpha.data_ptr<float>(),
                sub_costs.data_ptr<float>(),
                trans_src.data_ptr<int>(),
                damerau_score.data_ptr<float>(),
                beta.data_ptr<float>(),
                posteriors.data_ptr<float>(),
                tmp_T.data_ptr<float>(),
                tmp_ins.data_ptr<float>(),
                tmp_del.data_ptr<float>(),
                tmp_trans.data_ptr<float>(),
                lengths.data_ptr<int>(),
                ins_cost, del_cost, trans_cost,
                B, max_L1, max_L2, temp_val
            );

            grad_sub_costs += grad_damerau_score.view({B, 1, 1}) * posteriors;
            total_grad_T += (grad_damerau_score * grad_T_fwd).sum().reshape({1});
        }

        // Gradient from posteriors path (HVP)
        if (grad_posteriors.defined() && grad_posteriors.numel() > 0) {
            grad_posteriors = grad_posteriors.contiguous().to(torch::kFloat32);

            torch::Tensor d_alpha = torch::zeros({B, alpha_size}, options);
            torch::Tensor d_damerau_score = torch::zeros({B}, options);
            torch::Tensor beta = torch::zeros({B, alpha_size}, options);
            torch::Tensor d_beta = torch::zeros({B, alpha_size}, options);
            torch::Tensor hvp_grad_sub_costs = torch::zeros({B, max_L1, max_L2}, options);

            cpu::damerau_hvp_cpu(
                alpha.data_ptr<float>(),
                sub_costs.data_ptr<float>(),
                trans_src.data_ptr<int>(),
                damerau_score.data_ptr<float>(),
                grad_posteriors.data_ptr<float>(),
                d_alpha.data_ptr<float>(),
                d_damerau_score.data_ptr<float>(),
                beta.data_ptr<float>(),
                d_beta.data_ptr<float>(),
                hvp_grad_sub_costs.data_ptr<float>(),
                lengths.data_ptr<int>(),
                ins_cost, del_cost, trans_cost,
                B, max_L1, max_L2, temp_val
            );

            grad_sub_costs += hvp_grad_sub_costs;
        }

        // Return gradients: sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, lengths
        return {grad_sub_costs, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), total_grad_T, torch::Tensor()};
    }
};

// ============================================================================
// Operator Implementations
// ============================================================================

std::vector<torch::Tensor> soft_damerau_cpu(
    torch::Tensor sub_costs,
    torch::Tensor trans_src,
    torch::Tensor ins_cost,
    torch::Tensor del_cost,
    torch::Tensor trans_cost,
    torch::Tensor temperature,
    torch::Tensor lengths
) {
    return SoftDamerauCPUFunction::apply(sub_costs, trans_src, ins_cost, del_cost, trans_cost, temperature, lengths);
}

std::vector<torch::Tensor> soft_damerau_cpu_float(
    torch::Tensor sub_costs,
    torch::Tensor trans_src,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    int B = sub_costs.size(0);
    int L1 = sub_costs.size(1);
    int L2 = sub_costs.size(2);

    auto options = sub_costs.options();
    torch::Tensor temp_t = torch::tensor({static_cast<float>(temperature)}, options);
    torch::Tensor ins_t = torch::tensor({static_cast<float>(ins_cost)}, options);
    torch::Tensor del_t = torch::tensor({static_cast<float>(del_cost)}, options);
    torch::Tensor trans_t = torch::tensor({static_cast<float>(trans_cost)}, options);
    torch::Tensor lengths = lengths_opt.has_value() ? lengths_opt.value()
                                                    : make_default_lengths_cpu(B, L1, L2);

    return SoftDamerauCPUFunction::apply(sub_costs, trans_src, ins_t, del_t, trans_t, temp_t, lengths);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
soft_damerau_cpu_with_grads(
    torch::Tensor sub_costs,
    torch::Tensor trans_src,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT_CPU(sub_costs);
    CHECK_INPUT_CPU(trans_src);
    TORCH_CHECK(sub_costs.dim() == 3, "sub_costs must be 3D (B, L1, L2)");
    TORCH_CHECK(trans_src.dim() == 4, "trans_src must be 4D (B, L1, L2, 2)");

    int B = sub_costs.size(0);
    int max_L1 = sub_costs.size(1);
    int max_L2 = sub_costs.size(2);
    int alpha_size = (max_L1 + 1) * (max_L2 + 1);

    auto options = sub_costs.options();
    torch::Tensor lengths = lengths_opt.has_value() ? lengths_opt.value()
                                                    : make_default_lengths_cpu(B, max_L1, max_L2);

    CHECK_INPUT_CPU(lengths);
    TORCH_CHECK(lengths.dim() == 2 && lengths.size(0) == B && lengths.size(1) == 2);
    TORCH_CHECK(lengths.dtype() == torch::kInt32);

    float temp_val = static_cast<float>(temperature);
    float ins = static_cast<float>(ins_cost);
    float del = static_cast<float>(del_cost);
    float trans = static_cast<float>(trans_cost);

    torch::Tensor alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor damerau_score = torch::zeros({B}, options);
    torch::Tensor beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor posteriors = torch::zeros({B, max_L1, max_L2}, options);
    torch::Tensor grad_T = torch::zeros({B}, options);
    torch::Tensor grad_ins_out = torch::zeros({B}, options);
    torch::Tensor grad_del_out = torch::zeros({B}, options);
    torch::Tensor grad_trans_out = torch::zeros({B}, options);

    cpu::damerau_forward_cpu(
        sub_costs.data_ptr<float>(),
        trans_src.data_ptr<int>(),
        alpha.data_ptr<float>(),
        damerau_score.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    cpu::damerau_backward_cpu(
        alpha.data_ptr<float>(),
        sub_costs.data_ptr<float>(),
        trans_src.data_ptr<int>(),
        damerau_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        posteriors.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        grad_ins_out.data_ptr<float>(),
        grad_del_out.data_ptr<float>(),
        grad_trans_out.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    return std::make_tuple(damerau_score, posteriors, grad_T, grad_ins_out, grad_del_out, grad_trans_out);
}

torch::Tensor soft_damerau_hvp_cpu_impl(
    torch::Tensor sub_costs,
    torch::Tensor trans_src,
    torch::Tensor tangent,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT_CPU(sub_costs);
    CHECK_INPUT_CPU(trans_src);
    CHECK_INPUT_CPU(tangent);
    TORCH_CHECK(sub_costs.dim() == 3, "sub_costs must be 3D");
    TORCH_CHECK(trans_src.dim() == 4, "trans_src must be 4D");
    TORCH_CHECK(sub_costs.sizes() == tangent.sizes(), "sub_costs and tangent must have same shape");

    int B = sub_costs.size(0);
    int max_L1 = sub_costs.size(1);
    int max_L2 = sub_costs.size(2);
    int alpha_size = (max_L1 + 1) * (max_L2 + 1);

    auto options = sub_costs.options();
    torch::Tensor lengths = lengths_opt.has_value() ? lengths_opt.value()
                                                    : make_default_lengths_cpu(B, max_L1, max_L2);

    CHECK_INPUT_CPU(lengths);
    TORCH_CHECK(lengths.dim() == 2 && lengths.size(0) == B && lengths.size(1) == 2);
    TORCH_CHECK(lengths.dtype() == torch::kInt32);

    float temp_val = static_cast<float>(temperature);
    float ins = static_cast<float>(ins_cost);
    float del = static_cast<float>(del_cost);
    float trans = static_cast<float>(trans_cost);

    torch::Tensor alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor damerau_score = torch::zeros({B}, options);
    torch::Tensor d_alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor d_damerau_score = torch::zeros({B}, options);
    torch::Tensor beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor d_beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor H_scores = torch::zeros({B, max_L1, max_L2}, options);

    cpu::damerau_forward_cpu(
        sub_costs.data_ptr<float>(),
        trans_src.data_ptr<int>(),
        alpha.data_ptr<float>(),
        damerau_score.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    cpu::damerau_hvp_cpu(
        alpha.data_ptr<float>(),
        sub_costs.data_ptr<float>(),
        trans_src.data_ptr<int>(),
        damerau_score.data_ptr<float>(),
        tangent.data_ptr<float>(),
        d_alpha.data_ptr<float>(),
        d_damerau_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        H_scores.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    return H_scores;
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
soft_damerau_backward_full_cpu_impl(
    torch::Tensor sub_costs,
    torch::Tensor trans_src,
    torch::Tensor grad_output,
    double ins_cost,
    double del_cost,
    double trans_cost,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT_CPU(sub_costs);
    CHECK_INPUT_CPU(trans_src);
    TORCH_CHECK(sub_costs.dim() == 3 && trans_src.dim() == 4 && grad_output.dim() == 3, "invalid tensor dimensions");

    int B = sub_costs.size(0);
    int max_L1 = sub_costs.size(1);
    int max_L2 = sub_costs.size(2);
    int alpha_size = (max_L1 + 1) * (max_L2 + 1);

    auto options = sub_costs.options();
    torch::Tensor lengths = lengths_opt.has_value() ? lengths_opt.value()
                                                    : make_default_lengths_cpu(B, max_L1, max_L2);

    CHECK_INPUT_CPU(lengths);
    TORCH_CHECK(lengths.dim() == 2 && lengths.size(0) == B && lengths.size(1) == 2);
    TORCH_CHECK(lengths.dtype() == torch::kInt32);

    grad_output = grad_output.contiguous().to(torch::kFloat32);

    float temp_val = static_cast<float>(temperature);
    float ins = static_cast<float>(ins_cost);
    float del = static_cast<float>(del_cost);
    float trans = static_cast<float>(trans_cost);

    // Forward pass
    torch::Tensor alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor damerau_score = torch::zeros({B}, options);

    cpu::damerau_forward_cpu(
        sub_costs.data_ptr<float>(),
        trans_src.data_ptr<int>(),
        alpha.data_ptr<float>(),
        damerau_score.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    // HVP for grad_sub_costs
    torch::Tensor d_alpha = torch::zeros({B, alpha_size}, options);
    torch::Tensor d_damerau_score = torch::zeros({B}, options);
    torch::Tensor beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor d_beta = torch::zeros({B, alpha_size}, options);
    torch::Tensor grad_sub_costs = torch::zeros({B, max_L1, max_L2}, options);

    cpu::damerau_hvp_cpu(
        alpha.data_ptr<float>(),
        sub_costs.data_ptr<float>(),
        trans_src.data_ptr<int>(),
        damerau_score.data_ptr<float>(),
        grad_output.data_ptr<float>(),
        d_alpha.data_ptr<float>(),
        d_damerau_score.data_ptr<float>(),
        beta.data_ptr<float>(),
        d_beta.data_ptr<float>(),
        grad_sub_costs.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    // Get cost gradients from backward
    torch::Tensor beta2 = torch::zeros({B, alpha_size}, options);
    torch::Tensor posteriors = torch::zeros({B, max_L1, max_L2}, options);
    torch::Tensor grad_T = torch::zeros({B}, options);
    torch::Tensor grad_ins_out = torch::zeros({B}, options);
    torch::Tensor grad_del_out = torch::zeros({B}, options);
    torch::Tensor grad_trans_out = torch::zeros({B}, options);

    cpu::damerau_backward_cpu(
        alpha.data_ptr<float>(),
        sub_costs.data_ptr<float>(),
        trans_src.data_ptr<int>(),
        damerau_score.data_ptr<float>(),
        beta2.data_ptr<float>(),
        posteriors.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        grad_ins_out.data_ptr<float>(),
        grad_del_out.data_ptr<float>(),
        grad_trans_out.data_ptr<float>(),
        lengths.data_ptr<int>(),
        ins, del, trans,
        B, max_L1, max_L2, temp_val
    );

    // Weight by grad_output
    torch::Tensor total_grad_T = (grad_T * grad_output.sum({1, 2}));
    torch::Tensor total_grad_ins = (grad_ins_out * grad_output.sum({1, 2}));
    torch::Tensor total_grad_del = (grad_del_out * grad_output.sum({1, 2}));
    torch::Tensor total_grad_trans = (grad_trans_out * grad_output.sum({1, 2}));

    return std::make_tuple(grad_sub_costs, total_grad_T, total_grad_ins, total_grad_del, total_grad_trans);
}

}  // namespace damerau
}  // namespace d2p

// ============================================================================
// Registration
// ============================================================================

#ifdef USE_TORCH_LIBRARY

TORCH_LIBRARY_IMPL(d2p, CPU, m) {
    m.impl("soft_damerau", d2p::damerau::soft_damerau_cpu);
    m.impl("soft_damerau_float", d2p::damerau::soft_damerau_cpu_float);
    m.impl("soft_damerau_with_grads", d2p::damerau::soft_damerau_cpu_with_grads);
    m.impl("soft_damerau_hvp", d2p::damerau::soft_damerau_hvp_cpu_impl);
    m.impl("soft_damerau_backward_full", d2p::damerau::soft_damerau_backward_full_cpu_impl);
}

TORCH_LIBRARY_IMPL(d2p, AutogradCPU, m) {
    m.impl("soft_damerau", d2p::damerau::soft_damerau_cpu);
    m.impl("soft_damerau_float", d2p::damerau::soft_damerau_cpu_float);
}

#endif // USE_TORCH_LIBRARY
