/**
 * @file kernels.cu
 * @brief Soft Eisner CUDA Kernel Implementations
 *
 * Differentiable Eisner algorithm for projective dependency parsing.
 * Uses wavefront parallelization over span lengths.
 */

#include "kernels.cuh"
#include <cuda_runtime.h>
#include <math.h>

namespace d2p {
namespace eisner {

// ============================================================================
// Device Helpers
// ============================================================================

#define WARP_SIZE 32

template<typename T>
__device__ __forceinline__ T safe_exp(T x) {
    if (x < -88.0f) return (T)0.0f;
    if (x > 88.0f) x = (T)88.0f;
    return exp(x);
}

template<typename T>
__device__ __forceinline__ T warp_reduce_sum(T v) {
    #pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
        v += __shfl_down_sync(0xffffffff, v, offset);
    }
    return v;
}

template<typename T>
__device__ __forceinline__ T block_reduce_sum(T v) {
    __shared__ T shared[32];
    int lane = threadIdx.x % WARP_SIZE;
    int wid  = threadIdx.x / WARP_SIZE;

    v = warp_reduce_sum(v);
    if (lane == 0) shared[wid] = v;
    __syncthreads();

    int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
    v = (threadIdx.x < num_warps) ? shared[lane] : (T)0.0f;
    if (wid == 0) v = warp_reduce_sum(v);
    return v;
}

// ============================================================================
// Forward Pass Kernels
// ============================================================================

__global__ void init_kernel(
    float* __restrict__ C_R,
    float* __restrict__ C_L,
    float* __restrict__ I_R,
    float* __restrict__ I_L,
    const int* __restrict__ lengths,
    int B, int n
) {
    size_t stride = (size_t)n * n;
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
    size_t total = (size_t)B * stride;
    if (idx >= total) return;

    size_t b = idx / stride;
    size_t rem = idx - b * stride;
    int i = rem / n;
    int j = rem % n;

    int seq_len = lengths ? lengths[b] : n;

    if (i == j && i < seq_len) {
        C_R[idx] = 0.0f;
        C_L[idx] = 0.0f;
    } else {
        C_R[idx] = NINF;
        C_L[idx] = NINF;
    }
    I_R[idx] = NINF;
    I_L[idx] = NINF;
}

__global__ void forward_incomplete_kernel(
    const float* __restrict__ arc_scores,
    const float* __restrict__ C_R,
    const float* __restrict__ C_L,
    float* __restrict__ I_R,
    float* __restrict__ I_L,
    const int* __restrict__ lengths,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t stride = (size_t)n * n;

    const float* arc = arc_scores + b * stride;
    const float* cr = C_R + b * stride;
    const float* cl = C_L + b * stride;
    float* ir = I_R + b * stride;
    float* il = I_L + b * stride;

    int seq_len = lengths ? lengths[b] : n;
    int num_spans = seq_len - span_len;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len;

        if (j >= seq_len) continue;

        // I_R[i,j] = arc[i,j] + LSE_k{ C_R[i,k] + C_L[k+1,j] }
        // I_L[i,j] = arc[j,i] + LSE_k{ C_R[i,k] + C_L[k+1,j] }

        float max_v = NINF;
        for (int k = i; k < j; k++) {
            float v = cr[i * n + k] + cl[(k + 1) * n + j];
            max_v = fmaxf(max_v, v);
        }

        if (max_v <= NINF) {
            ir[i * n + j] = NINF;
            il[i * n + j] = NINF;
            continue;
        }

        float sum_exp = 0.0f;
        for (int k = i; k < j; k++) {
            float v = cr[i * n + k] + cl[(k + 1) * n + j];
            sum_exp += safe_exp((v - max_v) / T);
        }

        float lse = max_v + T * logf(sum_exp);
        ir[i * n + j] = arc[i * n + j] + lse;
        il[i * n + j] = arc[j * n + i] + lse;
    }
}

