/**
 * @file kernels.cu
 * @brief Soft CKY CUDA Kernel Implementations
 *
 * Differentiable CKY parsing with span-based parallelization.
 * Uses logsumexp for maximization (like SW/NW, unlike DTW).
 */

#include "kernels.cuh"
#include "common/numerics.cuh"
#include "common/reduce.cuh"

using namespace d2p::common;

// ============================================================================
// FORWARD PASS (Inside Algorithm)
// ============================================================================

// Initialize Z: leaves = leaf_scores, all others = -inf
__global__ void cky_init_kernel(
    const float* __restrict__ leaf_scores,  // [B, n]
    float* __restrict__ Z,                  // [B, n, n]
    int B, int n
) {
    size_t Z_stride = (size_t)n * n;
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
    size_t total = (size_t)B * Z_stride;
    if (idx >= total) return;

    size_t b = idx / Z_stride;
    size_t rem = idx - b * Z_stride;
    int i = rem / n;
    int j = rem % n;

    if (i == j) {
        Z[idx] = leaf_scores[b * n + i];
    } else {
        Z[idx] = NINF;
    }
}

// Forward DP for one span length
__global__ void cky_forward_diag_kernel(
    const float* __restrict__ merge_scores,  // [B, n, n, n]
    float* __restrict__ Z,                   // [B, n, n]
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;

    float* z = Z + b * Z_stride;
    const float* merge = merge_scores + b * merge_stride;

    int num_spans = n - span_len + 1;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len - 1;

        // Logsumexp over split points k = i..j-1
        float max_v = NINF;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            max_v = fmaxf(max_v, v);
        }

        if (max_v <= NINF) {
            z[i * n + j] = NINF;
            continue;
        }

        float sum_exp = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            sum_exp += safe_exp((v - max_v) / T);
        }

        z[i * n + j] = max_v + T * logf(sum_exp);
    }
}

// ============================================================================
// BACKWARD PASS (Outside Algorithm + Gradients)
// ============================================================================

__global__ void cky_init_beta_kernel(
    float* __restrict__ beta,
    float* __restrict__ Pcond,
    float* __restrict__ Pjoint,
    float* __restrict__ grad_merge,
    float* __restrict__ grad_leaf,
    float* __restrict__ grad_T,
    int B, int n
) {
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
    size_t total_Z = (size_t)B * Z_stride;
    size_t total_merge = (size_t)B * merge_stride;

    if (idx < total_Z) {
        size_t b = idx / Z_stride;
        size_t rem = idx - b * Z_stride;
        int i = rem / n;
        int j = rem % n;
        beta[idx] = (i == 0 && j == n - 1) ? 1.0f : 0.0f;
    }

    if (idx < total_merge) {
        Pcond[idx] = 0.0f;
        Pjoint[idx] = 0.0f;
        grad_merge[idx] = 0.0f;
    }

    if (idx < (size_t)B * n) {
        grad_leaf[idx] = 0.0f;
    }

    if (idx < (size_t)B) {
        grad_T[idx] = 0.0f;
    }
}

__global__ void cky_compute_pcond_kernel(
    const float* __restrict__ Z,
    const float* __restrict__ merge_scores,
    float* __restrict__ Pcond,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;

    const float* z = Z + b * Z_stride;
    const float* merge = merge_scores + b * merge_stride;
    float* pc = Pcond + b * merge_stride;

    int num_spans = n - span_len + 1;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len - 1;

        float max_v = NINF;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            max_v = fmaxf(max_v, v);
        }

        if (max_v <= NINF) continue;

        float sum_exp = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            sum_exp += safe_exp((v - max_v) / T);
        }

        float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            float w = safe_exp((v - max_v) / T) * inv_sum;
            pc[i * n * n + k * n + j] = w;
        }
    }
}

