#include "./ag_to_numerator.h"


namespace npeff {
namespace gpu {
namespace ops {
namespace custom {


///////////////////////////////////////////////////////////////////////////////
// AgToNumerator


void AgToNumerator::call_async() {
    ctx.set_device();

    // Treat this as the batched dot-product of n_examples * rank column
    // vectors, each of size n_classes.
    int64_t rank = AG.n_cols;
    int64_t batch_count = n_examples * rank;
    
    CUBLAS_CALL(cublasSgemmStridedBatched(
        ctx.dense_handle,
        CUBLAS_OP_T,
        CUBLAS_OP_N,
        1, 1, n_classes,
        ctx.dev_1f,
        (float *) AG.data, n_classes, n_classes,
        (float *) AG.data, n_classes, n_classes,
        ctx.dev_0f,
        (float *) W_update_numerator.data, 1, 1,
        batch_count
    ));
}


///////////////////////////////////////////////////////////////////////////////
// AgToNumeratorLvrm


__global__
void AgToNumeratorLvrm_Kernel(
    int64_t n_examples, int64_t n_rows, int64_t n_cols,
    float* d_AG, int64_t* d_example_row_offsets,
    float* d_W_update_numerator
) {
    int64_t example_index = blockIdx.x * blockDim.x + threadIdx.x;
    int64_t col_index = blockIdx.y * blockDim.y + threadIdx.y;

    if (example_index < n_examples && col_index < n_cols) {
        int64_t start_row = d_example_row_offsets[example_index];
        int64_t end_row = d_example_row_offsets[example_index + 1];

        float out_entry = 0.0f;

        for(int64_t i = start_row; i < end_row; i++) {
            float value = d_AG[col_index * n_rows + i];
            out_entry += value * value;
        }

        d_W_update_numerator[col_index * n_examples + example_index] = out_entry;
    }
}


void AgToNumeratorLvrm::call_async() {
    ctx.set_device();

    int64_t n_blocks_x = (n_examples + block_size - 1) / block_size;
    int64_t n_blocks_y = (n_cols + block_size - 1) / block_size;

    dim3 n_blocks(n_blocks_x, n_blocks_y);
    dim3 block_sizes(block_size, block_size);

    AgToNumeratorLvrm_Kernel<<<n_blocks, block_sizes, 0, ctx.stream>>>(
        n_examples, n_rows, n_cols,
        (float*) AG.data, d_example_row_offsets, (float*) W_update_numerator.data
    );
}


}  // custom
}  // ops
}  // gpu
}  // npeff