__global__ void forward_complete_kernel(
    const float* __restrict__ C_R,
    const float* __restrict__ C_L,
    const float* __restrict__ I_R,
    const float* __restrict__ I_L,
    float* __restrict__ C_R_out,
    float* __restrict__ C_L_out,
    const int* __restrict__ lengths,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t stride = (size_t)n * n;

    const float* cr = C_R + b * stride;
    const float* cl = C_L + b * stride;
    const float* ir = I_R + b * stride;
    const float* il = I_L + b * stride;
    float* cr_out = C_R_out + b * stride;
    float* cl_out = C_L_out + b * stride;

    int seq_len = lengths ? lengths[b] : n;
    int num_spans = seq_len - span_len;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len;

        if (j >= seq_len) continue;

        // C_R[i,j] = LSE_k{ C_R[i,k] + I_R[k,j] }
        {
            float max_v = NINF;
            for (int k = i; k < j; k++) {
                float v = cr[i * n + k] + ir[k * n + j];
                max_v = fmaxf(max_v, v);
            }

            if (max_v <= NINF) {
                cr_out[i * n + j] = NINF;
            } else {
                float sum_exp = 0.0f;
                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + ir[k * n + j];
                    sum_exp += safe_exp((v - max_v) / T);
                }
                cr_out[i * n + j] = max_v + T * logf(sum_exp);
            }
        }

        // C_L[i,j] = LSE_k{ I_L[i,k] + C_L[k,j] }
        {
            float max_v = NINF;
            for (int k = i + 1; k <= j; k++) {
                float v = il[i * n + k] + cl[k * n + j];
                max_v = fmaxf(max_v, v);
            }

            if (max_v <= NINF) {
                cl_out[i * n + j] = NINF;
            } else {
                float sum_exp = 0.0f;
                for (int k = i + 1; k <= j; k++) {
                    float v = il[i * n + k] + cl[k * n + j];
                    sum_exp += safe_exp((v - max_v) / T);
                }
                cl_out[i * n + j] = max_v + T * logf(sum_exp);
            }
        }
    }
}

__global__ void extract_partition_kernel(
    const float* __restrict__ C_R,
    float* __restrict__ partition,
    const int* __restrict__ lengths,
    int B, int n
) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;
    if (b >= B) return;

    int seq_len = lengths ? lengths[b] : n;
    partition[b] = C_R[(size_t)b * n * n + (seq_len - 1)];
}

// ============================================================================
// Backward Pass Kernels
// ============================================================================

__global__ void init_beta_kernel(
    float* __restrict__ beta_C_R,
    float* __restrict__ beta_C_L,
    float* __restrict__ beta_I_R,
    float* __restrict__ beta_I_L,
    float* __restrict__ marginals,
    float* __restrict__ grad_T,
    const int* __restrict__ lengths,
    int B, int n
) {
    size_t stride = (size_t)n * n;
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
    size_t total = (size_t)B * stride;

    if (idx < total) {
        size_t b = idx / stride;
        size_t rem = idx - b * stride;
        int i = rem / n;
        int j = rem % n;

        int seq_len = lengths ? lengths[b] : n;

        if (i == 0 && j == seq_len - 1) {
            beta_C_R[idx] = 1.0f;
        } else {
            beta_C_R[idx] = 0.0f;
        }
        beta_C_L[idx] = 0.0f;
        beta_I_R[idx] = 0.0f;
        beta_I_L[idx] = 0.0f;
        marginals[idx] = 0.0f;
    }

    if (idx < (size_t)B) {
        grad_T[idx] = 0.0f;
    }
}

