#pragma once
// The multiplicatively update parameters given the numerator and denominator.

#include <cstdint>

#include <cuda_runtime.h>

#include <gpu/contexts/device_context.h>
#include <gpu/containers/dense_matrix.h>


namespace npeff {
namespace gpu {
namespace ops {
namespace custom {


class MultiplicativeUpdate {
    DeviceContext& ctx;

    DenseMatrix& out;
    DenseMatrix& numer;
    DenseMatrix& denom;
    float eps;

    // TODO: Figure out how to set this.
    const int64_t block_size = 256;

public:
    MultiplicativeUpdate(
        DeviceContext& ctx,
        DenseMatrix& out,
        DenseMatrix& numer,
        DenseMatrix& denom,
        float eps
    ) : 
        ctx(ctx), out(out), numer(numer), denom(denom), eps(eps)
    {
        // Validation.
        THROW_IF_FALSE(numer.n_rows == out.n_rows);
        THROW_IF_FALSE(denom.n_rows == out.n_rows);
        THROW_IF_FALSE(numer.n_cols == out.n_cols);
        THROW_IF_FALSE(denom.n_cols == out.n_cols);
    }

    void call_async();
};


}  // custom
}  // ops
}  // gpu
}  // npeff