__global__ void cky_backward_diag_kernel(
    const float* __restrict__ Z,
    const float* __restrict__ merge_scores,
    const float* __restrict__ Pcond,
    float* __restrict__ beta,
    float* __restrict__ Pjoint,
    float* __restrict__ grad_merge,
    float* __restrict__ grad_T,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;

    const float* z = Z + b * Z_stride;
    const float* merge = merge_scores + b * merge_stride;
    const float* pc = Pcond + b * merge_stride;
    float* be = beta + b * Z_stride;
    float* pj = Pjoint + b * merge_stride;
    float* gm = grad_merge + b * merge_stride;

    int num_spans = n - span_len + 1;
    if (num_spans <= 0) return;

    float local_grad_T = 0.0f;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len - 1;

        float beta_ij = be[i * n + j];
        if (beta_ij == 0.0f) continue;

        for (int k = i; k < j; k++) {
            float p = pc[i * n * n + k * n + j];
            float mass = beta_ij * p;

            pj[i * n * n + k * n + j] = mass;
            gm[i * n * n + k * n + j] = mass;

            atomicAdd(&be[i * n + k], mass);
            atomicAdd(&be[(k + 1) * n + j], mass);
        }

        // Temperature gradient
        float Zij = z[i * n + j];
        float E_term = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float p = pc[i * n * n + k * n + j];
            E_term += p * (left + right + ms);
        }

        local_grad_T += beta_ij * (Zij - E_term) / T;
    }

    float block_grad_T = block_reduce_sum(local_grad_T);
    if (threadIdx.x == 0) {
        atomicAdd(&grad_T[b], block_grad_T);
    }
}

__global__ void cky_finalize_grad_leaf_kernel(
    const float* __restrict__ beta,
    float* __restrict__ grad_leaf,
    int B, int n
) {
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= (size_t)B * n) return;

    int b = idx / n;
    int i = idx % n;

    grad_leaf[idx] = beta[b * n * n + i * n + i];
}

// ============================================================================
// HVP (Hessian-Vector Product)
// ============================================================================

__global__ void cky_hvp_init_kernel(
    const float* __restrict__ V_leaf,
    float* __restrict__ d_Z,
    float* __restrict__ beta,
    float* __restrict__ d_beta,
    float* __restrict__ HVP_merge,
    float* __restrict__ HVP_leaf,
    int B, int n
) {
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < (size_t)B * Z_stride) {
        size_t b = idx / Z_stride;
        size_t rem = idx - b * Z_stride;
        int i = rem / n;
        int j = rem % n;
        d_Z[idx] = (i == j) ? V_leaf[b * n + i] : 0.0f;
        beta[idx] = (i == 0 && j == n - 1) ? 1.0f : 0.0f;
        d_beta[idx] = 0.0f;
    }

    if (idx < (size_t)B * merge_stride) {
        HVP_merge[idx] = 0.0f;
    }
    if (idx < (size_t)B * n) {
        HVP_leaf[idx] = 0.0f;
    }
}

__global__ void cky_hvp_forward_diag_kernel(
    const float* __restrict__ Z,
    const float* __restrict__ merge_scores,
    const float* __restrict__ V_merge,
    float* __restrict__ d_Z,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;

    const float* z = Z + b * Z_stride;
    const float* merge = merge_scores + b * merge_stride;
    const float* v_merge = V_merge + b * merge_stride;
    float* dz = d_Z + b * Z_stride;

    int num_spans = n - span_len + 1;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len - 1;

        float max_v = NINF;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            max_v = fmaxf(max_v, v);
        }

        if (max_v <= NINF) {
            dz[i * n + j] = 0.0f;
            continue;
        }

        float sum_exp = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            sum_exp += safe_exp((v - max_v) / T);
        }
        float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

        float d_zij = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            float w = safe_exp((v - max_v) / T) * inv_sum;

            float d_left = dz[i * n + k];
            float d_right = dz[(k + 1) * n + j];
            float d_ms = v_merge[i * n * n + k * n + j];

            d_zij += w * (d_left + d_right + d_ms);
        }

        dz[i * n + j] = d_zij;
    }
}