__global__ void backward_complete_kernel(
    const float* __restrict__ C_R,
    const float* __restrict__ C_L,
    const float* __restrict__ I_R,
    const float* __restrict__ I_L,
    float* __restrict__ beta_C_R,
    float* __restrict__ beta_C_L,
    float* __restrict__ beta_I_R,
    float* __restrict__ beta_I_L,
    float* __restrict__ grad_T,
    const int* __restrict__ lengths,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t stride = (size_t)n * n;

    const float* cr = C_R + b * stride;
    const float* cl = C_L + b * stride;
    const float* ir = I_R + b * stride;
    const float* il = I_L + b * stride;
    float* bcr = beta_C_R + b * stride;
    float* bcl = beta_C_L + b * stride;
    float* bir = beta_I_R + b * stride;
    float* bil = beta_I_L + b * stride;

    int seq_len = lengths ? lengths[b] : n;
    int num_spans = seq_len - span_len;
    if (num_spans <= 0) return;

    float local_grad_T = 0.0f;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len;

        if (j >= seq_len) continue;

        // Backward for C_R[i,j]
        float beta_cr_ij = bcr[i * n + j];
        if (beta_cr_ij != 0.0f) {
            float max_v = NINF;
            for (int k = i; k < j; k++) {
                float v = cr[i * n + k] + ir[k * n + j];
                max_v = fmaxf(max_v, v);
            }

            if (max_v > NINF) {
                float sum_exp = 0.0f;
                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + ir[k * n + j];
                    sum_exp += safe_exp((v - max_v) / T);
                }
                float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

                float Zij = cr[i * n + j];
                float E_term = 0.0f;

                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + ir[k * n + j];
                    float w = safe_exp((v - max_v) / T) * inv_sum;
                    float mass = beta_cr_ij * w;

                    atomicAdd(&bcr[i * n + k], mass);
                    atomicAdd(&bir[k * n + j], mass);

                    E_term += w * v;
                }

                local_grad_T += beta_cr_ij * (Zij - E_term) / T;
            }
        }

        // Backward for C_L[i,j]
        float beta_cl_ij = bcl[i * n + j];
        if (beta_cl_ij != 0.0f) {
            float max_v = NINF;
            for (int k = i + 1; k <= j; k++) {
                float v = il[i * n + k] + cl[k * n + j];
                max_v = fmaxf(max_v, v);
            }

            if (max_v > NINF) {
                float sum_exp = 0.0f;
                for (int k = i + 1; k <= j; k++) {
                    float v = il[i * n + k] + cl[k * n + j];
                    sum_exp += safe_exp((v - max_v) / T);
                }
                float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

                float Zij = cl[i * n + j];
                float E_term = 0.0f;

                for (int k = i + 1; k <= j; k++) {
                    float v = il[i * n + k] + cl[k * n + j];
                    float w = safe_exp((v - max_v) / T) * inv_sum;
                    float mass = beta_cl_ij * w;

                    atomicAdd(&bil[i * n + k], mass);
                    atomicAdd(&bcl[k * n + j], mass);

                    E_term += w * v;
                }

                local_grad_T += beta_cl_ij * (Zij - E_term) / T;
            }
        }
    }

    float block_grad_T = block_reduce_sum(local_grad_T);
    if (threadIdx.x == 0) {
        atomicAdd(&grad_T[b], block_grad_T);
    }
}

__global__ void backward_incomplete_kernel(
    const float* __restrict__ arc_scores,
    const float* __restrict__ C_R,
    const float* __restrict__ C_L,
    const float* __restrict__ I_R,
    const float* __restrict__ I_L,
    float* __restrict__ beta_C_R,
    float* __restrict__ beta_C_L,
    const float* __restrict__ beta_I_R,
    const float* __restrict__ beta_I_L,
    float* __restrict__ marginals,
    float* __restrict__ grad_T,
    const int* __restrict__ lengths,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t stride = (size_t)n * n;

    const float* arc = arc_scores + b * stride;
    const float* cr = C_R + b * stride;
    const float* cl = C_L + b * stride;
    const float* ir = I_R + b * stride;
    const float* il = I_L + b * stride;
    float* bcr = beta_C_R + b * stride;
    float* bcl = beta_C_L + b * stride;
    const float* bir = beta_I_R + b * stride;
    const float* bil = beta_I_L + b * stride;
    float* marg = marginals + b * stride;

    int seq_len = lengths ? lengths[b] : n;
    int num_spans = seq_len - span_len;
    if (num_spans <= 0) return;

    float local_grad_T = 0.0f;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len;

        if (j >= seq_len) continue;

        float beta_ir_ij = bir[i * n + j];
        float beta_il_ij = bil[i * n + j];

        // Arc marginals
        marg[i * n + j] = beta_ir_ij;
        marg[j * n + i] = beta_il_ij;

        float beta_combined = beta_ir_ij + beta_il_ij;

        if (beta_combined != 0.0f) {
            float max_v = NINF;
            for (int k = i; k < j; k++) {
                float v = cr[i * n + k] + cl[(k + 1) * n + j];
                max_v = fmaxf(max_v, v);
            }

            if (max_v > NINF) {
                float sum_exp = 0.0f;
                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + cl[(k + 1) * n + j];
                    sum_exp += safe_exp((v - max_v) / T);
                }
                float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

                float lse = max_v + T * logf(sum_exp);
                float E_term = 0.0f;

                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + cl[(k + 1) * n + j];
                    float w = safe_exp((v - max_v) / T) * inv_sum;
                    float mass = beta_combined * w;

                    atomicAdd(&bcr[i * n + k], mass);
                    atomicAdd(&bcl[(k + 1) * n + j], mass);

                    E_term += w * v;
                }

                local_grad_T += beta_combined * (lse - E_term) / T;
            }
        }
    }

    float block_grad_T = block_reduce_sum(local_grad_T);
    if (threadIdx.x == 0) {
        atomicAdd(&grad_T[b], block_grad_T);
    }
}

