/**
 * @file kernels_cpu.cpp
 * @brief Soft Monotonic Alignment Search (MAS) CPU Kernel Implementations
 *
 * CPU implementation of MAS for TTS/ASR alignment.
 */

#include "kernels_cpu.h"
#include <cmath>
#include <algorithm>
#include <cstring>

namespace d2p {
namespace mas {
namespace cpu {

// =============================================================================
// Helper Functions
// =============================================================================

static inline float safe_exp(float x) {
    if (x < -88.0f) return 0.0f;
    if (x > 88.0f) x = 88.0f;
    return std::exp(x);
}

static inline float softmax2(float a, float b, float T) {
    float m = std::max(a, b);
    if (m <= NINF) return NINF;

    float ea = (a > NINF) ? safe_exp((a - m) / T) : 0.0f;
    float eb = (b > NINF) ? safe_exp((b - m) / T) : 0.0f;

    float sum = ea + eb;
    if (sum <= 0.0f) return NINF;
    return m + T * std::log(sum);
}

static inline void softmax2_weights(
    float a, float b, float T,
    float& wa, float& wb
) {
    float m = std::max(a, b);
    if (m <= NINF) {
        wa = wb = 0.0f;
        return;
    }

    float ea = (a > NINF) ? safe_exp((a - m) / T) : 0.0f;
    float eb = (b > NINF) ? safe_exp((b - m) / T) : 0.0f;

    float total = ea + eb;
    if (total > 0.0f) {
        wa = ea / total;
        wb = eb / total;
    } else {
        wa = wb = 0.0f;
    }
}

static inline void softmax2_tangent(
    float w1, float w2,
    float dv1, float dv2,
    float T,
    float& dw1, float& dw2
) {
    float sum_w_dv = w1 * dv1 + w2 * dv2;
    dw1 = w1 * (dv1 - sum_w_dv) / T;
    dw2 = w2 * (dv2 - sum_w_dv) / T;
}

// =============================================================================
// Forward Pass
// =============================================================================

void forward(
    const float* scores,
    float* alpha,
    float* partition,
    const int* lengths,
    int B, int max_T, int max_S,
    float temperature
) {
    const size_t stride = (size_t)max_T * max_S;

    for (int b = 0; b < B; b++) {
        const float* sc = scores + b * stride;
        float* a = alpha + b * stride;
        int T = lengths[b * 2];
        int S = lengths[b * 2 + 1];

        // Initialize to NINF
        for (size_t i = 0; i < stride; i++) {
            a[i] = NINF;
        }

        // Base case: α(0, 0) = score(0, 0)
        a[0] = sc[0];

        // Base case: α(t, 0) = α(t-1, 0) + score(t, 0)
        for (int t = 1; t < T; t++) {
            a[t * max_S] = a[(t - 1) * max_S] + sc[t * max_S];
        }

        // Fill DP table
        for (int t = 1; t < T; t++) {
            for (int s = 1; s < S; s++) {
                int idx = t * max_S + s;
                int idx_stay = (t - 1) * max_S + s;
                int idx_diag = (t - 1) * max_S + (s - 1);

                float stay = a[idx_stay];
                float diag = a[idx_diag];

                a[idx] = sc[idx] + softmax2(stay, diag, temperature);
            }
        }

        // Score
        partition[b] = a[(T - 1) * max_S + (S - 1)];
    }
}

// =============================================================================
// Backward Pass
// =============================================================================

void backward(
    const float* alpha,
    const float* scores,
    const float* partition,
    float* beta,
    float* posteriors,
    float* grad_T,
    const int* lengths,
    int B, int max_T, int max_S,
    float temperature
) {
    const size_t stride = (size_t)max_T * max_S;

    for (int b = 0; b < B; b++) {
        const float* a = alpha + b * stride;
        const float* sc = scores + b * stride;
        float* be = beta + b * stride;
        float* P = posteriors + b * stride;
        int T = lengths[b * 2];
        int S = lengths[b * 2 + 1];

        // Initialize
        for (size_t i = 0; i < stride; i++) {
            be[i] = 0.0f;
            P[i] = 0.0f;
        }

        // β(T-1, S-1) = 1
        be[(T - 1) * max_S + (S - 1)] = 1.0f;

        // Backward pass
        for (int t = T - 1; t >= 1; t--) {
            for (int s = S - 1; s >= 0; s--) {
                int idx = t * max_S + s;
                float beta_ts = be[idx];

                if (beta_ts < 1e-30f) continue;

                // Posteriors
                P[idx] = beta_ts;

                // Recompute weights
                int idx_stay = (t - 1) * max_S + s;
                int idx_diag = (t - 1) * max_S + (s - 1);

                float stay = a[idx_stay];
                float diag = (s >= 1) ? a[idx_diag] : NINF;

                float w_stay, w_diag;
                softmax2_weights(stay, diag, temperature, w_stay, w_diag);

                // Propagate beta
                be[idx_stay] += beta_ts * w_stay;
                if (s >= 1) {
                    be[idx_diag] += beta_ts * w_diag;
                }
            }
        }

        // First row posteriors
        for (int t = 0; t < T; t++) {
            P[t * max_S] = be[t * max_S];
        }

        // Temperature gradient
        float expected_score = 0.0f;
        for (int t = 0; t < T; t++) {
            for (int s = 0; s < S; s++) {
                int idx = t * max_S + s;
                expected_score += P[idx] * sc[idx];
            }
        }
        grad_T[b] = (partition[b] - expected_score) / temperature;
    }
}

// =============================================================================
// HVP (Hessian-Vector Product)
// =============================================================================

void hvp(
    const float* alpha,
    const float* scores,
    const float* V,
    float* d_alpha,
    float* d_score,
    float* beta,
    float* d_beta,
    float* H_scores,
    const int* lengths,
    int B, int max_T, int max_S,
    float temperature
) {
    const size_t stride = (size_t)max_T * max_S;

    for (int b = 0; b < B; b++) {
        const float* a = alpha + b * stride;
        const float* v = V + b * stride;
        float* da = d_alpha + b * stride;
        float* be = beta + b * stride;
        float* dbe = d_beta + b * stride;
        float* H = H_scores + b * stride;
        int T = lengths[b * 2];
        int S = lengths[b * 2 + 1];

        // Initialize
        for (size_t i = 0; i < stride; i++) {
            da[i] = 0.0f;
            be[i] = 0.0f;
            dbe[i] = 0.0f;
            H[i] = 0.0f;
        }

        // Forward tangent: first column
        da[0] = v[0];
        for (int t = 1; t < T; t++) {
            da[t * max_S] = da[(t - 1) * max_S] + v[t * max_S];
        }

        // Forward tangent: main DP
        for (int t = 1; t < T; t++) {
            for (int s = 1; s < S; s++) {
                int idx = t * max_S + s;
                int idx_stay = (t - 1) * max_S + s;
                int idx_diag = (t - 1) * max_S + (s - 1);

                float stay = a[idx_stay];
                float diag = a[idx_diag];

                float w_stay, w_diag;
                softmax2_weights(stay, diag, temperature, w_stay, w_diag);

                da[idx] = v[idx] + w_stay * da[idx_stay] + w_diag * da[idx_diag];
            }
        }

        // d_score
        d_score[b] = da[(T - 1) * max_S + (S - 1)];

        // Initialize beta at terminal
        be[(T - 1) * max_S + (S - 1)] = 1.0f;

        // Backward tangent pass
        for (int t = T - 1; t >= 1; t--) {
            for (int s = S - 1; s >= 0; s--) {
                int idx = t * max_S + s;
                float beta_ts = be[idx];
                float dbeta_ts = dbe[idx];

                if (beta_ts < 1e-30f && std::abs(dbeta_ts) < 1e-30f) continue;

                H[idx] = dbeta_ts;

                int idx_stay = (t - 1) * max_S + s;
                int idx_diag = (t - 1) * max_S + (s - 1);

                float stay = a[idx_stay];
                float diag = (s >= 1) ? a[idx_diag] : NINF;

                float w_stay, w_diag;
                softmax2_weights(stay, diag, temperature, w_stay, w_diag);

                float da_stay = da[idx_stay];
                float da_diag = (s >= 1) ? da[idx_diag] : 0.0f;

                float dw_stay, dw_diag;
                softmax2_tangent(w_stay, w_diag, da_stay, da_diag, temperature, dw_stay, dw_diag);

                be[idx_stay] += beta_ts * w_stay;
                dbe[idx_stay] += dbeta_ts * w_stay + beta_ts * dw_stay;

                if (s >= 1) {
                    be[idx_diag] += beta_ts * w_diag;
                    dbe[idx_diag] += dbeta_ts * w_diag + beta_ts * dw_diag;
                }
            }
        }

        // First column
        for (int t = 0; t < T; t++) {
            H[t * max_S] = dbe[t * max_S];
        }
    }
}

// =============================================================================
// Parameter Gradient (∂P/∂T)
// =============================================================================

void param_grad(
    const float* alpha,
    const float* scores,
    float* U,
    float* beta,
    float* W,
    float* dP_dT,
    const int* lengths,
    int B, int max_T, int max_S,
    float temperature
) {
    const size_t stride = (size_t)max_T * max_S;

    for (int b = 0; b < B; b++) {
        const float* a = alpha + b * stride;
        float* u = U + b * stride;
        float* be = beta + b * stride;
        float* w = W + b * stride;
        float* dP = dP_dT + b * stride;
        int T = lengths[b * 2];
        int S = lengths[b * 2 + 1];

        // Initialize
        for (size_t i = 0; i < stride; i++) {
            u[i] = 0.0f;
            be[i] = 0.0f;
            w[i] = 0.0f;
            dP[i] = 0.0f;
        }

        // Forward U pass
        for (int t = 1; t < T; t++) {
            for (int s = 1; s < S; s++) {
                int idx = t * max_S + s;
                int idx_stay = (t - 1) * max_S + s;
                int idx_diag = (t - 1) * max_S + (s - 1);

                float stay = a[idx_stay];
                float diag = a[idx_diag];

                float w_stay, w_diag;
                softmax2_weights(stay, diag, temperature, w_stay, w_diag);

                float u_stay = u[idx_stay];
                float u_diag = u[idx_diag];

                float E_v = w_stay * stay + w_diag * diag;
                float inv_T2 = 1.0f / (temperature * temperature);

                float du_stay = w_stay * (E_v - stay) * inv_T2;
                float du_diag = w_diag * (E_v - diag) * inv_T2;

                u[idx] = w_stay * u_stay + w_diag * u_diag + du_stay + du_diag;
            }
        }

        // Initialize beta at terminal
        be[(T - 1) * max_S + (S - 1)] = 1.0f;

        // Backward W pass
        for (int t = T - 1; t >= 1; t--) {
            for (int s = S - 1; s >= 0; s--) {
                int idx = t * max_S + s;
                float beta_ts = be[idx];
                float w_ts = w[idx];

                if (beta_ts < 1e-30f && std::abs(w_ts) < 1e-30f) continue;

                dP[idx] = w_ts;

                int idx_stay = (t - 1) * max_S + s;
                int idx_diag = (t - 1) * max_S + (s - 1);

                float stay = a[idx_stay];
                float diag = (s >= 1) ? a[idx_diag] : NINF;

                float wt_stay, wt_diag;
                softmax2_weights(stay, diag, temperature, wt_stay, wt_diag);

                float u_stay = u[idx_stay];
                float u_diag = (s >= 1) ? u[idx_diag] : 0.0f;

                float dw_stay, dw_diag;
                softmax2_tangent(wt_stay, wt_diag, u_stay, u_diag, temperature, dw_stay, dw_diag);

                float E_v = wt_stay * stay + wt_diag * diag;
                float inv_T2 = 1.0f / (temperature * temperature);
                dw_stay += wt_stay * (E_v - stay) * inv_T2;
                dw_diag += wt_diag * (E_v - diag) * inv_T2;

                be[idx_stay] += beta_ts * wt_stay;
                w[idx_stay] += w_ts * wt_stay + beta_ts * dw_stay;

                if (s >= 1) {
                    be[idx_diag] += beta_ts * wt_diag;
                    w[idx_diag] += w_ts * wt_diag + beta_ts * dw_diag;
                }
            }
        }
    }
}

} // namespace cpu
} // namespace mas
} // namespace d2p