__global__ void cky_hvp_backward_diag_kernel(
    const float* __restrict__ Z,
    const float* __restrict__ merge_scores,
    const float* __restrict__ V_merge,
    const float* __restrict__ d_Z,
    float* __restrict__ beta,
    float* __restrict__ d_beta,
    float* __restrict__ HVP_merge,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;

    const float* z = Z + b * Z_stride;
    const float* merge = merge_scores + b * merge_stride;
    const float* v_merge = V_merge + b * merge_stride;
    const float* dz = d_Z + b * Z_stride;
    float* be = beta + b * Z_stride;
    float* dbe = d_beta + b * Z_stride;
    float* hm = HVP_merge + b * merge_stride;

    int num_spans = n - span_len + 1;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len - 1;

        float beta_ij = be[i * n + j];
        float d_beta_ij = dbe[i * n + j];

        float max_v = NINF;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            max_v = fmaxf(max_v, v);
        }

        if (max_v <= NINF) continue;

        float sum_exp = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            sum_exp += safe_exp((v - max_v) / T);
        }
        float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

        float E_d_term = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            float w = safe_exp((v - max_v) / T) * inv_sum;

            float d_left = dz[i * n + k];
            float d_right = dz[(k + 1) * n + j];
            float d_ms = v_merge[i * n * n + k * n + j];
            float d_term = d_left + d_right + d_ms;

            E_d_term += w * d_term;
        }

        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            float w = safe_exp((v - max_v) / T) * inv_sum;

            float d_left = dz[i * n + k];
            float d_right = dz[(k + 1) * n + j];
            float d_ms = v_merge[i * n * n + k * n + j];
            float d_term = d_left + d_right + d_ms;

            float d_w = w * (d_term - E_d_term) / T;

            float mass = beta_ij * w;
            float d_mass = d_beta_ij * w + beta_ij * d_w;

            hm[i * n * n + k * n + j] = d_mass;

            atomicAdd(&be[i * n + k], mass);
            atomicAdd(&be[(k + 1) * n + j], mass);
            atomicAdd(&dbe[i * n + k], d_mass);
            atomicAdd(&dbe[(k + 1) * n + j], d_mass);
        }
    }
}

__global__ void cky_hvp_finalize_leaf_kernel(
    const float* __restrict__ d_beta,
    float* __restrict__ HVP_leaf,
    int B, int n
) {
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= (size_t)B * n) return;

    int b = idx / n;
    int i = idx % n;

    HVP_leaf[idx] = d_beta[b * n * n + i * n + i];
}

// ============================================================================
// PARAM GRAD (dP/dT)
// ============================================================================

__global__ void cky_param_grad_init_kernel(
    float* __restrict__ U,
    float* __restrict__ beta,
    float* __restrict__ W,
    float* __restrict__ dP_dT_merge,
    float* __restrict__ dP_dT_leaf,
    int B, int n
) {
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < (size_t)B * Z_stride) {
        size_t b = idx / Z_stride;
        size_t rem = idx - b * Z_stride;
        int i = rem / n;
        int j = rem % n;
        U[idx] = 0.0f;
        beta[idx] = (i == 0 && j == n - 1) ? 1.0f : 0.0f;
        W[idx] = 0.0f;
    }

    if (idx < (size_t)B * merge_stride) {
        dP_dT_merge[idx] = 0.0f;
    }
    if (idx < (size_t)B * n) {
        dP_dT_leaf[idx] = 0.0f;
    }
}

__global__ void cky_param_grad_forward_diag_kernel(
    const float* __restrict__ Z,
    const float* __restrict__ merge_scores,
    float* __restrict__ U,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;

    const float* z = Z + b * Z_stride;
    const float* merge = merge_scores + b * merge_stride;
    float* u = U + b * Z_stride;

    int num_spans = n - span_len + 1;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len - 1;

        float max_v = NINF;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            max_v = fmaxf(max_v, v);
        }

        if (max_v <= NINF) {
            u[i * n + j] = 0.0f;
            continue;
        }

        float sum_exp = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            sum_exp += safe_exp((v - max_v) / T);
        }
        float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

        float Zij = z[i * n + j];
        float E_term = 0.0f;
        float E_U = 0.0f;

        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            float w = safe_exp((v - max_v) / T) * inv_sum;

            float u_left = u[i * n + k];
            float u_right = u[(k + 1) * n + j];

            E_term += w * v;
            E_U += w * (u_left + u_right);
        }

        u[i * n + j] = (Zij - E_term) / T + E_U;
    }
}