// ============================================================================
// HVP Kernels
// ============================================================================

__global__ void hvp_init_kernel(
    const float* __restrict__ V,
    float* __restrict__ d_C_R,
    float* __restrict__ d_C_L,
    float* __restrict__ d_I_R,
    float* __restrict__ d_I_L,
    float* __restrict__ beta_C_R,
    float* __restrict__ beta_C_L,
    float* __restrict__ beta_I_R,
    float* __restrict__ beta_I_L,
    float* __restrict__ d_beta_C_R,
    float* __restrict__ d_beta_C_L,
    float* __restrict__ d_beta_I_R,
    float* __restrict__ d_beta_I_L,
    float* __restrict__ HVP,
    const int* __restrict__ lengths,
    int B, int n
) {
    size_t stride = (size_t)n * n;
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
    size_t total = (size_t)B * stride;

    if (idx < total) {
        size_t b = idx / stride;
        size_t rem = idx - b * stride;
        int i = rem / n;
        int j = rem % n;

        int seq_len = lengths ? lengths[b] : n;

        if (i == j && i < seq_len) {
            d_C_R[idx] = 0.0f;
            d_C_L[idx] = 0.0f;
        } else {
            d_C_R[idx] = 0.0f;
            d_C_L[idx] = 0.0f;
        }
        d_I_R[idx] = 0.0f;
        d_I_L[idx] = 0.0f;

        if (i == 0 && j == seq_len - 1) {
            beta_C_R[idx] = 1.0f;
        } else {
            beta_C_R[idx] = 0.0f;
        }
        beta_C_L[idx] = 0.0f;
        beta_I_R[idx] = 0.0f;
        beta_I_L[idx] = 0.0f;

        d_beta_C_R[idx] = 0.0f;
        d_beta_C_L[idx] = 0.0f;
        d_beta_I_R[idx] = 0.0f;
        d_beta_I_L[idx] = 0.0f;

        HVP[idx] = 0.0f;
    }
}

__global__ void hvp_forward_incomplete_kernel(
    const float* __restrict__ arc_scores,
    const float* __restrict__ V,
    const float* __restrict__ C_R,
    const float* __restrict__ C_L,
    const float* __restrict__ d_C_R,
    const float* __restrict__ d_C_L,
    float* __restrict__ d_I_R,
    float* __restrict__ d_I_L,
    const int* __restrict__ lengths,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t stride = (size_t)n * n;

    const float* v = V + b * stride;
    const float* cr = C_R + b * stride;
    const float* cl = C_L + b * stride;
    const float* dcr = d_C_R + b * stride;
    const float* dcl = d_C_L + b * stride;
    float* dir = d_I_R + b * stride;
    float* dil = d_I_L + b * stride;

    int seq_len = lengths ? lengths[b] : n;
    int num_spans = seq_len - span_len;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len;

        if (j >= seq_len) continue;

        float max_v = NINF;
        for (int k = i; k < j; k++) {
            float val = cr[i * n + k] + cl[(k + 1) * n + j];
            max_v = fmaxf(max_v, val);
        }

        if (max_v <= NINF) {
            dir[i * n + j] = 0.0f;
            dil[i * n + j] = 0.0f;
            continue;
        }

        float sum_exp = 0.0f;
        for (int k = i; k < j; k++) {
            float val = cr[i * n + k] + cl[(k + 1) * n + j];
            sum_exp += safe_exp((val - max_v) / T);
        }
        float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

        float d_lse = 0.0f;
        for (int k = i; k < j; k++) {
            float val = cr[i * n + k] + cl[(k + 1) * n + j];
            float w = safe_exp((val - max_v) / T) * inv_sum;
            d_lse += w * (dcr[i * n + k] + dcl[(k + 1) * n + j]);
        }

        dir[i * n + j] = v[i * n + j] + d_lse;
        dil[i * n + j] = v[j * n + i] + d_lse;
    }
}

