/**
 * @file kernels_cpu.h
 * @brief Soft Hamming Distance CPU Kernel Declarations
 *
 * Hamming distance is the simplest edit distance - it counts the number of
 * positions where the two sequences differ (for equal-length sequences).
 *
 * The "soft" version computes:
 * - distance = sum_i costs[i] (where costs = 0 for match, positive for mismatch)
 * - posteriors = 1 (gradient of sum w.r.t. each input is 1)
 */

#pragma once

namespace d2p {
namespace hamming {
namespace cpu {

/**
 * Forward pass: compute Hamming distance for each batch.
 *
 * @param costs     [B, L] input costs (0 = match, positive for mismatch)
 * @param distance  [B] output distances
 * @param lengths   [B] sequence lengths (or nullptr for full length)
 * @param B         batch size
 * @param L         max sequence length
 * @param temperature temperature (unused for forward, kept for API consistency)
 */
void forward(
    const float* costs,
    float* distance,
    const int* lengths,
    int B,
    int L,
    float temperature
);

/**
 * Compute posteriors (all ones for valid positions).
 *
 * @param posteriors [B, L] output posteriors
 * @param lengths    [B] sequence lengths (or nullptr)
 * @param B          batch size
 * @param L          max sequence length
 */
void posteriors(
    float* posteriors,
    const int* lengths,
    int B,
    int L
);

/**
 * Forward pass with posteriors.
 *
 * @param costs      [B, L] input costs
 * @param distance   [B] output distances
 * @param posteriors [B, L] output posteriors
 * @param lengths    [B] sequence lengths (or nullptr)
 * @param B          batch size
 * @param L          max sequence length
 * @param temperature temperature (unused)
 */
void forward_with_posteriors(
    const float* costs,
    float* distance,
    float* posteriors,
    const int* lengths,
    int B,
    int L,
    float temperature
);

/**
 * Backward pass with parameter gradients.
 *
 * @param costs      [B, L] input costs
 * @param distance   [B] output distances
 * @param posteriors [B, L] output posteriors
 * @param grad_T     [B] output temperature gradient (zeros)
 * @param lengths    [B] sequence lengths (or nullptr)
 * @param B          batch size
 * @param L          max sequence length
 * @param temperature temperature
 */
void backward(
    const float* costs,
    float* distance,
    float* posteriors,
    float* grad_T,
    const int* lengths,
    int B,
    int L,
    float temperature
);

/**
 * HVP (Hessian-vector product).
 * For a linear function, the Hessian is 0, so HVP is 0.
 *
 * @param costs    [B, L] input costs
 * @param tangent  [B, L] input tangent vector
 * @param hvp      [B, L] output HVP (zeros)
 * @param lengths  [B] sequence lengths (or nullptr)
 * @param B        batch size
 * @param L        max sequence length
 * @param temperature temperature
 */
void hvp(
    const float* costs,
    const float* tangent,
    float* hvp,
    const int* lengths,
    int B,
    int L,
    float temperature
);

/**
 * Full backward pass.
 *
 * @param costs       [B, L] input costs
 * @param grad_output [B] gradient w.r.t. output distance
 * @param grad_costs  [B, L] output gradient w.r.t. costs
 * @param grad_T      [B] output gradient w.r.t. temperature
 * @param lengths     [B] sequence lengths (or nullptr)
 * @param B           batch size
 * @param L           max sequence length
 * @param temperature temperature
 */
void backward_full(
    const float* costs,
    const float* grad_output,
    float* grad_costs,
    float* grad_T,
    const int* lengths,
    int B,
    int L,
    float temperature
);

} // namespace cpu
} // namespace hamming
} // namespace d2p
