/**
 * @file kernels_cpu.cpp
 * @brief Soft Eisner CPU Kernel Implementations
 *
 * Pure CPU kernels for differentiable projective dependency parsing.
 * Mirrors the CUDA kernel interface for seamless dispatch.
 */

#include "kernels_cpu.h"
#include <cmath>
#include <algorithm>
#include <vector>

namespace d2p {
namespace eisner {
namespace cpu {

// =============================================================================
// Helper Functions
// =============================================================================

inline float safe_exp(float x) {
    if (x < -88.0f) return 0.0f;
    if (x > 88.0f) x = 88.0f;
    return std::exp(x);
}

// Kahan compensated summation for better numerical precision
struct KahanAccumulator {
    float sum = 0.0f;
    float c = 0.0f;

    void add(float value) {
        float y = value - c;
        float t = sum + y;
        c = (t - sum) - y;
        sum = t;
    }

    float result() const { return sum; }
    void reset() { sum = 0.0f; c = 0.0f; }
};

// Temperature-scaled logsumexp
inline float logsumexp_T(const std::vector<float>& vals, float T) {
    if (vals.empty()) return NINF;
    float max_v = *std::max_element(vals.begin(), vals.end());
    if (max_v <= NINF) return NINF;

    KahanAccumulator sum;
    for (float v : vals) {
        if (v > NINF) {
            sum.add(safe_exp((v - max_v) / T));
        }
    }
    return max_v + T * std::log(sum.result());
}

// Compute softmax weights
inline void softmax_T(const std::vector<float>& vals, float T, std::vector<float>& weights) {
    weights.resize(vals.size());
    if (vals.empty()) return;

    float max_v = *std::max_element(vals.begin(), vals.end());
    if (max_v <= NINF) {
        std::fill(weights.begin(), weights.end(), 0.0f);
        return;
    }

    KahanAccumulator sum;
    for (size_t i = 0; i < vals.size(); i++) {
        if (vals[i] > NINF) {
            weights[i] = safe_exp((vals[i] - max_v) / T);
            sum.add(weights[i]);
        } else {
            weights[i] = 0.0f;
        }
    }

    float total = sum.result();
    if (total > 0) {
        for (size_t i = 0; i < vals.size(); i++) {
            weights[i] /= total;
        }
    }
}

// =============================================================================
// Forward Pass
// =============================================================================

void forward(
    const float* arc_scores,
    float* C_R,
    float* C_L,
    float* I_R,
    float* I_L,
    float* partition,
    const int* lengths,
    int B, int n, float T
) {
    const size_t stride = (size_t)n * n;

    for (int b = 0; b < B; b++) {
        const float* arc = arc_scores + b * stride;
        float* cr = C_R + b * stride;
        float* cl = C_L + b * stride;
        float* ir = I_R + b * stride;
        float* il = I_L + b * stride;

        int seq_len = lengths ? lengths[b] : n;

        // Initialize tables
        for (size_t idx = 0; idx < stride; idx++) {
            cr[idx] = NINF;
            cl[idx] = NINF;
            ir[idx] = NINF;
            il[idx] = NINF;
        }

        // Base case: C[i,i] = 0
        for (int i = 0; i < seq_len; i++) {
            cr[i * n + i] = 0.0f;
            cl[i * n + i] = 0.0f;
        }

        // Process spans by increasing length
        std::vector<float> terms;
        for (int len = 1; len < seq_len; len++) {
            for (int i = 0; i + len < seq_len; i++) {
                int j = i + len;

                // Incomplete spans: I_R[i,j] and I_L[i,j]
                // I_R[i,j] = arc[i,j] + LSE_k{ C_R[i,k] + C_L[k+1,j] }
                // I_L[i,j] = arc[j,i] + LSE_k{ C_R[i,k] + C_L[k+1,j] }
                terms.clear();
                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + cl[(k + 1) * n + j];
                    terms.push_back(v);
                }
                float lse = logsumexp_T(terms, T);
                ir[i * n + j] = arc[i * n + j] + lse;
                il[i * n + j] = arc[j * n + i] + lse;

                // Complete spans: C_R[i,j] and C_L[i,j]
                // C_R[i,j] = LSE_k{ C_R[i,k] + I_R[k,j] }  for k in [i, j)
                terms.clear();
                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + ir[k * n + j];
                    terms.push_back(v);
                }
                cr[i * n + j] = logsumexp_T(terms, T);

                // C_L[i,j] = LSE_k{ I_L[i,k] + C_L[k,j] }  for k in (i, j]
                terms.clear();
                for (int k = i + 1; k <= j; k++) {
                    float v = il[i * n + k] + cl[k * n + j];
                    terms.push_back(v);
                }
                cl[i * n + j] = logsumexp_T(terms, T);
            }
        }