__global__ void hvp_forward_complete_kernel(
    const float* __restrict__ C_R,
    const float* __restrict__ C_L,
    const float* __restrict__ I_R,
    const float* __restrict__ I_L,
    const float* __restrict__ d_C_R_in,
    const float* __restrict__ d_C_L_in,
    const float* __restrict__ d_I_R,
    const float* __restrict__ d_I_L,
    float* __restrict__ d_C_R,
    float* __restrict__ d_C_L,
    const int* __restrict__ lengths,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t stride = (size_t)n * n;

    const float* cr = C_R + b * stride;
    const float* cl = C_L + b * stride;
    const float* ir = I_R + b * stride;
    const float* il = I_L + b * stride;
    const float* dcr_in = d_C_R_in + b * stride;
    const float* dcl_in = d_C_L_in + b * stride;
    const float* dir = d_I_R + b * stride;
    const float* dil = d_I_L + b * stride;
    float* dcr = d_C_R + b * stride;
    float* dcl = d_C_L + b * stride;

    int seq_len = lengths ? lengths[b] : n;
    int num_spans = seq_len - span_len;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len;

        if (j >= seq_len) continue;

        // d_C_R[i,j]
        {
            float max_v = NINF;
            for (int k = i; k < j; k++) {
                float v = cr[i * n + k] + ir[k * n + j];
                max_v = fmaxf(max_v, v);
            }

            if (max_v <= NINF) {
                dcr[i * n + j] = 0.0f;
            } else {
                float sum_exp = 0.0f;
                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + ir[k * n + j];
                    sum_exp += safe_exp((v - max_v) / T);
                }
                float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

                float d_lse = 0.0f;
                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + ir[k * n + j];
                    float w = safe_exp((v - max_v) / T) * inv_sum;
                    d_lse += w * (dcr_in[i * n + k] + dir[k * n + j]);
                }
                dcr[i * n + j] = d_lse;
            }
        }

        // d_C_L[i,j]
        {
            float max_v = NINF;
            for (int k = i + 1; k <= j; k++) {
                float v = il[i * n + k] + cl[k * n + j];
                max_v = fmaxf(max_v, v);
            }

            if (max_v <= NINF) {
                dcl[i * n + j] = 0.0f;
            } else {
                float sum_exp = 0.0f;
                for (int k = i + 1; k <= j; k++) {
                    float v = il[i * n + k] + cl[k * n + j];
                    sum_exp += safe_exp((v - max_v) / T);
                }
                float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

                float d_lse = 0.0f;
                for (int k = i + 1; k <= j; k++) {
                    float v = il[i * n + k] + cl[k * n + j];
                    float w = safe_exp((v - max_v) / T) * inv_sum;
                    d_lse += w * (dil[i * n + k] + dcl_in[k * n + j]);
                }
                dcl[i * n + j] = d_lse;
            }
        }
    }
}

