/**
 * @file kernels_cpu.cpp
 * @brief Soft Hamming Distance CPU Kernel Implementations
 *
 * Hamming distance is the simplest edit distance - it counts the number of
 * positions where the two sequences differ (for equal-length sequences).
 *
 * The "soft" version computes:
 * - distance = sum_i costs[i] (where costs = 0 for match, positive for mismatch)
 * - posteriors = 1 (gradient of sum w.r.t. each input is 1)
 */

#include "kernels_cpu.h"
#include <cmath>
#include <cfloat>

namespace d2p {
namespace hamming {
namespace cpu {

// =============================================================================
// Kahan Summation for Numerical Stability
// =============================================================================

struct KahanSum {
    float sum = 0.0f;
    float c = 0.0f;  // compensation for lost low-order bits

    void add(float x) {
        float y = x - c;
        float t = sum + y;
        c = (t - sum) - y;
        sum = t;
    }

    float get() const { return sum; }
};

// =============================================================================
// CPU Kernel Implementations
// =============================================================================

void forward(
    const float* costs,
    float* distance,
    const int* lengths,
    int B,
    int L,
    float temperature
) {
    for (int b = 0; b < B; b++) {
        int actual_L = lengths ? lengths[b] : L;
        KahanSum sum;

        for (int i = 0; i < actual_L; i++) {
            sum.add(costs[b * L + i]);
        }

        distance[b] = sum.get();
    }
}

void posteriors(
    float* posteriors,
    const int* lengths,
    int B,
    int L
) {
    for (int b = 0; b < B; b++) {
        int actual_L = lengths ? lengths[b] : L;
        for (int i = 0; i < L; i++) {
            posteriors[b * L + i] = (i < actual_L) ? 1.0f : 0.0f;
        }
    }
}

void forward_with_posteriors(
    const float* costs,
    float* distance,
    float* posteriors_out,
    const int* lengths,
    int B,
    int L,
    float temperature
) {
    forward(costs, distance, lengths, B, L, temperature);
    posteriors(posteriors_out, lengths, B, L);
}

void backward(
    const float* costs,
    float* distance,
    float* posteriors_out,
    float* grad_T,
    const int* lengths,
    int B,
    int L,
    float temperature
) {
    forward_with_posteriors(costs, distance, posteriors_out, lengths, B, L, temperature);

    // Temperature gradient is zero
    for (int b = 0; b < B; b++) {
        grad_T[b] = 0.0f;
    }
}

void hvp(
    const float* costs,
    const float* tangent,
    float* hvp_out,
    const int* lengths,
    int B,
    int L,
    float temperature
) {
    // For a linear function, the Hessian is 0, so HVP is 0
    int total = B * L;
    for (int i = 0; i < total; i++) {
        hvp_out[i] = 0.0f;
    }
}

void backward_full(
    const float* costs,
    const float* grad_output,
    float* grad_costs,
    float* grad_T,
    const int* lengths,
    int B,
    int L,
    float temperature
) {
    // Gradient w.r.t. costs is grad_output[b] for valid positions
    for (int b = 0; b < B; b++) {
        int actual_L = lengths ? lengths[b] : L;
        for (int i = 0; i < L; i++) {
            grad_costs[b * L + i] = (i < actual_L) ? grad_output[b] : 0.0f;
        }
        grad_T[b] = 0.0f;
    }
}

} // namespace cpu
} // namespace hamming
} // namespace d2p
