/**
 * @file kernels.cu
 * @brief Soft Hamming Distance CUDA Kernel Implementations
 *
 * Hamming distance is the simplest edit distance - it counts the number of
 * positions where the two sequences differ (for equal-length sequences).
 *
 * The "soft" version computes:
 * - distance = sum_i costs[i] (where costs = 0 for match, positive for mismatch)
 * - posteriors = 1 (gradient of sum w.r.t. each input is 1)
 *
 * Since there's no dynamic programming involved, this is O(n) linear.
 * The gradient is trivial: dH/dcosts[i] = 1 for all i.
 * The Hessian is 0 (linear function), so HVP = 0.
 */

#include "kernels.cuh"
#include <cuda_runtime.h>

namespace d2p {
namespace hamming {

// =============================================================================
// Device Kernels
// =============================================================================

/**
 * Kernel to compute Hamming distance (sum of costs) for each batch element.
 * Uses shared memory reduction for efficiency.
 */
__global__ void forward_kernel(
    const float* __restrict__ costs,    // [B, L]
    float* __restrict__ distance,       // [B]
    const int* __restrict__ lengths,    // [B] or nullptr
    int B,
    int L
) {
    int b = blockIdx.x;
    if (b >= B) return;

    int actual_L = lengths ? lengths[b] : L;

    // Each block handles one batch element
    // Use shared memory for reduction
    extern __shared__ float sdata[];

    float sum = 0.0f;
    for (int i = threadIdx.x; i < actual_L; i += blockDim.x) {
        sum += costs[b * L + i];
    }

    sdata[threadIdx.x] = sum;
    __syncthreads();

    // Reduction in shared memory
    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (threadIdx.x < s) {
            sdata[threadIdx.x] += sdata[threadIdx.x + s];
        }
        __syncthreads();
    }

    if (threadIdx.x == 0) {
        distance[b] = sdata[0];
    }
}

/**
 * Kernel to set posteriors to 1 (or 0 for out-of-bounds positions).
 */
__global__ void posteriors_kernel(
    float* __restrict__ posteriors,     // [B, L]
    const int* __restrict__ lengths,    // [B] or nullptr
    int B,
    int L
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = B * L;
    if (idx >= total) return;

    int b = idx / L;
    int i = idx % L;

    int actual_L = lengths ? lengths[b] : L;

    // Posterior is 1 for valid positions, 0 otherwise
    posteriors[idx] = (i < actual_L) ? 1.0f : 0.0f;
}

/**
 * Kernel for HVP - just zeros since Hessian of linear function is 0.
 */
__global__ void hvp_kernel(
    float* __restrict__ hvp,            // [B, L]
    int total
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= total) return;
    hvp[idx] = 0.0f;
}

/**
 * Kernel for parameter Jacobian (gradient w.r.t. temperature).
 * Since distance = sum(costs) doesn't depend on temperature,
 * the gradient is 0.
 */
__global__ void param_jacobian_kernel(
    float* __restrict__ grad_T,         // [B]
    int B
) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;
    if (b >= B) return;
    grad_T[b] = 0.0f;
}

// =============================================================================
// Host Functions
// =============================================================================

void forward(
    const float* d_costs,
    float* d_distance,
    const int* d_lengths,
    int B,
    int L,
    float temperature
) {
    // One block per batch element
    int threads = 256;
    int shared_mem = threads * sizeof(float);
    forward_kernel<<<B, threads, shared_mem>>>(
        d_costs, d_distance, d_lengths, B, L
    );
}

void posteriors(
    float* d_posteriors,
    const int* d_lengths,
    int B,
    int L
) {
    int total = B * L;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    posteriors_kernel<<<blocks, threads>>>(
        d_posteriors, d_lengths, B, L
    );
}

void forward_with_posteriors(
    const float* d_costs,
    float* d_distance,
    float* d_posteriors,
    const int* d_lengths,
    int B,
    int L,
    float temperature
) {
    forward(d_costs, d_distance, d_lengths, B, L, temperature);
    posteriors(d_posteriors, d_lengths, B, L);
}

void backward(
    const float* d_costs,
    float* d_distance,
    float* d_posteriors,
    float* d_grad_T,
    const int* d_lengths,
    int B,
    int L,
    float temperature
) {
    forward_with_posteriors(d_costs, d_distance, d_posteriors, d_lengths, B, L, temperature);

    // Temperature gradient is zero
    int threads = 256;
    int blocks = (B + threads - 1) / threads;
    param_jacobian_kernel<<<blocks, threads>>>(d_grad_T, B);
}

void hvp(
    const float* d_costs,
    const float* d_tangent,
    float* d_hvp,
    const int* d_lengths,
    int B,
    int L,
    float temperature
) {
    int total = B * L;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    hvp_kernel<<<blocks, threads>>>(d_hvp, total);
}

void backward_full(
    const float* d_costs,
    const float* d_grad_output,
    float* d_grad_costs,
    float* d_grad_T,
    const int* d_lengths,
    int B,
    int L,
    float temperature
) {
    // Gradient w.r.t. costs is just grad_output (broadcasted)
    // We need to multiply posteriors by grad_output
    // But posteriors are 1, so grad_costs = grad_output[b] for valid positions

    // For now, just compute posteriors (will be multiplied in binding layer)
    posteriors(d_grad_costs, d_lengths, B, L);

    // Temperature gradient is zero
    int threads = 256;
    int blocks_B = (B + threads - 1) / threads;
    param_jacobian_kernel<<<blocks_B, threads>>>(d_grad_T, B);
}

} // namespace hamming
} // namespace d2p