__global__ void hvp_backward_complete_kernel(
    const float* __restrict__ C_R,
    const float* __restrict__ C_L,
    const float* __restrict__ I_R,
    const float* __restrict__ I_L,
    const float* __restrict__ d_C_R,
    const float* __restrict__ d_C_L,
    const float* __restrict__ d_I_R,
    const float* __restrict__ d_I_L,
    float* __restrict__ beta_C_R,
    float* __restrict__ beta_C_L,
    float* __restrict__ beta_I_R,
    float* __restrict__ beta_I_L,
    float* __restrict__ d_beta_C_R,
    float* __restrict__ d_beta_C_L,
    float* __restrict__ d_beta_I_R,
    float* __restrict__ d_beta_I_L,
    const int* __restrict__ lengths,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t stride = (size_t)n * n;

    const float* cr = C_R + b * stride;
    const float* cl = C_L + b * stride;
    const float* ir = I_R + b * stride;
    const float* il = I_L + b * stride;
    const float* dcr = d_C_R + b * stride;
    const float* dcl = d_C_L + b * stride;
    const float* dir = d_I_R + b * stride;
    const float* dil = d_I_L + b * stride;
    float* bcr = beta_C_R + b * stride;
    float* bcl = beta_C_L + b * stride;
    float* bir = beta_I_R + b * stride;
    float* bil = beta_I_L + b * stride;
    float* dbcr = d_beta_C_R + b * stride;
    float* dbcl = d_beta_C_L + b * stride;
    float* dbir = d_beta_I_R + b * stride;
    float* dbil = d_beta_I_L + b * stride;

    int seq_len = lengths ? lengths[b] : n;
    int num_spans = seq_len - span_len;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len;

        if (j >= seq_len) continue;

        // Backward for C_R[i,j]
        float beta_cr_ij = bcr[i * n + j];
        float d_beta_cr_ij = dbcr[i * n + j];

        if (beta_cr_ij != 0.0f || d_beta_cr_ij != 0.0f) {
            float max_v = NINF;
            for (int k = i; k < j; k++) {
                float v = cr[i * n + k] + ir[k * n + j];
                max_v = fmaxf(max_v, v);
            }

            if (max_v > NINF) {
                float sum_exp = 0.0f;
                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + ir[k * n + j];
                    sum_exp += safe_exp((v - max_v) / T);
                }
                float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

                float E_d_term = 0.0f;
                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + ir[k * n + j];
                    float w = safe_exp((v - max_v) / T) * inv_sum;
                    E_d_term += w * (dcr[i * n + k] + dir[k * n + j]);
                }

                for (int k = i; k < j; k++) {
                    float v = cr[i * n + k] + ir[k * n + j];
                    float w = safe_exp((v - max_v) / T) * inv_sum;
                    float d_term = dcr[i * n + k] + dir[k * n + j];
                    float d_w = w * (d_term - E_d_term) / T;

                    float mass = beta_cr_ij * w;
                    float d_mass = d_beta_cr_ij * w + beta_cr_ij * d_w;

                    atomicAdd(&bcr[i * n + k], mass);
                    atomicAdd(&bir[k * n + j], mass);
                    atomicAdd(&dbcr[i * n + k], d_mass);
                    atomicAdd(&dbir[k * n + j], d_mass);
                }
            }
        }

        // Backward for C_L[i,j]
        float beta_cl_ij = bcl[i * n + j];
        float d_beta_cl_ij = dbcl[i * n + j];

        if (beta_cl_ij != 0.0f || d_beta_cl_ij != 0.0f) {
            float max_v = NINF;
            for (int k = i + 1; k <= j; k++) {
                float v = il[i * n + k] + cl[k * n + j];
                max_v = fmaxf(max_v, v);
            }

            if (max_v > NINF) {
                float sum_exp = 0.0f;
                for (int k = i + 1; k <= j; k++) {
                    float v = il[i * n + k] + cl[k * n + j];
                    sum_exp += safe_exp((v - max_v) / T);
                }
                float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

                float E_d_term = 0.0f;
                for (int k = i + 1; k <= j; k++) {
                    float v = il[i * n + k] + cl[k * n + j];
                    float w = safe_exp((v - max_v) / T) * inv_sum;
                    E_d_term += w * (dil[i * n + k] + dcl[k * n + j]);
                }

                for (int k = i + 1; k <= j; k++) {
                    float v = il[i * n + k] + cl[k * n + j];
                    float w = safe_exp((v - max_v) / T) * inv_sum;
                    float d_term = dil[i * n + k] + dcl[k * n + j];
                    float d_w = w * (d_term - E_d_term) / T;

                    float mass = beta_cl_ij * w;
                    float d_mass = d_beta_cl_ij * w + beta_cl_ij * d_w;

                    atomicAdd(&bil[i * n + k], mass);
                    atomicAdd(&bcl[k * n + j], mass);
                    atomicAdd(&dbil[i * n + k], d_mass);
                    atomicAdd(&dbcl[k * n + j], d_mass);
                }
            }
        }
    }
}