__global__ void cky_param_grad_backward_diag_kernel(
    const float* __restrict__ Z,
    const float* __restrict__ merge_scores,
    const float* __restrict__ U,
    float* __restrict__ beta,
    float* __restrict__ W,
    float* __restrict__ dP_dT_merge,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;

    const float* z = Z + b * Z_stride;
    const float* merge = merge_scores + b * merge_stride;
    const float* u = U + b * Z_stride;
    float* be = beta + b * Z_stride;
    float* w_buf = W + b * Z_stride;
    float* dpm = dP_dT_merge + b * merge_stride;

    int num_spans = n - span_len + 1;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len - 1;

        float beta_ij = be[i * n + j];
        float w_ij = w_buf[i * n + j];

        float max_v = NINF;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            max_v = fmaxf(max_v, v);
        }

        if (max_v <= NINF) continue;

        float sum_exp = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            sum_exp += safe_exp((v - max_v) / T);
        }
        float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

        float E_term = 0.0f;
        float E_d_term = 0.0f;

        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            float w = safe_exp((v - max_v) / T) * inv_sum;

            float u_left = u[i * n + k];
            float u_right = u[(k + 1) * n + j];

            E_term += w * v;
            E_d_term += w * (u_left + u_right);
        }

        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            float w = safe_exp((v - max_v) / T) * inv_sum;

            float u_left = u[i * n + k];
            float u_right = u[(k + 1) * n + j];
            float d_term = u_left + u_right;

            float diff = v - E_term;
            float dw_dT = w * (-diff / (T * T) + (d_term - E_d_term) / T);

            dpm[i * n * n + k * n + j] = w_ij * w + beta_ij * dw_dT;

            float mass = beta_ij * w;
            float d_mass = w_ij * w + beta_ij * dw_dT;

            atomicAdd(&be[i * n + k], mass);
            atomicAdd(&be[(k + 1) * n + j], mass);
            atomicAdd(&w_buf[i * n + k], d_mass);
            atomicAdd(&w_buf[(k + 1) * n + j], d_mass);
        }
    }
}

__global__ void cky_param_grad_finalize_leaf_kernel(
    const float* __restrict__ W,
    float* __restrict__ dP_dT_leaf,
    int B, int n
) {
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= (size_t)B * n) return;

    int b = idx / n;
    int i = idx % n;

    dP_dT_leaf[idx] = W[b * n * n + i * n + i];
}

// ============================================================================
// THERMODYNAMICS
// ============================================================================

__global__ void cky_thermo_init_kernel(
    const float* __restrict__ leaf_scores,
    float* __restrict__ M1,
    float* __restrict__ M2,
    int B, int n
) {
    size_t Z_stride = (size_t)n * n;
    size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x;
    size_t total = (size_t)B * Z_stride;
    if (idx >= total) return;

    size_t b = idx / Z_stride;
    size_t rem = idx - b * Z_stride;
    int i = rem / n;
    int j = rem % n;

    if (i == j) {
        float ls = leaf_scores[b * n + i];
        M1[idx] = ls;
        M2[idx] = ls * ls;
    } else {
        M1[idx] = 0.0f;
        M2[idx] = 0.0f;
    }
}

__global__ void cky_thermo_diag_kernel(
    const float* __restrict__ Z,
    const float* __restrict__ merge_scores,
    float* __restrict__ M1,
    float* __restrict__ M2,
    int B, int n, float T,
    int span_len
) {
    int b = blockIdx.x;
    size_t Z_stride = (size_t)n * n;
    size_t merge_stride = (size_t)n * n * n;

    const float* z = Z + b * Z_stride;
    const float* merge = merge_scores + b * merge_stride;
    float* m1 = M1 + b * Z_stride;
    float* m2 = M2 + b * Z_stride;

    int num_spans = n - span_len + 1;
    if (num_spans <= 0) return;

    for (int t = threadIdx.x; t < num_spans; t += blockDim.x) {
        int i = t;
        int j = i + span_len - 1;

        float max_v = NINF;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            max_v = fmaxf(max_v, v);
        }

        if (max_v <= NINF) {
            m1[i * n + j] = 0.0f;
            m2[i * n + j] = 0.0f;
            continue;
        }

        float sum_exp = 0.0f;
        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            sum_exp += safe_exp((v - max_v) / T);
        }
        float inv_sum = (sum_exp > 1e-20f) ? 1.0f / sum_exp : 0.0f;

        float new_m1 = 0.0f;
        float new_m2 = 0.0f;

        for (int k = i; k < j; k++) {
            float left = z[i * n + k];
            float right = z[(k + 1) * n + j];
            float ms = merge[i * n * n + k * n + j];
            float v = left + right + ms;
            float w = safe_exp((v - max_v) / T) * inv_sum;

            float left_m1 = m1[i * n + k];
            float right_m1 = m1[(k + 1) * n + j];
            float left_m2 = m2[i * n + k];
            float right_m2 = m2[(k + 1) * n + j];

            float score_k = left_m1 + right_m1 + ms;
            new_m1 += w * score_k;

            float score2_k = left_m2 + right_m2 + ms * ms
                           + 2.0f * ms * (left_m1 + right_m1)
                           + 2.0f * left_m1 * right_m1;
            new_m2 += w * score2_k;
        }

        m1[i * n + j] = new_m1;
        m2[i * n + j] = new_m2;
    }
}

