/**
 * @file kernels_cpu.h
 * @brief Soft OSA CPU Kernel Declarations
 *
 * CPU implementation declarations mirroring CUDA interface.
 */

#pragma once

namespace d2p {
namespace osa {
namespace cpu {

// ============================================================================
// Constants
// ============================================================================

constexpr float PINF = 1e30f;   // Positive infinity for minimization

// ============================================================================
// CPU Kernel Function Declarations
// ============================================================================

void osa_forward_cpu(
    const float* sub_costs,
    const float* trans_mask,
    float* alpha,
    float* osa_score,
    const int* lengths,
    float ins_cost, float del_cost, float trans_cost,
    int B, int max_L1, int max_L2,
    float T
);

void osa_backward_cpu(
    const float* alpha,
    const float* sub_costs,
    const float* trans_mask,
    const float* osa_score,
    float* beta,
    float* posteriors,
    float* grad_T,
    float* grad_ins,
    float* grad_del,
    float* grad_trans,
    const int* lengths,
    float ins_cost, float del_cost, float trans_cost,
    int B, int max_L1, int max_L2,
    float T
);

void osa_hvp_cpu(
    const float* alpha,
    const float* sub_costs,
    const float* trans_mask,
    const float* osa_score,
    const float* V,
    float* d_alpha,
    float* d_score,
    float* beta,
    float* d_beta,
    float* H_scores,
    const int* lengths,
    float ins_cost, float del_cost, float trans_cost,
    int B, int max_L1, int max_L2,
    float T
);

}  // namespace cpu
}  // namespace osa
}  // namespace d2p