__global__ void hvp_backward_incomplete_kernel(
    const float* __restrict__ V,
    const float* __restrict__ C_R,
    const float* __restrict__ C_L,
    const float* __restrict__ d_C_R,
    const float* __restrict__ d_C_L,
    float* __restrict__ beta_C_R,
    float* __restrict__ beta_C_L,
    const float* __restrict__ beta_I_R,
    const float* __restrict__ beta_I_L,
    float* __restrict__ d_beta_C_R,
    float* __restrict__ d_beta_C_L,
    const float* __restrict__ d_beta_I_R,
    const float* __restrict__ d_beta_I_L,
    float* __restrict__ HVP,
    const int* __restrict__ lengths,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t stride = (size_t)n * n;

    const float* v = V + b * stride;
    const float* cr = C_R + b * stride;
    const float* cl = C_L + b * stride;
    const float* dcr = d_C_R + b * stride;
    const float* dcl = d_C_L + b * stride;
    float* bcr = beta_C_R + b * stride;
    float* bcl = beta_C_L + b * stride;
    const float* bir = beta_I_R + b * stride;
    const float* bil = beta_I_L + b * stride;
    float* dbcr = d_beta_C_R + b * stride;
    float* dbcl = d_beta_C_L + b * stride;
    const float* dbir = d_beta_I_R + b * stride;
    const float* dbil = d_beta_I_L + b * stride;
    float* hvp_out = HVP + b * stride;

    int seq_len = lengths ? lengths[b] : n;
    int num_spans = seq_len - span_len;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len;

        if (j >= seq_len) continue;

        float beta_ir = bir[i * n + j];
        float beta_il = bil[i * n + j];
        float d_beta_ir = dbir[i * n + j];
        float d_beta_il = dbil[i * n + j];

        // HVP output
        hvp_out[i * n + j] = d_beta_ir;
        hvp_out[j * n + i] = d_beta_il;

        float beta_combined = beta_ir + beta_il;
        float d_beta_combined = d_beta_ir + d_beta_il;

        if (beta_combined != 0.0f || d_beta_combined != 0.0f) {
            float max_v = NINF;
            for (int k = i; k < j; k++) {
                float val = cr[i * n + k] + cl[(k + 1) * n + j];
                max_v = fmaxf(max_v, val);
            }

            if (max_v > NINF) {
                float sum_exp = 0.0f;
                for (int k = i; k < j; k++) {
                    float val = cr[i * n + k] + cl[(k + 1) * n + j];
                    sum_exp += safe_exp((val - max_v) / T);
                }
                float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

                float E_d_term = 0.0f;
                for (int k = i; k < j; k++) {
                    float val = cr[i * n + k] + cl[(k + 1) * n + j];
                    float w = safe_exp((val - max_v) / T) * inv_sum;
                    E_d_term += w * (dcr[i * n + k] + dcl[(k + 1) * n + j]);
                }

                for (int k = i; k < j; k++) {
                    float val = cr[i * n + k] + cl[(k + 1) * n + j];
                    float w = safe_exp((val - max_v) / T) * inv_sum;
                    float d_term = dcr[i * n + k] + dcl[(k + 1) * n + j];
                    float d_w = w * (d_term - E_d_term) / T;

                    float mass = beta_combined * w;
                    float d_mass = d_beta_combined * w + beta_combined * d_w;

                    atomicAdd(&bcr[i * n + k], mass);
                    atomicAdd(&bcl[(k + 1) * n + j], mass);
                    atomicAdd(&dbcr[i * n + k], d_mass);
                    atomicAdd(&dbcl[(k + 1) * n + j], d_mass);
                }
            }
        }
    }
}

// ============================================================================
// Host API
// ============================================================================

void forward(
    const float* arc_scores,
    float* C_R,
    float* C_L,
    float* I_R,
    float* I_L,
    float* partition,
    const int* lengths,
    int B, int n, float temperature
) {
    int threads = 256;
    size_t elems = (size_t)B * n * n;
    int blocks_init = (elems + threads - 1) / threads;

    init_kernel<<<blocks_init, threads>>>(C_R, C_L, I_R, I_L, lengths, B, n);

    for (int len = 1; len < n; len++) {
        forward_incomplete_kernel<<<B, threads>>>(
            arc_scores, C_R, C_L, I_R, I_L, lengths, B, n, temperature, len
        );
        forward_complete_kernel<<<B, threads>>>(
            C_R, C_L, I_R, I_L, C_R, C_L, lengths, B, n, temperature, len
        );
    }

    int blocks_extract = (B + 255) / 256;
    extract_partition_kernel<<<blocks_extract, 256>>>(C_R, partition, lengths, B, n);

    cudaDeviceSynchronize();
}

