#include "./multiplicative_update.h"


namespace npeff {
namespace gpu {
namespace ops {
namespace custom {


__global__
void MultiplicativeUpdate_Kernel(
    long n, float* F, const float* numer, const float* denom, float eps
) {
    // F *= numer / (denom + eps)
    INDEX_STRIDE_1D(n, i) {
        F[i] *= numer[i] / (denom[i] + eps);
    }
}


void MultiplicativeUpdate::call_async() {
    ctx.set_device();
    long n = out.n_entries;
    long n_blocks = (n + block_size - 1) / block_size;

    MultiplicativeUpdate_Kernel<<<n_blocks, block_size, 0, ctx.stream>>>(
        n, (float*) out.data, (float*) numer.data, (float*) denom.data, eps
    );
}


}  // custom
}  // ops
}  // gpu
}  // npeff
