#include "./gradient_descent.h"

namespace npeff {
namespace gpu {
namespace ops {
namespace custom {


__global__
void GradientDescentUpdate_Kernel(
    int64_t n, float* data, float* gradient, const float learning_rate, const float alpha = 0.0f
) {
    INDEX_STRIDE_1D(n, i) {
        // data[i] -= learning_rate * gradient[i];
        data[i] -= learning_rate * gradient[i] + learning_rate * alpha * data[i];
    }
}


void GradientDescentUpdate::call_async() {
    ctx.set_device();
    long n_blocks = (n_elements + block_size - 1) / block_size;

    GradientDescentUpdate_Kernel<<<n_blocks, block_size, 0, ctx.stream>>>(
        n_elements, (float*) data, (float*) gradient, learning_rate, alpha
    );
}


}  // custom
}  // ops
}  // gpu
}  // npeff