void backward(
    const float* arc_scores,
    const float* C_R,
    const float* C_L,
    const float* I_R,
    const float* I_L,
    float* beta_C_R,
    float* beta_C_L,
    float* beta_I_R,
    float* beta_I_L,
    float* marginals,
    float* grad_T,
    const int* lengths,
    int B, int n, float temperature
) {
    int threads = 256;
    size_t elems = (size_t)B * n * n;
    int blocks_init = (elems + threads - 1) / threads;

    init_beta_kernel<<<blocks_init, threads>>>(
        beta_C_R, beta_C_L, beta_I_R, beta_I_L, marginals, grad_T, lengths, B, n
    );

    for (int len = n - 1; len >= 1; len--) {
        backward_complete_kernel<<<B, threads>>>(
            C_R, C_L, I_R, I_L,
            beta_C_R, beta_C_L, beta_I_R, beta_I_L,
            grad_T, lengths, B, n, temperature, len
        );
        backward_incomplete_kernel<<<B, threads>>>(
            arc_scores, C_R, C_L, I_R, I_L,
            beta_C_R, beta_C_L, beta_I_R, beta_I_L,
            marginals, grad_T, lengths, B, n, temperature, len
        );
    }

    cudaDeviceSynchronize();
}

void hvp(
    const float* arc_scores,
    const float* V,
    const float* C_R,
    const float* C_L,
    const float* I_R,
    const float* I_L,
    float* d_C_R,
    float* d_C_L,
    float* d_I_R,
    float* d_I_L,
    float* beta_C_R,
    float* beta_C_L,
    float* beta_I_R,
    float* beta_I_L,
    float* d_beta_C_R,
    float* d_beta_C_L,
    float* d_beta_I_R,
    float* d_beta_I_L,
    float* HVP,
    const int* lengths,
    int B, int n, float temperature
) {
    int threads = 256;
    size_t elems = (size_t)B * n * n;
    int blocks_init = (elems + threads - 1) / threads;

    hvp_init_kernel<<<blocks_init, threads>>>(
        V, d_C_R, d_C_L, d_I_R, d_I_L,
        beta_C_R, beta_C_L, beta_I_R, beta_I_L,
        d_beta_C_R, d_beta_C_L, d_beta_I_R, d_beta_I_L,
        HVP, lengths, B, n
    );

    for (int len = 1; len < n; len++) {
        hvp_forward_incomplete_kernel<<<B, threads>>>(
            arc_scores, V, C_R, C_L, d_C_R, d_C_L,
            d_I_R, d_I_L, lengths, B, n, temperature, len
        );
        hvp_forward_complete_kernel<<<B, threads>>>(
            C_R, C_L, I_R, I_L, d_C_R, d_C_L, d_I_R, d_I_L,
            d_C_R, d_C_L, lengths, B, n, temperature, len
        );
    }

    for (int len = n - 1; len >= 1; len--) {
        hvp_backward_complete_kernel<<<B, threads>>>(
            C_R, C_L, I_R, I_L, d_C_R, d_C_L, d_I_R, d_I_L,
            beta_C_R, beta_C_L, beta_I_R, beta_I_L,
            d_beta_C_R, d_beta_C_L, d_beta_I_R, d_beta_I_L,
            lengths, B, n, temperature, len
        );
        hvp_backward_incomplete_kernel<<<B, threads>>>(
            V, C_R, C_L, d_C_R, d_C_L,
            beta_C_R, beta_C_L, beta_I_R, beta_I_L,
            d_beta_C_R, d_beta_C_L, d_beta_I_R, d_beta_I_L,
            HVP, lengths, B, n, temperature, len
        );
    }

    cudaDeviceSynchronize();
}

} // namespace eisner
} // namespace d2p