__global__ void cky_thermo_finalize_kernel(
    const float* __restrict__ M1,
    const float* __restrict__ M2,
    const float* __restrict__ partition,
    float* __restrict__ F_out,
    float* __restrict__ E_out,
    float* __restrict__ E2_out,
    float* __restrict__ S_out,
    float* __restrict__ C_out,
    int B, int n, float T
) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;
    if (b >= B) return;

    size_t root_idx = (size_t)b * n * n + 0 * n + (n - 1);
    float E = M1[root_idx];
    float E2 = M2[root_idx];
    float logZ = partition[b];
    float Var = E2 - E * E;

    F_out[b] = -logZ;
    E_out[b] = E;
    E2_out[b] = E2;
    S_out[b] = (logZ - E) / T;
    C_out[b] = Var / (T * T);
}

// ============================================================================
// Host Wrappers
// ============================================================================

void cky_forward(
    const float* d_merge_scores,
    const float* d_leaf_scores,
    float* d_Z,
    float* d_partition,
    int B, int n, float T
) {
    int threads = 256;
    size_t Z_elems = (size_t)B * n * n;
    int blocks_init = (Z_elems + threads - 1) / threads;

    cky_init_kernel<<<blocks_init, threads>>>(d_leaf_scores, d_Z, B, n);

    for (int len = 2; len <= n; len++) {
        cky_forward_diag_kernel<<<B, threads>>>(d_merge_scores, d_Z, B, n, T, len);
    }

    cudaMemcpy2D(
        d_partition,
        sizeof(float),
        d_Z + (n - 1),
        (size_t)n * n * sizeof(float),
        sizeof(float),
        B,
        cudaMemcpyDeviceToDevice
    );

    cudaDeviceSynchronize();
}

void cky_backward(
    const float* d_Z,
    const float* d_merge_scores,
    const float* d_leaf_scores,
    const float* d_partition,
    float* d_beta,
    float* d_Pcond,
    float* d_Pjoint,
    float* d_grad_merge,
    float* d_grad_leaf,
    float* d_grad_T,
    int B, int n, float T
) {
    int threads = 256;
    size_t merge_elems = (size_t)B * n * n * n;
    int blocks_init = (merge_elems + threads - 1) / threads;

    cky_init_beta_kernel<<<blocks_init, threads>>>(
        d_beta, d_Pcond, d_Pjoint, d_grad_merge, d_grad_leaf, d_grad_T, B, n
    );

    for (int len = 2; len <= n; len++) {
        cky_compute_pcond_kernel<<<B, threads>>>(d_Z, d_merge_scores, d_Pcond, B, n, T, len);
    }

    for (int len = n; len >= 2; len--) {
        cky_backward_diag_kernel<<<B, threads>>>(
            d_Z, d_merge_scores, d_Pcond, d_beta, d_Pjoint, d_grad_merge, d_grad_T,
            B, n, T, len
        );
    }

    int blocks_leaf = (B * n + threads - 1) / threads;
    cky_finalize_grad_leaf_kernel<<<blocks_leaf, threads>>>(d_beta, d_grad_leaf, B, n);

    cudaDeviceSynchronize();
}

