#pragma once
// Part of the G-update step.


#include <cstdint>

#include <cuda_runtime.h>
#include "cublas_v2.h"

#include <gpu/contexts/device_context.h>
#include <gpu/containers/dense_matrix.h>

#include <gpu/macros.h>
#include <util/macros.h>

namespace npeff {
namespace gpu {
namespace ops {
namespace custom {


class Compute_W_AG {
    DeviceContext& ctx;

    DenseMatrix& W;
    DenseMatrix& AG;
    DenseMatrix& W_AG;

    int64_t n_classes;
    int64_t n_examples;

public:
    Compute_W_AG(
        DeviceContext& ctx,
        DenseMatrix& W,
        DenseMatrix& AG,
        DenseMatrix& W_AG,
        int64_t n_classes
    ) :
        ctx(ctx), W(W), AG(AG), W_AG(W_AG),
        n_classes(n_classes), n_examples(AG.n_rows / n_classes)
    {
        if((AG.n_rows % n_classes) != 0) {
            THROW_MSG("The number of classes must divide the number of rows.");
        }
    }

    void call_async();

};


///////////////////////////////////////////////////////////////////////////////


class Compute_W_AG_Lvrm {
    DeviceContext& ctx;

    // W.shape = [n_examples, rank]
    DenseMatrix& W;

    // AG.shape = [n_rows, rank]
    DenseMatrix& AG;

    // W_AG.shape = [n_rows, rank]
    DenseMatrix& W_AG;

    int64_t* d_example_row_offsets;

    const int64_t n_examples;
    const int64_t n_rows;
    const int64_t n_cols;


    // TODO: Figure out how to set this.
    const int64_t block_size = 16;

public:
    Compute_W_AG_Lvrm(
        DeviceContext& ctx,
        DenseMatrix& W,
        DenseMatrix& AG,
        DenseMatrix& W_AG,
        int64_t* d_example_row_offsets
    ) :
        ctx(ctx), W(W), AG(AG), W_AG(W_AG), d_example_row_offsets(d_example_row_offsets),
        n_examples(W.n_rows), n_rows(AG.n_rows), n_cols(AG.n_cols)
    {
        // Matrix shape checks.
        THROW_IF_FALSE(W.n_cols == n_cols);
        THROW_IF_FALSE(AG.n_cols == n_cols);
        THROW_IF_FALSE(W_AG.n_cols == n_cols);

        THROW_IF_FALSE(AG.n_rows == n_rows);
        THROW_IF_FALSE(W_AG.n_rows == n_rows);
    }

    void call_async();

};

}  // custom
}  // ops
}  // gpu
}  // npeff
