/**
 * @file kernels_cpu.cpp
 * @brief Soft OSA CPU Kernel Implementations
 *
 * CPU implementation of Soft OSA (Optimal String Alignment).
 * Mirrors the CUDA kernel interface exactly for seamless dispatch.
 */

#include <cmath>
#include <algorithm>
#include "kernels_cpu.h"

namespace d2p {
namespace osa {
namespace cpu {

// ============================================================================
// Helper Functions
// ============================================================================

inline float safe_exp(float x) {
    if (x < -88.0f) return 0.0f;
    if (x > 88.0f) x = 88.0f;
    return std::exp(x);
}

// Kahan compensated summation
struct KahanAccumulator {
    float sum = 0.0f;
    float c = 0.0f;

    void add(float value) {
        float y = value - c;
        float t = sum + y;
        c = (t - sum) - y;
        sum = t;
    }

    float result() const { return sum; }
};

// Softmin for 4 values
inline float softmin4(float a, float b, float c, float d, float T) {
    float m = std::min({a, b, c, d});
    if (m >= PINF) return PINF;

    KahanAccumulator sum;
    if (a < PINF) sum.add(safe_exp(-(a - m) / T));
    if (b < PINF) sum.add(safe_exp(-(b - m) / T));
    if (c < PINF) sum.add(safe_exp(-(c - m) / T));
    if (d < PINF) sum.add(safe_exp(-(d - m) / T));

    float s = sum.result();
    if (s <= 0.0f) return PINF;
    return m - T * std::log(s);
}

// Softmin weights for 4 values
inline void softmin4_weights(float a, float b, float c, float d, float T,
                             float& wa, float& wb, float& wc, float& wd) {
    float m = std::min({a, b, c, d});
    if (m >= PINF) {
        wa = wb = wc = wd = 0.0f;
        return;
    }

    float ea = (a < PINF) ? safe_exp(-(a - m) / T) : 0.0f;
    float eb = (b < PINF) ? safe_exp(-(b - m) / T) : 0.0f;
    float ec = (c < PINF) ? safe_exp(-(c - m) / T) : 0.0f;
    float ed = (d < PINF) ? safe_exp(-(d - m) / T) : 0.0f;

    float total = ea + eb + ec + ed;
    if (total > 0.0f) {
        wa = ea / total;
        wb = eb / total;
        wc = ec / total;
        wd = ed / total;
    } else {
        wa = wb = wc = wd = 0.0f;
    }
}

// ============================================================================
// FORWARD PASS
// ============================================================================

void osa_forward_cpu(
    const float* sub_costs,
    const float* trans_mask,
    float* alpha,
    float* osa_score,
    const int* lengths,
    float ins_cost, float del_cost, float trans_cost,
    int B, int max_L1, int max_L2,
    float T
) {
    const size_t alpha_stride = (size_t)(max_L1 + 1) * (max_L2 + 1);
    const size_t score_stride = (size_t)max_L1 * max_L2;
    const int alpha_cols = max_L2 + 1;

    for (int b = 0; b < B; b++) {
        const float* s = sub_costs + b * score_stride;
        const float* tm = trans_mask + b * score_stride;
        float* a = alpha + b * alpha_stride;
        int L1 = lengths[b * 2];
        int L2 = lengths[b * 2 + 1];

        // Initialize alpha
        for (size_t idx = 0; idx < alpha_stride; idx++) {
            a[idx] = PINF;
        }

        // Base cases
        a[0] = 0.0f;
        for (int i = 1; i <= L1; i++) {
            a[i * alpha_cols] = i * del_cost;
        }
        for (int j = 1; j <= L2; j++) {
            a[j] = j * ins_cost;
        }

        // Forward DP
        for (int k = 2; k <= L1 + L2; k++) {
            int i_start = std::max(1, k - L2);
            int i_end = std::min(L1, k - 1);

            for (int i = i_start; i <= i_end; i++) {
                int j = k - i;
                if (j < 1 || j > L2) continue;

                int idx = i * alpha_cols + j;
                int idx_diag = (i - 1) * alpha_cols + (j - 1);
                int idx_up = (i - 1) * alpha_cols + j;
                int idx_left = i * alpha_cols + (j - 1);
                int score_idx = (i - 1) * max_L2 + (j - 1);

                float sub_cost_val = s[score_idx];
                float trans_valid = (i >= 2 && j >= 2) ? tm[score_idx] : 0.0f;

                float a_diag = a[idx_diag];
                float a_up = a[idx_up];
                float a_left = a[idx_left];

                float v_sub = a_diag + sub_cost_val;
                float v_del = a_up + del_cost;
                float v_ins = a_left + ins_cost;
                float v_trans = PINF;
                if (trans_valid > 0.5f) {
                    int idx_trans = (i - 2) * alpha_cols + (j - 2);
                    v_trans = a[idx_trans] + trans_cost;
                }

                a[idx] = softmin4(v_sub, v_del, v_ins, v_trans, T);
            }
        }

        int final_idx = L1 * alpha_cols + L2;
        osa_score[b] = a[final_idx];
    }
}

// ============================================================================
// BACKWARD PASS
// ============================================================================

void osa_backward_cpu(
    const float* alpha,
    const float* sub_costs,
    const float* trans_mask,
    const float* osa_score,
    float* beta,
    float* posteriors,
    float* grad_T,
    float* grad_ins,
    float* grad_del,
    float* grad_trans,
    const int* lengths,
    float ins_cost, float del_cost, float trans_cost,
    int B, int max_L1, int max_L2,
    float T
) {
    const size_t alpha_stride = (size_t)(max_L1 + 1) * (max_L2 + 1);
    const size_t score_stride = (size_t)max_L1 * max_L2;
    const int alpha_cols = max_L2 + 1;

    for (int b = 0; b < B; b++) {
        const float* a = alpha + b * alpha_stride;
        const float* s = sub_costs + b * score_stride;
        const float* tm = trans_mask + b * score_stride;
        float* be = beta + b * alpha_stride;
        float* post = posteriors + b * score_stride;
        int L1 = lengths[b * 2];
        int L2 = lengths[b * 2 + 1];

        // Initialize
        for (size_t idx = 0; idx < score_stride; idx++) {
            post[idx] = 0.0f;
        }
        for (size_t idx = 0; idx < alpha_stride; idx++) {
            be[idx] = 0.0f;
        }

        int final_idx = L1 * alpha_cols + L2;
        be[final_idx] = 1.0f;

        float sum_T_grad = 0.0f;
        float sum_ins_grad = 0.0f;
        float sum_del_grad = 0.0f;
        float sum_trans_grad = 0.0f;

        // Backward DP
        for (int k = L1 + L2; k >= 2; k--) {
            int i_start = std::max(1, k - L2);
            int i_end = std::min(L1, k - 1);

            for (int i = i_start; i <= i_end; i++) {
                int j = k - i;
                if (j < 1 || j > L2) continue;

                int idx = i * alpha_cols + j;
                int idx_diag = (i - 1) * alpha_cols + (j - 1);
                int idx_up = (i - 1) * alpha_cols + j;
                int idx_left = i * alpha_cols + (j - 1);
                int score_idx = (i - 1) * max_L2 + (j - 1);

                float beta_ij = be[idx];
                if (beta_ij == 0.0f) continue;

                float sub_cost_val = s[score_idx];
                float trans_valid = (i >= 2 && j >= 2) ? tm[score_idx] : 0.0f;

                float a_diag = a[idx_diag];
                float a_up = a[idx_up];
                float a_left = a[idx_left];

                float v_sub = a_diag + sub_cost_val;
                float v_del = a_up + del_cost;
                float v_ins = a_left + ins_cost;
                float v_trans = PINF;
                int idx_trans = -1;
                if (trans_valid > 0.5f) {
                    idx_trans = (i - 2) * alpha_cols + (j - 2);
                    v_trans = a[idx_trans] + trans_cost;
                }

                float w_sub, w_del, w_ins, w_trans;
                softmin4_weights(v_sub, v_del, v_ins, v_trans, T, w_sub, w_del, w_ins, w_trans);

                // Posteriors
                post[score_idx] = beta_ij * w_sub;

                // Temperature gradient
                float alpha_ij = a[idx];
                if (alpha_ij < PINF) {
                    float E_v = w_sub * v_sub + w_del * v_del + w_ins * v_ins + w_trans * v_trans;
                    sum_T_grad += beta_ij * (E_v - alpha_ij) / T;
                }

                // Cost gradients
                sum_ins_grad += beta_ij * w_ins;
                sum_del_grad += beta_ij * w_del;
                sum_trans_grad += beta_ij * w_trans;

                // Propagate beta
                if (w_sub > 0.0f) {
                    be[idx_diag] += beta_ij * w_sub;
                }
                if (w_del > 0.0f) {
                    be[idx_up] += beta_ij * w_del;
                }
                if (w_ins > 0.0f) {
                    be[idx_left] += beta_ij * w_ins;
                }
                if (w_trans > 0.0f && idx_trans >= 0) {
                    be[idx_trans] += beta_ij * w_trans;
                }
            }
        }

        grad_T[b] = sum_T_grad;
        grad_ins[b] = sum_ins_grad;
        grad_del[b] = sum_del_grad;
        grad_trans[b] = sum_trans_grad;
    }
}

// ============================================================================
// HESSIAN-VECTOR PRODUCT (HVP)
// ============================================================================

void osa_hvp_cpu(
    const float* alpha,
    const float* sub_costs,
    const float* trans_mask,
    const float* osa_score,
    const float* V,
    float* d_alpha,
    float* d_score,
    float* beta,
    float* d_beta,
    float* H_scores,
    const int* lengths,
    float ins_cost, float del_cost, float trans_cost,
    int B, int max_L1, int max_L2,
    float T
) {
    const size_t alpha_stride = (size_t)(max_L1 + 1) * (max_L2 + 1);
    const size_t score_stride = (size_t)max_L1 * max_L2;
    const int alpha_cols = max_L2 + 1;

    for (int b = 0; b < B; b++) {
        const float* a = alpha + b * alpha_stride;
        const float* s = sub_costs + b * score_stride;
        const float* tm = trans_mask + b * score_stride;
        const float* v = V + b * score_stride;
        float* da = d_alpha + b * alpha_stride;
        float* be = beta + b * alpha_stride;
        float* dbe = d_beta + b * alpha_stride;
        float* H = H_scores + b * score_stride;
        int L1 = lengths[b * 2];
        int L2 = lengths[b * 2 + 1];

        // Initialize
        for (size_t idx = 0; idx < alpha_stride; idx++) {
            da[idx] = 0.0f;
            be[idx] = 0.0f;
            dbe[idx] = 0.0f;
        }
        for (size_t idx = 0; idx < score_stride; idx++) {
            H[idx] = 0.0f;
        }

        // Forward tangent pass
        for (int k = 2; k <= L1 + L2; k++) {
            int i_start = std::max(1, k - L2);
            int i_end = std::min(L1, k - 1);

            for (int i = i_start; i <= i_end; i++) {
                int j = k - i;
                if (j < 1 || j > L2) continue;

                int idx = i * alpha_cols + j;
                int idx_diag = (i - 1) * alpha_cols + (j - 1);
                int idx_up = (i - 1) * alpha_cols + j;
                int idx_left = i * alpha_cols + (j - 1);
                int score_idx = (i - 1) * max_L2 + (j - 1);

                float sub_cost_val = s[score_idx];
                float v_ij = v[score_idx];
                float trans_valid = (i >= 2 && j >= 2) ? tm[score_idx] : 0.0f;

                float a_diag = a[idx_diag];
                float a_up = a[idx_up];
                float a_left = a[idx_left];

                float val_sub = a_diag + sub_cost_val;
                float val_del = a_up + del_cost;
                float val_ins = a_left + ins_cost;
                float val_trans = PINF;
                float da_trans = 0.0f;
                int idx_trans = -1;
                if (trans_valid > 0.5f) {
                    idx_trans = (i - 2) * alpha_cols + (j - 2);
                    val_trans = a[idx_trans] + trans_cost;
                    da_trans = da[idx_trans];
                }

                float w_sub, w_del, w_ins, w_trans;
                softmin4_weights(val_sub, val_del, val_ins, val_trans, T, w_sub, w_del, w_ins, w_trans);

                float dv_sub = da[idx_diag] + v_ij;
                float dv_del = da[idx_up];
                float dv_ins = da[idx_left];
                float dv_trans = (idx_trans >= 0) ? da_trans : 0.0f;

                da[idx] = w_sub * dv_sub + w_del * dv_del + w_ins * dv_ins + w_trans * dv_trans;
            }
        }

        int final_idx = L1 * alpha_cols + L2;
        d_score[b] = da[final_idx];

        // Backward pass
        be[final_idx] = 1.0f;
        dbe[final_idx] = 0.0f;

        for (int k = L1 + L2; k >= 2; k--) {
            int i_start = std::max(1, k - L2);
            int i_end = std::min(L1, k - 1);

            for (int i = i_start; i <= i_end; i++) {
                int j = k - i;
                if (j < 1 || j > L2) continue;

                int idx = i * alpha_cols + j;
                int idx_diag = (i - 1) * alpha_cols + (j - 1);
                int idx_up = (i - 1) * alpha_cols + j;
                int idx_left = i * alpha_cols + (j - 1);
                int score_idx = (i - 1) * max_L2 + (j - 1);

                float beta_ij = be[idx];
                float dbeta_ij = dbe[idx];
                float sub_cost_val = s[score_idx];
                float v_ij = v[score_idx];
                float trans_valid = (i >= 2 && j >= 2) ? tm[score_idx] : 0.0f;

                if (beta_ij == 0.0f && std::abs(dbeta_ij) < 1e-20f) continue;

                float a_diag = a[idx_diag];
                float a_up = a[idx_up];
                float a_left = a[idx_left];

                float val_sub = a_diag + sub_cost_val;
                float val_del = a_up + del_cost;
                float val_ins = a_left + ins_cost;
                float val_trans = PINF;
                float da_trans = 0.0f;
                int idx_trans = -1;
                if (trans_valid > 0.5f) {
                    idx_trans = (i - 2) * alpha_cols + (j - 2);
                    val_trans = a[idx_trans] + trans_cost;
                    da_trans = da[idx_trans];
                }

                float w_sub, w_del, w_ins, w_trans;
                softmin4_weights(val_sub, val_del, val_ins, val_trans, T, w_sub, w_del, w_ins, w_trans);

                float dv_sub = da[idx_diag] + v_ij;
                float dv_del = da[idx_up];
                float dv_ins = da[idx_left];
                float dv_trans = (idx_trans >= 0) ? da_trans : 0.0f;

                float E_dv = w_sub * dv_sub + w_del * dv_del + w_ins * dv_ins + w_trans * dv_trans;

                // Weight tangents for softmin
                float dw_sub = -w_sub * (dv_sub - E_dv) / T;
                float dw_del = -w_del * (dv_del - E_dv) / T;
                float dw_ins = -w_ins * (dv_ins - E_dv) / T;
                float dw_trans = -w_trans * (dv_trans - E_dv) / T;

                // HVP
                H[score_idx] = dbeta_ij * w_sub + beta_ij * dw_sub;

                // Propagate
                if (w_sub > 0.0f) {
                    be[idx_diag] += beta_ij * w_sub;
                    dbe[idx_diag] += dbeta_ij * w_sub + beta_ij * dw_sub;
                }
                if (w_del > 0.0f) {
                    be[idx_up] += beta_ij * w_del;
                    dbe[idx_up] += dbeta_ij * w_del + beta_ij * dw_del;
                }
                if (w_ins > 0.0f) {
                    be[idx_left] += beta_ij * w_ins;
                    dbe[idx_left] += dbeta_ij * w_ins + beta_ij * dw_ins;
                }
                if (w_trans > 0.0f && idx_trans >= 0) {
                    be[idx_trans] += beta_ij * w_trans;
                    dbe[idx_trans] += dbeta_ij * w_trans + beta_ij * dw_trans;
                }
            }
        }
    }
}

}  // namespace cpu
}  // namespace osa
}  // namespace d2p