void cky_hvp(
    const float* d_Z,
    const float* d_merge_scores,
    const float* d_leaf_scores,
    const float* d_partition,
    const float* d_V_merge,
    const float* d_V_leaf,
    float* d_d_Z,
    float* d_d_partition,
    float* d_beta,
    float* d_d_beta,
    float* d_HVP_merge,
    float* d_HVP_leaf,
    int B, int n, float T
) {
    int threads = 256;
    size_t merge_elems = (size_t)B * n * n * n;
    int blocks_init = (merge_elems + threads - 1) / threads;

    cky_hvp_init_kernel<<<blocks_init, threads>>>(
        d_V_leaf, d_d_Z, d_beta, d_d_beta, d_HVP_merge, d_HVP_leaf, B, n
    );

    for (int len = 2; len <= n; len++) {
        cky_hvp_forward_diag_kernel<<<B, threads>>>(
            d_Z, d_merge_scores, d_V_merge, d_d_Z, B, n, T, len
        );
    }

    cudaMemcpy2D(
        d_d_partition,
        sizeof(float),
        d_d_Z + (n - 1),
        (size_t)n * n * sizeof(float),
        sizeof(float),
        B,
        cudaMemcpyDeviceToDevice
    );

    for (int len = n; len >= 2; len--) {
        cky_hvp_backward_diag_kernel<<<B, threads>>>(
            d_Z, d_merge_scores, d_V_merge, d_d_Z, d_beta, d_d_beta, d_HVP_merge,
            B, n, T, len
        );
    }

    int blocks_leaf = (B * n + threads - 1) / threads;
    cky_hvp_finalize_leaf_kernel<<<blocks_leaf, threads>>>(d_d_beta, d_HVP_leaf, B, n);

    cudaDeviceSynchronize();
}

void cky_param_grad(
    const float* d_Z,
    const float* d_merge_scores,
    const float* d_leaf_scores,
    const float* d_partition,
    float* d_U,
    float* d_beta,
    float* d_W,
    float* d_dP_dT_merge,
    float* d_dP_dT_leaf,
    int B, int n, float T
) {
    int threads = 256;
    size_t merge_elems = (size_t)B * n * n * n;
    int blocks_init = (merge_elems + threads - 1) / threads;

    cky_param_grad_init_kernel<<<blocks_init, threads>>>(
        d_U, d_beta, d_W, d_dP_dT_merge, d_dP_dT_leaf, B, n
    );

    for (int len = 2; len <= n; len++) {
        cky_param_grad_forward_diag_kernel<<<B, threads>>>(d_Z, d_merge_scores, d_U, B, n, T, len);
    }

    for (int len = n; len >= 2; len--) {
        cky_param_grad_backward_diag_kernel<<<B, threads>>>(
            d_Z, d_merge_scores, d_U, d_beta, d_W, d_dP_dT_merge, B, n, T, len
        );
    }

    int blocks_leaf = (B * n + threads - 1) / threads;
    cky_param_grad_finalize_leaf_kernel<<<blocks_leaf, threads>>>(d_W, d_dP_dT_leaf, B, n);

    cudaDeviceSynchronize();
}

void cky_thermodynamics(
    const float* d_Z,
    const float* d_merge_scores,
    const float* d_leaf_scores,
    const float* d_partition,
    const float* d_beta,
    const float* d_Pjoint,
    float* d_M1,
    float* d_M2,
    float* d_F,
    float* d_E,
    float* d_E2,
    float* d_S,
    float* d_C,
    int B, int n, float T
) {
    int threads = 256;
    size_t Z_elems = (size_t)B * n * n;
    int blocks_init = (Z_elems + threads - 1) / threads;

    cky_thermo_init_kernel<<<blocks_init, threads>>>(d_leaf_scores, d_M1, d_M2, B, n);

    for (int len = 2; len <= n; len++) {
        cky_thermo_diag_kernel<<<B, threads>>>(d_Z, d_merge_scores, d_M1, d_M2, B, n, T, len);
    }

    int blocks_final = (B + threads - 1) / threads;
    cky_thermo_finalize_kernel<<<blocks_final, threads>>>(
        d_M1, d_M2, d_partition, d_F, d_E, d_E2, d_S, d_C, B, n, T
    );

    cudaDeviceSynchronize();
}
