
#include "./device_worker.h"


namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization {


///////////////////////////////////////////////////////////////////////////////
// Initialization stuff.


void DeviceWorker::set_up_work() {
    dctx.initialize();

    write_scalars_to_gpu();

    allocate_and_create_device_matrices();
    initialize_device_matrices();
}


void DeviceWorker::nccl_broadcast_of_W(DeviceWorker& src_worker) {
    NCCL_CALL(
        ncclBroadcast(
            src_worker.get_W_data_ptr(),
            get_W_data_ptr(),
            d_ms.W->n_entries,
            ncclFloat,
            src_worker.get_device(),
            dctx.comm,
            dctx.stream
        )
    );
}


///////////////////////////////////////////////////////////////////////////////
// General stuff.


void DeviceWorker::compute_local_AG_GG_async() {
    DnDnMatMul(dctx, *d_ms.A, *d_ms.G, *d_ms.AG, false, true).call_async();
    DnDnMatMul(dctx, *d_ms.G, *d_ms.G, *d_ms.GG, false, true).call_async();
}


void DeviceWorker::nccl_all_reduce_AG_GG() {
    NCCL_CALL(
        ncclAllReduce(
            (float*) d_ms.AG->data,
            (float*) d_ms.AG->data,
            d_ms.AG->n_entries,
            ncclFloat,
            ncclSum,
            dctx.comm,
            dctx.stream)
    );
    NCCL_CALL(
        ncclAllReduce(
            (float*) d_ms.GG->data,
            (float*) d_ms.GG->data,
            d_ms.GG->n_entries,
            ncclFloat,
            ncclSum,
            dctx.comm,
            dctx.stream)
    );
}


///////////////////////////////////////////////////////////////////////////////
// W-update stuff.


void DeviceWorker::update_local_W_after_all_reduces_async() {
    // Must called be after the AG and GG matrices have been all-reduced
    // across all devices (or slated to do so in the streams).

    // Compute the numerator.
    gpu::ops::custom::AgToNumerator(
        dctx, *d_ms.AG, *d_ms.W_update_numerator, config.n_classes)
        .call_async();

    // Square GG to get HH.
    gpu::ops::custom::ElwiseSquare(dctx, *d_ms.GG, *d_ms.HH).call_async();
    // Compute W(HH) to get the denominator.
    DnDnMatMul(dctx, *d_ms.W, *d_ms.HH, *d_ms.W_update_denominator, false, false)
        .call_async();

    // Update the local copy of W.
    gpu::ops::custom::MultiplicativeUpdate(
        dctx, *d_ms.W, *d_ms.W_update_numerator, *d_ms.W_update_denominator, config.mu_eps)
        .call_async();
}


///////////////////////////////////////////////////////////////////////////////
// G-update stuff.


void DeviceWorker::update_local_G_after_all_reduces_async(float learning_rate_G) {

        // Compute the first term and write it to the buffer storing the gradient.
        DnDnMatMul(dctx, *d_ms.W, *d_ms.W, *d_ms.WW, true, false).call_async();
        gpu::ops::custom::HadamardProduct(
            dctx, *d_ms.WW, *d_ms.GG, *d_ms.WW_GG)
            .call_async();
        DnDnMatMul(dctx, *d_ms.WW_GG, *d_ms.G, *d_ms.G_gradient, false, false).call_async();

        // Compute the second term and accumulate it onto the gradient buffer.
        gpu::ops::custom::Compute_W_AG(
            dctx, *d_ms.W, *d_ms.AG, *d_ms.W_AG, config.n_classes)
            .call_async();
        DnDnMatMul(
            dctx,
            *d_ms.W_AG, *d_ms.A, *d_ms.G_gradient,
            true, false,
            d_ptrs.minus_1, dctx.dev_1f).call_async();

        // Update the parameters G given the gradient.
        // 
        // The factor of 4 comes from the gradient being multiplied by that
        // but not accounted for in our computation of it.
        gpu::ops::custom::GradientDescentUpdate(
            dctx, *d_ms.G, *d_ms.G_gradient, 4.0f * learning_rate_G)
            .call_async();
}


///////////////////////////////////////////////////////////////////////////////
// Loss computation stuff.


void DeviceWorker::compute_loss_after_all_reduces_async() {
    // Compute tr_WW_HH.
    gpu::ops::custom::ElwiseSquare(dctx, *d_ms.GG, *d_ms.HH).call_async();
    DnDnMatMul(dctx, *d_ms.W, *d_ms.W, *d_ms.WW, true, false).call_async();
    FrobeniousInnerProduct(dctx, *d_ms.WW, *d_ms.HH, d_ptrs.d_tr_WW_HH)
        .call_async();

    // Compute tr_WHX.
    gpu::ops::custom::AgToNumerator(
        dctx, *d_ms.AG, *d_ms.W_update_numerator, config.n_classes)
        .call_async();
    FrobeniousInnerProduct(dctx, *d_ms.W, *d_ms.W_update_numerator, d_ptrs.d_tr_WHX)
        .call_async();
}


float DeviceWorker::read_loss_term_from_device() {
    dctx.set_device();
    float tr_WW_HH, tr_WHX; 
    dctx.copy_to_host_async<float>(&tr_WW_HH, d_ptrs.d_tr_WW_HH, 1);
    dctx.copy_to_host_async<float>(&tr_WHX, d_ptrs.d_tr_WHX, 1);
    dctx.synchronize_stream();
    return -2.0f * tr_WHX + tr_WW_HH;
}


///////////////////////////////////////////////////////////////////////////////
// Reading data back to host stuff.


void DeviceWorker::read_W_from_gpu_async(float* host_write_location) {
    read_matrix_from_gpu_async(*d_ms.W, host_write_location);
}


void DeviceWorker::read_G_from_gpu_async(float* host_write_location) {
    read_matrix_from_gpu_async(*d_ms.G, host_write_location);
}


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Private member functions.


void DeviceWorker::write_scalars_to_gpu() {
    // Allocate memory for the scalars representing intermediate quantities.
    d_ptrs.d_tr_WW_HH = dctx.dmalloc<float>(2);
    d_ptrs.d_tr_WHX = d_ptrs.d_tr_WW_HH + 1;

    // Allocate memory for and write the constant scalars to the device.
    const int64_t n_scalars = 1;
    float* d_scalars = dctx.dmalloc<float>(n_scalars);
    d_ptrs.minus_1 = d_scalars + 0;

    float scalars[n_scalars] = {-1.0f,};
    dctx.copy_to_device_async(d_scalars, scalars, n_scalars);
}


void DeviceWorker::allocate_and_create_device_matrices() {
    int64_t n_rows = host_matrix_partition->n_rows;
    int64_t n_cols = host_matrix_partition->n_cols;

    int64_t n_classes = config.n_classes;
    int64_t rank = config.rank;

    int64_t n_examples = n_rows / n_classes;

    /////////////////////////////////////////
    // Allocate the memory on the device.

    d_ptrs.d_A = dctx.dmalloc<float>(n_rows * n_cols);

    d_ptrs.d_W = dctx.dmalloc<float>(n_examples * rank);
    d_ptrs.d_G = dctx.dmalloc<float>(rank * n_cols);

    d_ptrs.d_ncr1 = dctx.dmalloc<float>(n_classes * n_examples * rank);
    d_ptrs.d_ncr2 = dctx.dmalloc<float>(n_classes * n_examples * rank);
    d_ptrs.d_nr = dctx.dmalloc<float>(n_examples * rank);
    d_ptrs.d_rr1 = dctx.dmalloc<float>(rank * rank);
    d_ptrs.d_rr2 = dctx.dmalloc<float>(rank * rank);
    d_ptrs.d_rr3 = dctx.dmalloc<float>(rank * rank);
    d_ptrs.d_rm = dctx.dmalloc<float>(rank * n_cols);

    /////////////////////////////////////////
    // Create the matrices.

    // Input/parameter matrices.
    d_ms.A = gpu::DenseMatrix::make_unique_ptr(n_rows, n_cols, d_ptrs.d_A);

    d_ms.W = gpu::DenseMatrix::make_unique_ptr(n_examples, rank, d_ptrs.d_W);
    d_ms.G = gpu::DenseMatrix::make_unique_ptr(rank, n_cols, d_ptrs.d_G);

    // Simple/common intermediate matrices.
    d_ms.WW = gpu::DenseMatrix::make_unique_ptr(rank, rank, d_ptrs.d_rr1);

    d_ms.GG = gpu::DenseMatrix::make_unique_ptr(rank, rank, d_ptrs.d_rr2);
    d_ms.HH = gpu::DenseMatrix::make_unique_ptr(rank, rank, d_ptrs.d_rr3);

    d_ms.AG = gpu::DenseMatrix::make_unique_ptr(n_classes * n_examples, rank, d_ptrs.d_ncr1);

    // W-step specific intermediate matrices.
    d_ms.W_update_numerator = gpu::DenseMatrix::make_unique_ptr(n_examples, rank, d_ptrs.d_nr);
    d_ms.W_update_denominator = gpu::DenseMatrix::make_unique_ptr(n_examples, rank, d_ptrs.d_ncr2);

    // G-step specific intermediate matrices.
    d_ms.WW_GG = gpu::DenseMatrix::make_unique_ptr(rank, rank, d_ptrs.d_rr1);
    d_ms.W_AG = gpu::DenseMatrix::make_unique_ptr(n_classes * n_examples, rank, d_ptrs.d_ncr2);
    d_ms.G_gradient = gpu::DenseMatrix::make_unique_ptr(rank, n_cols, d_ptrs.d_rm);
}


void DeviceWorker::initialize_device_matrices() {
    dctx.set_device();

    // Move A onto the GPU.
    dctx.copy_to_device_async(*d_ms.A, *host_matrix_partition);

    // If we are the first partition, initialize W with a uniform random
    // distribution.
    if(partition_index == 0) {
        auto& W = d_ms.W;
        CURAND_CALL(
            curandGenerateUniform(dctx.rand_gen, (float*) W->data, W->n_rows * W->n_cols)
        );
    }

    auto& G = d_ms.G;
    double inv_g_factor = config.compute_inv_g_initialization_scale_factor();
    CURAND_CALL(
        curandGenerateNormal(dctx.rand_gen, (float*) G->data, G->n_rows * G->n_cols, 0.0f, 1.0 / inv_g_factor)
    );

    // Synchronize the stream to assure that everything associated to A
    // has been copied onto the GPU. Then free the memory assoicated to
    // A on the host.
    dctx.synchronize_stream();
    host_matrix_partition.reset();
}


void DeviceWorker::read_matrix_from_gpu_async(gpu::DenseMatrix& matrix, float* host_write_location) {
    dctx.set_device();
    dctx.copy_to_host_async(host_write_location, (float*) matrix.data, matrix.n_entries);
}




}  // dn_lrm_factorization
}  // factorizations
}  // npeff
