/**
 * @file torch_cuda.cpp
 * @brief Soft Hamming Distance CUDA PyTorch Bindings
 *
 * Provides GPU-accelerated soft Hamming distance implementation with:
 *   - Simple O(n) sum of mismatch costs
 *   - Trivial gradients (1 for valid positions)
 *   - Zero Hessian (linear function)
 */

#include <torch/extension.h>
#include <vector>

#include "kernels.cuh"

// =============================================================================
// Helper Macros
// =============================================================================

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

// =============================================================================
// Helper: Create default lengths tensor
// =============================================================================

namespace {

torch::Tensor make_default_lengths(int B, int L, const torch::TensorOptions& options) {
    auto cpu_options = torch::TensorOptions().dtype(torch::kInt32);
    auto lengths_cpu = torch::full({B}, L, cpu_options);
    return lengths_cpu.to(options.device());
}

} // anonymous namespace

// =============================================================================
// Hamming CUDA Autograd Function
// =============================================================================

class SoftHammingCUDAFunction : public torch::autograd::Function<SoftHammingCUDAFunction> {
public:
    static torch::autograd::tensor_list forward(
        torch::autograd::AutogradContext *ctx,
        torch::Tensor costs,
        torch::Tensor temperature,
        torch::Tensor lengths
    ) {
        CHECK_INPUT(costs);
        TORCH_CHECK(costs.dim() == 2, "costs must be 2D (B, L)");
        TORCH_CHECK(costs.dtype() == torch::kFloat32, "costs must be float32");
        TORCH_CHECK(temperature.numel() == 1, "temperature must be a scalar tensor");

        int B = costs.size(0);
        int L = costs.size(1);

        CHECK_CUDA(lengths);
        CHECK_CONTIGUOUS(lengths);
        TORCH_CHECK(lengths.dim() == 1 && lengths.size(0) == B);
        TORCH_CHECK(lengths.dtype() == torch::kInt32);

        float temp_val = temperature.cpu().item<float>();

        auto options = costs.options();
        torch::Tensor distance = torch::zeros({B}, options);
        torch::Tensor posteriors = torch::zeros({B, L}, options);
        torch::Tensor grad_T = torch::zeros({B}, options);

        d2p::hamming::backward(
            costs.data_ptr<float>(),
            distance.data_ptr<float>(),
            posteriors.data_ptr<float>(),
            grad_T.data_ptr<float>(),
            lengths.data_ptr<int>(),
            B, L, temp_val
        );

        ctx->save_for_backward({costs.clone(), lengths.clone()});
        ctx->saved_data["temperature"] = temp_val;

        return {distance, posteriors};
    }

    static torch::autograd::tensor_list backward(
        torch::autograd::AutogradContext *ctx,
        torch::autograd::tensor_list grad_outputs
    ) {
        auto saved = ctx->get_saved_variables();
        torch::Tensor costs = saved[0];
        torch::Tensor lengths = saved[1];

        float temp_val = static_cast<float>(ctx->saved_data["temperature"].toDouble());

        int B = costs.size(0);
        int L = costs.size(1);

        auto options = costs.options();

        torch::Tensor grad_distance = grad_outputs[0];
        torch::Tensor grad_posteriors = grad_outputs[1];

        torch::Tensor grad_costs = torch::zeros({B, L}, options);
        torch::Tensor total_grad_T = torch::zeros({1}, options);

        // Gradient from distance: d(distance)/d(costs) = 1 for valid positions
        // So grad_costs = grad_distance[b] for each valid position
        if (grad_distance.defined() && grad_distance.numel() > 0) {
            // Expand grad_distance to [B, L] and mask invalid positions
            torch::Tensor posteriors = torch::zeros({B, L}, options);
            d2p::hamming::posteriors(posteriors.data_ptr<float>(), lengths.data_ptr<int>(), B, L);
            grad_costs += posteriors * grad_distance.unsqueeze(1);
        }

        // Gradient from posteriors: posteriors don't depend on costs
        // (they're constant 1s), so no additional gradient

        return {grad_costs, total_grad_T, torch::Tensor()};
    }
};

// =============================================================================
// Python-Accessible Functions
// =============================================================================

/**
 * soft_hamming - autograd-wrapped version for tensor parameters
 */
std::vector<torch::Tensor> soft_hamming_cuda(
    torch::Tensor costs,
    torch::Tensor temperature,
    torch::Tensor lengths
) {
    auto result = SoftHammingCUDAFunction::apply(costs, temperature, lengths);
    return {result[0], result[1]};
}

/**
 * soft_hamming_float - autograd-wrapped version with float parameters
 */
std::vector<torch::Tensor> soft_hamming_float_cuda(
    torch::Tensor costs,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT(costs);
    TORCH_CHECK(costs.dim() == 2, "costs must be 2D (B, L)");
    TORCH_CHECK(costs.dtype() == torch::kFloat32, "costs must be float32");

    int B = costs.size(0);
    int L = costs.size(1);

    auto options = costs.options();

    torch::Tensor lengths;
    if (lengths_opt.has_value()) {
        lengths = lengths_opt.value();
        CHECK_CUDA(lengths);
        CHECK_CONTIGUOUS(lengths);
    } else {
        lengths = make_default_lengths(B, L, options);
    }

    // Create temperature tensor and go through autograd
    torch::Tensor temp_tensor = torch::tensor({static_cast<float>(temperature)}, options);
    auto result = SoftHammingCUDAFunction::apply(costs, temp_tensor, lengths);
    return {result[0], result[1]};
}

