#include "./multiply_and_add_identity.h"



namespace npeff {
namespace gpu {
namespace ops {
namespace custom {


// This assumes that the matrix is in column-major format.
__global__
void MultiplyAndAddIdentity_InPlace_Kernel(const int64_t r, float* data, const float multiply_factor) {
    int64_t n = r * r;
    INDEX_STRIDE_1D(n, i) {
        int64_t row = i % r;
        int64_t col = i / r;
        if (row == col) {
            data[i] = multiply_factor * data[i] + 1;
        } else {
            data[i] *= multiply_factor;
        }
    }
}


void MultiplyAndAddIdentity_InPlace::call_async() {
    ctx.set_device();

    long n_blocks = (mat.n_entries + block_size - 1) / block_size;

    MultiplyAndAddIdentity_InPlace_Kernel<<<n_blocks, block_size, 0, ctx.stream>>>(
        mat.n_rows, (float*) mat.data, multiply_factor
    );
}


}  // custom
}  // ops
}  // gpu
}  // npeff
