
#include "./worker.h"


namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization2 {


using DnDnMatMul = gpu::ops::DnDnMatMul;
using FrobeniousInnerProduct = gpu::ops::FrobeniousInnerProduct;


///////////////////////////////////////////////////////////////////////////////
// Initialization stuff.


void DeviceWorker::set_up_work() {
    dctx.initialize();

    write_scalars_to_gpu();

    allocate_and_create_device_matrices();
    initialize_device_matrices();
}


void DeviceWorker::set_G_from_host(npeff::DenseMatrix<float>* initial_G) {
    if (initial_G == nullptr) { return; }

    // Make sure that G is the correct shape.
    THROW_IF_FALSE(initial_G->n_rows == d_ms.G->n_rows);
    THROW_IF_FALSE(initial_G->n_cols == d_ms.G->n_cols);

    dctx.set_device();
    dctx.copy_to_device_async((float*) d_ms.G->data, initial_G->data.get(), initial_G->n_entries);
}


void DeviceWorker::nccl_broadcast_of_G(DeviceWorker& src_worker) {
    NCCL_CALL(
        ncclBroadcast(
            src_worker.get_G_data_ptr(),
            get_G_data_ptr(),
            d_ms.G->n_entries,
            ncclFloat,
            src_worker.get_device(),
            dctx.comm,
            dctx.stream
        )
    );
}


///////////////////////////////////////////////////////////////////////////////
// General stuff.


void DeviceWorker::compute_partition_AG_async() {
    DnDnMatMul(dctx, *d_ms.A, *d_ms.G, *d_ms.AG, true, true).call_async();
}

void DeviceWorker::compute_partition_GG_async() {
    DnDnMatMul(dctx, *d_ms.G, *d_ms.G, *d_ms.GG, false, true).call_async();
}

void DeviceWorker::compute_local_WW_async() {
    DnDnMatMul(dctx, *d_ms.W, *d_ms.W, *d_ms.WW, true, false).call_async();
}

void DeviceWorker::nccl_all_reduce_WW() {
    NCCL_CALL(
        ncclAllReduce(
            (float*) d_ms.WW->data,
            (float*) d_ms.WW->data,
            d_ms.WW->n_entries,
            ncclFloat,
            ncclSum,
            dctx.comm,
            dctx.stream)
    );
}

void DeviceWorker::nccl_all_reduce_G_gradient() {
    NCCL_CALL(
        ncclAllReduce(
            (float*) d_ms.G_gradient->data,
            (float*) d_ms.G_gradient->data,
            d_ms.G_gradient->n_entries,
            ncclFloat,
            ncclSum,
            dctx.comm,
            dctx.stream)
    );
}


///////////////////////////////////////////////////////////////////////////////
// W-update stuff.

void DeviceWorker::update_partition_W_async(WUpdateOptions opts) {
    if (opts.recompute_partition_AG) { compute_partition_AG_async(); }
    if (opts.recompute_partition_GG) { compute_partition_GG_async(); }

    // Compute the numerator.
    gpu::ops::custom::AgToNumerator(
        dctx, *d_ms.AG, *d_ms.W_update_numerator, config.n_classes)
        .call_async();

    // Square GG to get HH.
    gpu::ops::custom::ElwiseSquare(dctx, *d_ms.GG, *d_ms.HH).call_async();
    // Compute W(HH) to get the denominator.
    DnDnMatMul(dctx, *d_ms.W, *d_ms.HH, *d_ms.W_update_denominator, false, false)
        .call_async();

    // Update the local partition of W.
    gpu::ops::custom::MultiplicativeUpdate(
        dctx, *d_ms.W, *d_ms.W_update_numerator, *d_ms.W_update_denominator, config.mu_eps)
        .call_async();
}


///////////////////////////////////////////////////////////////////////////////
// G-update stuff.


void DeviceWorker::update_G_step_pre_all_reduces_async(GUpdateOptions opts) {
    if (opts.recompute_WW) { compute_local_WW_async(); }
    if (opts.recompute_partition_GG) { compute_partition_GG_async(); }

    // The partition AG will be needed to compute the local contributions to the second
    // term in the G-gradient.
    if (opts.recompute_partition_AG) { compute_partition_AG_async(); }

    // Compute this partition's contribution to the second term.
    gpu::ops::custom::Compute_W_AG(
        dctx, *d_ms.W, *d_ms.AG, *d_ms.W_AG, config.n_classes)
        .call_async();
    DnDnMatMul(
        dctx,
        *d_ms.W_AG, *d_ms.A, *d_ms.G_gradient,
        true, true,
        d_ptrs.minus_1, dctx.dev_0f).call_async();
}

void DeviceWorker::update_G_step_all_reduces_async(GUpdateOptions opts) {
    if (opts.recompute_WW) { nccl_all_reduce_WW(); }
    nccl_all_reduce_G_gradient();
}

void DeviceWorker::update_G_step_post_all_reduces_async(GUpdateOptions opts) {
    // Compute the first G-gradient term and accumulate it to the buffer storing the gradient.
    gpu::ops::custom::HadamardProduct(
        dctx, *d_ms.WW, *d_ms.GG, *d_ms.WW_GG)
        .call_async();
    DnDnMatMul(
        dctx,
        *d_ms.WW_GG, *d_ms.G, *d_ms.G_gradient,
        false, false,
        dctx.dev_1f, dctx.dev_1f
    ).call_async();

    // Update the parameters G given the gradient.
    // 
    // The factor of 4 comes from the gradient being multiplied by that
    // but not accounted for in our computation of it.
    gpu::ops::custom::GradientDescentUpdate(
        dctx, *d_ms.G, *d_ms.G_gradient, 4.0f * opts.learning_rate)
        .call_async();
}

///////////////////////////////////////////////////////////////////////////////
// Loss computation stuff.


void DeviceWorker::compute_partition_loss_async(LossComputationOptions opts) {
    if (opts.recompute_partition_AG) { compute_partition_AG_async(); }
    if (opts.recompute_partition_GG) { compute_partition_GG_async(); }

    // Compute tr_WW_HH.
    gpu::ops::custom::ElwiseSquare(dctx, *d_ms.GG, *d_ms.HH).call_async();
    // NOTE: This line here invalidates the WW matrix since this will cause it
    // to contain only the contributions from this partition.
    DnDnMatMul(dctx, *d_ms.W, *d_ms.W, *d_ms.WW, true, false).call_async();
    FrobeniousInnerProduct(dctx, *d_ms.WW, *d_ms.HH, d_ptrs.d_tr_WW_HH)
        .call_async();

    // Compute tr_WHX.
    gpu::ops::custom::AgToNumerator(
        dctx, *d_ms.AG, *d_ms.W_update_numerator, config.n_classes)
        .call_async();
    FrobeniousInnerProduct(dctx, *d_ms.W, *d_ms.W_update_numerator, d_ptrs.d_tr_WHX)
        .call_async();
}


void DeviceWorker::read_partition_loss_from_device(std::mutex* mtx, float* loss) {
    dctx.set_device();

    float tr_WW_HH, tr_WHX; 
    dctx.copy_to_host_async<float>(&tr_WW_HH, d_ptrs.d_tr_WW_HH, 1);
    dctx.copy_to_host_async<float>(&tr_WHX, d_ptrs.d_tr_WHX, 1);
    dctx.synchronize_stream();

    mtx->lock();
    *loss += -2.0f * tr_WHX + tr_WW_HH;
    mtx->unlock();
}


///////////////////////////////////////////////////////////////////////////////
// Reading data back to host stuff.


void DeviceWorker::read_W_from_gpu_async(float* host_write_location) {
    read_matrix_from_gpu_async(*d_ms.W, host_write_location);
}


void DeviceWorker::read_G_from_gpu_async(float* host_write_location) {
    read_matrix_from_gpu_async(*d_ms.G, host_write_location);
}



///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Private member functions.


void DeviceWorker::write_scalars_to_gpu() {
    // Allocate memory for the scalars representing intermediate quantities.
    d_ptrs.d_tr_WW_HH = dctx.dmalloc<float>(2);
    d_ptrs.d_tr_WHX = d_ptrs.d_tr_WW_HH + 1;

    // Allocate memory for and write the constant scalars to the device.
    const int64_t n_scalars = 1;
    float* d_scalars = dctx.dmalloc<float>(n_scalars);
    d_ptrs.minus_1 = d_scalars + 0;

    float scalars[n_scalars] = {-1.0f,};
    dctx.copy_to_device_async(d_scalars, scalars, n_scalars);
}


void DeviceWorker::allocate_and_create_device_matrices() {
    int64_t n_partition_examples = host_partition->n_partition_examples();
    int64_t n_classes = host_partition->n_classes();
    int64_t n_parameters = host_partition->n_parameters();

    int64_t n_partition_rows = n_partition_examples * n_classes;

    int64_t rank = config.rank;

    /////////////////////////////////////////
    // Allocate the memory on the device.

    d_ptrs.d_A = dctx.dmalloc<float>(n_partition_rows * n_parameters);

    d_ptrs.d_W = dctx.dmalloc<float>(n_partition_examples * rank);
    d_ptrs.d_G = dctx.dmalloc<float>(rank * n_parameters);

    d_ptrs.d_ncr1 = dctx.dmalloc<float>(n_classes * n_partition_examples * rank);
    d_ptrs.d_ncr2 = dctx.dmalloc<float>(n_classes * n_partition_examples * rank);
    d_ptrs.d_nr = dctx.dmalloc<float>(n_partition_examples * rank);
    d_ptrs.d_rr1 = dctx.dmalloc<float>(rank * rank);
    d_ptrs.d_rr2 = dctx.dmalloc<float>(rank * rank);
    d_ptrs.d_rr3 = dctx.dmalloc<float>(rank * rank);
    d_ptrs.d_rr4 = dctx.dmalloc<float>(rank * rank);
    d_ptrs.d_rm = dctx.dmalloc<float>(rank * n_parameters);

    /////////////////////////////////////////
    // Create the matrices.

    // Input/parameter matrices.

    // NOTE: The dimensions of this matrix are transposed from the original matrix.
    d_ms.A = gpu::DenseMatrix::make_unique_ptr(n_parameters, n_partition_rows, d_ptrs.d_A);

    d_ms.W = gpu::DenseMatrix::make_unique_ptr(n_partition_examples, rank, d_ptrs.d_W);
    d_ms.G = gpu::DenseMatrix::make_unique_ptr(rank, n_parameters, d_ptrs.d_G);

    // Simple/common intermediate matrices.
    d_ms.WW = gpu::DenseMatrix::make_unique_ptr(rank, rank, d_ptrs.d_rr1);

    d_ms.GG = gpu::DenseMatrix::make_unique_ptr(rank, rank, d_ptrs.d_rr2);
    d_ms.HH = gpu::DenseMatrix::make_unique_ptr(rank, rank, d_ptrs.d_rr3);

    d_ms.AG = gpu::DenseMatrix::make_unique_ptr(n_classes * n_partition_examples, rank, d_ptrs.d_ncr1);

    // W-step specific intermediate matrices.
    d_ms.W_update_numerator = gpu::DenseMatrix::make_unique_ptr(n_partition_examples, rank, d_ptrs.d_nr);
    d_ms.W_update_denominator = gpu::DenseMatrix::make_unique_ptr(n_partition_examples, rank, d_ptrs.d_ncr2);

    // G-step specific intermediate matrices.
    d_ms.WW_GG = gpu::DenseMatrix::make_unique_ptr(rank, rank, d_ptrs.d_rr4);
    d_ms.W_AG = gpu::DenseMatrix::make_unique_ptr(n_classes * n_partition_examples, rank, d_ptrs.d_ncr2);
    d_ms.G_gradient = gpu::DenseMatrix::make_unique_ptr(rank, n_parameters, d_ptrs.d_rm);
}


void DeviceWorker::initialize_device_matrices() {
    dctx.set_device();


    // Move A onto the GPU.
    auto pefs_matrix = host_partition->pefs_matrix();
    dctx.copy_to_device_async(*d_ms.A, pefs_matrix);

    #warning
    // TODO: I think the random generators on all of the devices will produce the same sequences for W.
    // See if I can initialize them in a way that will be the same regardless of the number of devices.
    #warning

    auto& W = d_ms.W;
    CURAND_CALL(
        curandGenerateUniform(dctx.rand_gen, (float*) W->data, W->n_rows * W->n_cols)
    );

    // If we are the first partition, initialize G.
    if(partition_index == 0) {
        auto& G = d_ms.G;
        double inv_g_factor = config.compute_inv_g_initialization_scale_factor();
        CURAND_CALL(
            curandGenerateNormal(dctx.rand_gen, (float*) G->data, G->n_rows * G->n_cols, 0.0f, 1.0 / inv_g_factor)
        );
    }

    // Synchronize the stream to assure that everything associated to A
    // has been copied onto the GPU.
    dctx.synchronize_stream();
}


void DeviceWorker::read_matrix_from_gpu_async(gpu::DenseMatrix& matrix, float* host_write_location) {
    dctx.set_device();
    dctx.copy_to_host_async(host_write_location, (float*) matrix.data, matrix.n_entries);
}


}  // dn_lrm_factorization2
}  // factorizations
}  // npeff