/**
 * soft_hamming_with_grads - returns distance, posteriors, and grad_T
 */
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> soft_hamming_with_grads_cuda(
    torch::Tensor costs,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT(costs);
    TORCH_CHECK(costs.dim() == 2, "costs must be 2D (B, L)");
    TORCH_CHECK(costs.dtype() == torch::kFloat32, "costs must be float32");

    int B = costs.size(0);
    int L = costs.size(1);

    auto options = costs.options();

    torch::Tensor lengths;
    if (lengths_opt.has_value()) {
        lengths = lengths_opt.value();
        CHECK_CUDA(lengths);
        CHECK_CONTIGUOUS(lengths);
    } else {
        lengths = make_default_lengths(B, L, options);
    }

    torch::Tensor distance = torch::zeros({B}, options);
    torch::Tensor posteriors = torch::zeros({B, L}, options);
    torch::Tensor grad_T = torch::zeros({B}, options);

    d2p::hamming::backward(
        costs.data_ptr<float>(),
        distance.data_ptr<float>(),
        posteriors.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        lengths.data_ptr<int>(),
        B, L, temperature
    );

    return std::make_tuple(distance, posteriors, grad_T);
}

/**
 * soft_hamming_hvp - Hessian-vector product (always zero for Hamming)
 */
torch::Tensor soft_hamming_hvp_cuda(
    torch::Tensor costs,
    torch::Tensor tangent,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT(costs);
    CHECK_INPUT(tangent);
    TORCH_CHECK(costs.dim() == 2, "costs must be 2D (B, L)");
    TORCH_CHECK(tangent.sizes() == costs.sizes(), "tangent must have same shape as costs");

    int B = costs.size(0);
    int L = costs.size(1);

    auto options = costs.options();

    torch::Tensor lengths;
    if (lengths_opt.has_value()) {
        lengths = lengths_opt.value();
        CHECK_CUDA(lengths);
        CHECK_CONTIGUOUS(lengths);
    } else {
        lengths = make_default_lengths(B, L, options);
    }

    torch::Tensor hvp = torch::zeros({B, L}, options);

    d2p::hamming::hvp(
        costs.data_ptr<float>(),
        tangent.data_ptr<float>(),
        hvp.data_ptr<float>(),
        lengths.data_ptr<int>(),
        B, L, temperature
    );

    return hvp;
}

/**
 * soft_hamming_backward_full - full backward pass
 */
std::tuple<torch::Tensor, torch::Tensor> soft_hamming_backward_full_cuda(
    torch::Tensor costs,
    torch::Tensor grad_output,
    double temperature,
    c10::optional<torch::Tensor> lengths_opt
) {
    CHECK_INPUT(costs);
    CHECK_INPUT(grad_output);
    TORCH_CHECK(costs.dim() == 2, "costs must be 2D (B, L)");
    TORCH_CHECK(grad_output.dim() == 1, "grad_output must be 1D (B,)");

    int B = costs.size(0);
    int L = costs.size(1);

    auto options = costs.options();

    torch::Tensor lengths;
    if (lengths_opt.has_value()) {
        lengths = lengths_opt.value();
        CHECK_CUDA(lengths);
        CHECK_CONTIGUOUS(lengths);
    } else {
        lengths = make_default_lengths(B, L, options);
    }

    torch::Tensor grad_costs = torch::zeros({B, L}, options);
    torch::Tensor grad_T = torch::zeros({B}, options);

    d2p::hamming::backward_full(
        costs.data_ptr<float>(),
        grad_output.data_ptr<float>(),
        grad_costs.data_ptr<float>(),
        grad_T.data_ptr<float>(),
        lengths.data_ptr<int>(),
        B, L, temperature
    );

    return std::make_tuple(grad_costs, grad_T);
}

// =============================================================================
// TORCH_LIBRARY Registration
// =============================================================================

#ifdef USE_TORCH_LIBRARY

TORCH_LIBRARY_IMPL(d2p, CUDA, m) {
    m.impl("soft_hamming", soft_hamming_cuda);
    m.impl("soft_hamming_float", soft_hamming_float_cuda);
    m.impl("soft_hamming_with_grads", soft_hamming_with_grads_cuda);
    m.impl("soft_hamming_hvp", soft_hamming_hvp_cuda);
    m.impl("soft_hamming_backward_full", soft_hamming_backward_full_cuda);
}

TORCH_LIBRARY_IMPL(d2p, AutogradCUDA, m) {
    m.impl("soft_hamming", soft_hamming_cuda);
    m.impl("soft_hamming_float", soft_hamming_float_cuda);
}

#endif // USE_TORCH_LIBRARY
