/**
 * @file kernels.cuh
 * @brief Soft Hamming Distance CUDA Kernel Declarations
 *
 * Hamming distance is the simplest edit distance - it counts the number of
 * positions where the two sequences differ (for equal-length sequences).
 *
 * The "soft" version computes:
 * - distance = sum_i costs[i] (where costs = 0 for match, positive for mismatch)
 * - posteriors = 1 (gradient of sum w.r.t. each input is 1)
 *
 * Since there's no dynamic programming involved, this is O(n) linear.
 * The gradient is trivial: dH/dcosts[i] = 1 for all i.
 * The Hessian is 0 (linear function), so HVP = 0.
 */

#pragma once

namespace d2p {
namespace hamming {

/**
 * Forward pass: compute Hamming distance for each batch.
 *
 * @param d_costs     [B, L] input costs (0 = match, positive = mismatch)
 * @param d_distance  [B] output distances
 * @param d_lengths   [B] sequence lengths (or nullptr for full length)
 * @param B           batch size
 * @param L           max sequence length
 * @param temperature temperature (unused for forward, kept for API consistency)
 */
void forward(
    const float* d_costs,
    float* d_distance,
    const int* d_lengths,
    int B,
    int L,
    float temperature
);

/**
 * Compute posteriors (all ones for valid positions).
 *
 * @param d_posteriors [B, L] output posteriors
 * @param d_lengths    [B] sequence lengths (or nullptr)
 * @param B            batch size
 * @param L            max sequence length
 */
void posteriors(
    float* d_posteriors,
    const int* d_lengths,
    int B,
    int L
);

/**
 * Forward pass with posteriors.
 *
 * @param d_costs      [B, L] input costs
 * @param d_distance   [B] output distances
 * @param d_posteriors [B, L] output posteriors
 * @param d_lengths    [B] sequence lengths (or nullptr)
 * @param B            batch size
 * @param L            max sequence length
 * @param temperature  temperature (unused for this operator)
 */
void forward_with_posteriors(
    const float* d_costs,
    float* d_distance,
    float* d_posteriors,
    const int* d_lengths,
    int B,
    int L,
    float temperature
);

/**
 * Backward pass with parameter gradients.
 *
 * Since Hamming distance is just a sum, the gradients are trivial:
 * - d_posteriors = 1 for valid positions
 * - d_grad_T = 0 (doesn't depend on temperature)
 *
 * @param d_costs      [B, L] input costs
 * @param d_distance   [B] output distances
 * @param d_posteriors [B, L] output posteriors
 * @param d_grad_T     [B] output temperature gradient (zeros)
 * @param d_lengths    [B] sequence lengths (or nullptr)
 * @param B            batch size
 * @param L            max sequence length
 * @param temperature  temperature
 */
void backward(
    const float* d_costs,
    float* d_distance,
    float* d_posteriors,
    float* d_grad_T,
    const int* d_lengths,
    int B,
    int L,
    float temperature
);

/**
 * HVP (Hessian-vector product).
 * For a linear function, the Hessian is 0, so HVP is 0.
 *
 * @param d_costs   [B, L] input costs
 * @param d_tangent [B, L] input tangent vector
 * @param d_hvp     [B, L] output HVP (zeros)
 * @param d_lengths [B] sequence lengths (or nullptr)
 * @param B         batch size
 * @param L         max sequence length
 * @param temperature temperature
 */
void hvp(
    const float* d_costs,
    const float* d_tangent,
    float* d_hvp,
    const int* d_lengths,
    int B,
    int L,
    float temperature
);

/**
 * Full backward pass.
 *
 * @param d_costs       [B, L] input costs
 * @param d_grad_output [B] gradient w.r.t. output distance
 * @param d_grad_costs  [B, L] output gradient w.r.t. costs
 * @param d_grad_T      [B] output gradient w.r.t. temperature
 * @param d_lengths     [B] sequence lengths (or nullptr)
 * @param B             batch size
 * @param L             max sequence length
 * @param temperature   temperature
 */
void backward_full(
    const float* d_costs,
    const float* d_grad_output,
    float* d_grad_costs,
    float* d_grad_T,
    const int* d_lengths,
    int B,
    int L,
    float temperature
);

} // namespace hamming
} // namespace d2p
