#pragma once
// Converts the AG matrix into the numerator of the W-update step.
// This is equivalent to squaring its entries and then summing
// along the "classes" dimension.

#include <cstdint>

#include <cuda_runtime.h>
#include "cublas_v2.h"

#include <gpu/contexts/device_context.h>
#include <gpu/containers/dense_matrix.h>

#include <gpu/macros.h>
#include <util/macros.h>

namespace npeff {
namespace gpu {
namespace ops {
namespace custom {


class AgToNumerator {
    DeviceContext& ctx;

    DenseMatrix& AG;
    DenseMatrix& W_update_numerator;

    int64_t n_classes;
    int64_t n_examples;

public:
    AgToNumerator(
        DeviceContext& ctx,
        DenseMatrix& AG,
        DenseMatrix& W_update_numerator,
        int64_t n_classes
    ) :
        ctx(ctx), AG(AG), W_update_numerator(W_update_numerator),
        n_classes(n_classes), n_examples(AG.n_rows / n_classes)
    {
        if((AG.n_rows % n_classes) != 0) {
            THROW_MSG("The number of classes must divide the number of rows.");
        }
    }

    void call_async();
};


///////////////////////////////////////////////////////////////////////////////


class AgToNumeratorLvrm {
    DeviceContext& ctx;

    // AG.shape = [n_rows, rank]
    DenseMatrix& AG;

    // W_update_numerator.shape = [n_examples, rank]
    DenseMatrix& W_update_numerator;

    int64_t* d_example_row_offsets;

    const int64_t n_examples;
    const int64_t n_rows;
    const int64_t n_cols;

    // TODO: Figure out how to set this.
    const int64_t block_size = 16;

public:
    AgToNumeratorLvrm(
        DeviceContext& ctx,
        DenseMatrix& AG,
        DenseMatrix& W_update_numerator,
        int64_t* d_example_row_offsets
    ) :
        ctx(ctx), AG(AG), W_update_numerator(W_update_numerator),
        d_example_row_offsets(d_example_row_offsets),
        n_examples(W_update_numerator.n_rows), n_rows(AG.n_rows), n_cols(AG.n_cols)
    {
        // Matrix shape checks.
        THROW_IF_FALSE(AG.n_cols == n_cols);
        THROW_IF_FALSE(W_update_numerator.n_cols == n_cols);
    }

    void call_async();
};


}  // custom
}  // ops
}  // gpu
}  // npeff