        // Partition = C_R[0, seq_len-1]
        partition[b] = cr[0 * n + (seq_len - 1)];
    }
}

// =============================================================================
// Backward Pass
// =============================================================================

void backward(
    const float* arc_scores,
    const float* C_R,
    const float* C_L,
    const float* I_R,
    const float* I_L,
    float* beta_C_R,
    float* beta_C_L,
    float* beta_I_R,
    float* beta_I_L,
    float* marginals,
    float* grad_T,
    const int* lengths,
    int B, int n, float T
) {
    const size_t stride = (size_t)n * n;

    for (int b = 0; b < B; b++) {
        const float* arc = arc_scores + b * stride;
        const float* cr = C_R + b * stride;
        const float* cl = C_L + b * stride;
        const float* ir = I_R + b * stride;
        const float* il = I_L + b * stride;
        float* bcr = beta_C_R + b * stride;
        float* bcl = beta_C_L + b * stride;
        float* bir = beta_I_R + b * stride;
        float* bil = beta_I_L + b * stride;
        float* marg = marginals + b * stride;

        int seq_len = lengths ? lengths[b] : n;

        // Initialize
        for (size_t idx = 0; idx < stride; idx++) {
            bcr[idx] = 0.0f;
            bcl[idx] = 0.0f;
            bir[idx] = 0.0f;
            bil[idx] = 0.0f;
            marg[idx] = 0.0f;
        }

        // Root span beta = 1
        bcr[0 * n + (seq_len - 1)] = 1.0f;

        float sum_grad_T = 0.0f;
        std::vector<float> terms, weights;

        // Top-down by decreasing span length
        for (int len = seq_len - 1; len >= 1; len--) {
            for (int i = 0; i + len < seq_len; i++) {
                int j = i + len;

                // Backward for C_R[i,j] = LSE_k{ C_R[i,k] + I_R[k,j] }
                float beta_cr_ij = bcr[i * n + j];
                if (beta_cr_ij != 0.0f) {
                    terms.clear();
                    for (int k = i; k < j; k++) {
                        float v = cr[i * n + k] + ir[k * n + j];
                        terms.push_back(v);
                    }
                    softmax_T(terms, T, weights);

                    float Zij = cr[i * n + j];
                    float E_term = 0.0f;
                    for (int k = i; k < j; k++) {
                        float mass = beta_cr_ij * weights[k - i];
                        bcr[i * n + k] += mass;
                        bir[k * n + j] += mass;
                        E_term += weights[k - i] * terms[k - i];
                    }
                    sum_grad_T += beta_cr_ij * (Zij - E_term) / T;
                }

                // Backward for C_L[i,j] = LSE_k{ I_L[i,k] + C_L[k,j] }
                float beta_cl_ij = bcl[i * n + j];
                if (beta_cl_ij != 0.0f) {
                    terms.clear();
                    for (int k = i + 1; k <= j; k++) {
                        float v = il[i * n + k] + cl[k * n + j];
                        terms.push_back(v);
                    }
                    softmax_T(terms, T, weights);

                    float Zij = cl[i * n + j];
                    float E_term = 0.0f;
                    for (int k = i + 1; k <= j; k++) {
                        float mass = beta_cl_ij * weights[k - i - 1];
                        bil[i * n + k] += mass;
                        bcl[k * n + j] += mass;
                        E_term += weights[k - i - 1] * terms[k - i - 1];
                    }
                    sum_grad_T += beta_cl_ij * (Zij - E_term) / T;
                }

                // Backward for incomplete spans
                float beta_ir_ij = bir[i * n + j];
                float beta_il_ij = bil[i * n + j];

                // Arc marginals
                marg[i * n + j] = beta_ir_ij;
                marg[j * n + i] = beta_il_ij;

                float beta_combined = beta_ir_ij + beta_il_ij;
                if (beta_combined != 0.0f) {
                    terms.clear();
                    for (int k = i; k < j; k++) {
                        float v = cr[i * n + k] + cl[(k + 1) * n + j];
                        terms.push_back(v);
                    }
                    softmax_T(terms, T, weights);

                    float lse = logsumexp_T(terms, T);
                    float E_term = 0.0f;
                    for (int k = i; k < j; k++) {
                        float mass = beta_combined * weights[k - i];
                        bcr[i * n + k] += mass;
                        bcl[(k + 1) * n + j] += mass;
                        E_term += weights[k - i] * terms[k - i];
                    }
                    sum_grad_T += beta_combined * (lse - E_term) / T;
                }
            }
        }

        grad_T[b] = sum_grad_T;
    }
}

// =============================================================================
// Hessian-Vector Product (HVP)
// =============================================================================

void hvp(
    const float* arc_scores,
    const float* V,
    const float* C_R,
    const float* C_L,
    const float* I_R,
    const float* I_L,
    float* d_C_R,
    float* d_C_L,
    float* d_I_R,
    float* d_I_L,
    float* beta_C_R,
    float* beta_C_L,
    float* beta_I_R,
    float* beta_I_L,
    float* d_beta_C_R,
    float* d_beta_C_L,
    float* d_beta_I_R,
    float* d_beta_I_L,
    float* HVP,
    const int* lengths,
    int B, int n, float T
) {
    const size_t stride = (size_t)n * n;

    for (int b = 0; b < B; b++) {
        const float* arc = arc_scores + b * stride;
        const float* v = V + b * stride;
        const float* cr = C_R + b * stride;
        const float* cl = C_L + b * stride;
        const float* ir = I_R + b * stride;
        const float* il = I_L + b * stride;
        float* dcr = d_C_R + b * stride;
        float* dcl = d_C_L + b * stride;
        float* dir = d_I_R + b * stride;
        float* dil = d_I_L + b * stride;
        float* bcr = beta_C_R + b * stride;
        float* bcl = beta_C_L + b * stride;
        float* bir = beta_I_R + b * stride;
        float* bil = beta_I_L + b * stride;
        float* dbcr = d_beta_C_R + b * stride;
        float* dbcl = d_beta_C_L + b * stride;
        float* dbir = d_beta_I_R + b * stride;
        float* dbil = d_beta_I_L + b * stride;
        float* hvp_out = HVP + b * stride;

        int seq_len = lengths ? lengths[b] : n;

        // Initialize
        for (size_t idx = 0; idx < stride; idx++) {
            dcr[idx] = 0.0f;
            dcl[idx] = 0.0f;
            dir[idx] = 0.0f;
            dil[idx] = 0.0f;
            bcr[idx] = 0.0f;
            bcl[idx] = 0.0f;
            bir[idx] = 0.0f;
            bil[idx] = 0.0f;
            dbcr[idx] = 0.0f;
            dbcl[idx] = 0.0f;
            dbir[idx] = 0.0f;
            dbil[idx] = 0.0f;
            hvp_out[idx] = 0.0f;
        }

        // Base case
        for (int i = 0; i < seq_len; i++) {
            dcr[i * n + i] = 0.0f;
            dcl[i * n + i] = 0.0f;
        }
        bcr[0 * n + (seq_len - 1)] = 1.0f;

        std::vector<float> terms, weights, d_terms;

        // Forward pass for tangents
        for (int len = 1; len < seq_len; len++) {
            for (int i = 0; i + len < seq_len; i++) {
                int j = i + len;

                // Tangent for incomplete spans
                terms.clear();
                for (int k = i; k < j; k++) {
                    float val = cr[i * n + k] + cl[(k + 1) * n + j];
                    terms.push_back(val);
                }
                softmax_T(terms, T, weights);

                float d_lse = 0.0f;
                for (int k = i; k < j; k++) {
                    d_lse += weights[k - i] * (dcr[i * n + k] + dcl[(k + 1) * n + j]);
                }
                dir[i * n + j] = v[i * n + j] + d_lse;
                dil[i * n + j] = v[j * n + i] + d_lse;

                // Tangent for C_R[i,j]
                terms.clear();
                for (int k = i; k < j; k++) {
                    float val = cr[i * n + k] + ir[k * n + j];
                    terms.push_back(val);
                }
                softmax_T(terms, T, weights);

                d_lse = 0.0f;
                for (int k = i; k < j; k++) {
                    d_lse += weights[k - i] * (dcr[i * n + k] + dir[k * n + j]);
                }
                dcr[i * n + j] = d_lse;

                // Tangent for C_L[i,j]
                terms.clear();
                for (int k = i + 1; k <= j; k++) {
                    float val = il[i * n + k] + cl[k * n + j];
                    terms.push_back(val);
                }
                softmax_T(terms, T, weights);

                d_lse = 0.0f;
                for (int k = i + 1; k <= j; k++) {
                    d_lse += weights[k - i - 1] * (dil[i * n + k] + dcl[k * n + j]);
                }
                dcl[i * n + j] = d_lse;
            }
        }

        // Backward pass
        for (int len = seq_len - 1; len >= 1; len--) {
            for (int i = 0; i + len < seq_len; i++) {
                int j = i + len;

                // Backward for C_R[i,j]
                float beta_cr_ij = bcr[i * n + j];
                float d_beta_cr_ij = dbcr[i * n + j];

                if (beta_cr_ij != 0.0f || d_beta_cr_ij != 0.0f) {
                    terms.clear();
                    d_terms.clear();
                    for (int k = i; k < j; k++) {
                        float val = cr[i * n + k] + ir[k * n + j];
                        terms.push_back(val);
                        d_terms.push_back(dcr[i * n + k] + dir[k * n + j]);
                    }
                    softmax_T(terms, T, weights);

                    float E_d_term = 0.0f;
                    for (int k = i; k < j; k++) {
                        E_d_term += weights[k - i] * d_terms[k - i];
                    }

                    for (int k = i; k < j; k++) {
                        float w = weights[k - i];
                        float d_w = w * (d_terms[k - i] - E_d_term) / T;

                        float mass = beta_cr_ij * w;
                        float d_mass = d_beta_cr_ij * w + beta_cr_ij * d_w;

                        bcr[i * n + k] += mass;
                        bir[k * n + j] += mass;
                        dbcr[i * n + k] += d_mass;
                        dbir[k * n + j] += d_mass;
                    }
                }

                // Backward for C_L[i,j]
                float beta_cl_ij = bcl[i * n + j];
                float d_beta_cl_ij = dbcl[i * n + j];

                if (beta_cl_ij != 0.0f || d_beta_cl_ij != 0.0f) {
                    terms.clear();
                    d_terms.clear();
                    for (int k = i + 1; k <= j; k++) {
                        float val = il[i * n + k] + cl[k * n + j];
                        terms.push_back(val);
                        d_terms.push_back(dil[i * n + k] + dcl[k * n + j]);
                    }
                    softmax_T(terms, T, weights);

                    float E_d_term = 0.0f;
                    for (int k = i + 1; k <= j; k++) {
                        E_d_term += weights[k - i - 1] * d_terms[k - i - 1];
                    }

                    for (int k = i + 1; k <= j; k++) {
                        float w = weights[k - i - 1];
                        float d_w = w * (d_terms[k - i - 1] - E_d_term) / T;

                        float mass = beta_cl_ij * w;
                        float d_mass = d_beta_cl_ij * w + beta_cl_ij * d_w;

                        bil[i * n + k] += mass;
                        bcl[k * n + j] += mass;
                        dbil[i * n + k] += d_mass;
                        dbcl[k * n + j] += d_mass;
                    }
                }

                // Backward for incomplete spans
                float beta_ir = bir[i * n + j];
                float beta_il = bil[i * n + j];
                float d_beta_ir = dbir[i * n + j];
                float d_beta_il = dbil[i * n + j];

                // HVP output
                hvp_out[i * n + j] = d_beta_ir;
                hvp_out[j * n + i] = d_beta_il;

                float beta_combined = beta_ir + beta_il;
                float d_beta_combined = d_beta_ir + d_beta_il;

                if (beta_combined != 0.0f || d_beta_combined != 0.0f) {
                    terms.clear();
                    d_terms.clear();
                    for (int k = i; k < j; k++) {
                        float val = cr[i * n + k] + cl[(k + 1) * n + j];
                        terms.push_back(val);
                        d_terms.push_back(dcr[i * n + k] + dcl[(k + 1) * n + j]);
                    }
                    softmax_T(terms, T, weights);

                    float E_d_term = 0.0f;
                    for (int k = i; k < j; k++) {
                        E_d_term += weights[k - i] * d_terms[k - i];
                    }

                    for (int k = i; k < j; k++) {
                        float w = weights[k - i];
                        float d_w = w * (d_terms[k - i] - E_d_term) / T;

                        float mass = beta_combined * w;
                        float d_mass = d_beta_combined * w + beta_combined * d_w;

                        bcr[i * n + k] += mass;
                        bcl[(k + 1) * n + j] += mass;
                        dbcr[i * n + k] += d_mass;
                        dbcl[(k + 1) * n + j] += d_mass;
                    }
                }
            }
        }
    }
}

} // namespace cpu
} // namespace eisner
} // namespace d2p
