
#include <vector>
#include "./manager.h"


namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization {


using DeviceWorkerPtr = std::unique_ptr<DeviceWorker>;
using DenseMatrixPtr = std::unique_ptr<npeff::DenseMatrix<float>>;


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Public member functions.


DenseMatrixPtr MultiGpuManager::read_W_from_gpu(int64_t src_device) {
    DeviceWorkerPtr& src_worker = workers[src_device];
    DenseMatrixPtr W = std::unique_ptr<DenseMatrix<float>>(
        new DenseMatrix<float>(n_examples, config.rank));
    src_worker->read_W_from_gpu_async(W->data.get());
    src_worker->synchronize_stream();
    return W;
}


DenseMatrixPtr MultiGpuManager::read_G_from_gpu() {
    DenseMatrixPtr G = std::unique_ptr<DenseMatrix<float>>(
        new DenseMatrix<float>(config.rank, n_cols_total));

    float* data = G->data.get();
    for(int64_t i=0; i<n_partitions; i++) {
        workers[i]->read_G_from_gpu_async(data);
        data += config.rank * n_cols_per_partition[i];
    }

    synchronize_all_streams();
    return G;
}


void MultiGpuManager::run() {
    get_workers_ready();
    run_G_only();
    run_joint();
}


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Private member functions.


void MultiGpuManager::run_G_only() {
    t_step_start = std::chrono::high_resolution_clock::now();
    for (int64_t step = 0; step < config.n_iters_G_only; step++) {
        G_update_step(true, config.learning_rate_G_G_only);
        if(should_log_loss_at_step(step)) {
            log_loss("G only", step, losses_G_only);
        }
    }
}


void MultiGpuManager::run_joint() {
    t_step_start = std::chrono::high_resolution_clock::now();
    for (int64_t step = 0; step < config.n_iters_joint; step++) {
        W_update_step(true);
        G_update_step(false, config.learning_rate_G_joint);
        if(should_log_loss_at_step(step)) {
            log_loss("joint", step, losses_joint);
        }
    }
}


bool MultiGpuManager::should_log_loss_at_step(int64_t step) const {
    return (step + 1) % config.log_loss_frequency == 0;
}


void MultiGpuManager::log_loss(const std::string& prefix, int64_t step, std::vector<float>& losses) {
    auto t_end = std::chrono::high_resolution_clock::now();
    double elapsed_ms = std::chrono::duration<double, std::milli>(t_end-t_step_start).count();

    // NOTE: For `run_joint`, I think we could save the all-reduces
    // computing the loss after the W update step. For simplicity
    // and robustness, I'm always doing the all-reduces for now.
    double loss = compute_loss(true);
    losses.push_back(loss);

    std::cout << prefix << " step " << step + 1 << ": " << loss << " [" << elapsed_ms / (double) config.log_loss_frequency << " ms/step]\n";

    t_step_start = std::chrono::high_resolution_clock::now();
}


///////////////////////////////////////////////////////////////////////////////
// Set up and sub-steps of the optimization.


void MultiGpuManager::get_workers_ready() {
    call_on_workers_then_join(&DeviceWorker::set_up_work);

    // Replicate the W from device 0 on all of the devices. This
    // ensures consistency of the shared parameters across devices.
    broadcast_W_async();
    synchronize_all_streams();
}


void MultiGpuManager::W_update_step(bool recompute_AG_GG) {
    if(recompute_AG_GG) {
        call_on_workers_then_join(&DeviceWorker::compute_local_AG_GG_async);
        all_reduce_AG_GG_async();
    }
    call_on_workers_then_join(&DeviceWorker::update_local_W_after_all_reduces_async);
    synchronize_all_streams();
}


void MultiGpuManager::G_update_step(bool recompute_AG_GG, float learning_rate_G) {
    if(recompute_AG_GG) {
        call_on_workers_then_join(&DeviceWorker::compute_local_AG_GG_async);
        all_reduce_AG_GG_async();
    }
    call_on_workers_then_join(&DeviceWorker::update_local_G_after_all_reduces_async, learning_rate_G);
    synchronize_all_streams();
}


float MultiGpuManager::compute_loss(bool recompute_AG_GG) {
    if(recompute_AG_GG) {
        call_on_workers_then_join(&DeviceWorker::compute_local_AG_GG_async);
        all_reduce_AG_GG_async();
    }
    // We only need to compute the loss using a single GPU.
    auto& loss_worker = workers[0];
    loss_worker->compute_loss_after_all_reduces_async();
    float loss_term = loss_worker->read_loss_term_from_device();
    synchronize_all_streams();
    return loss_term + config.tr_xx;
}


///////////////////////////////////////////////////////////////////////////////
// Initializations.


int64_t MultiGpuManager::compute_n_examples(const std::vector<DenseMatrixPtr>& column_partitions) const {
    int64_t n_rows = column_partitions[0]->n_rows;
    THROW_IF_FALSE((n_rows % config.n_classes) == 0);
    return n_rows / config.n_classes;
}


int64_t MultiGpuManager::compute_n_cols_total(const std::vector<DenseMatrixPtr>& column_partitions) const {
    int64_t n_cols = 0;
    for(auto& mat : column_partitions) { n_cols += mat->n_cols; }
    return n_cols;
}


std::vector<int64_t> MultiGpuManager::compute_n_cols_per_partition(const std::vector<DenseMatrixPtr>& column_partitions) const {
    std::vector<int64_t> ret;
    for(auto& mat : column_partitions) {
        ret.push_back(mat->n_cols);
    }
    return ret;
}


void MultiGpuManager::initialize_nccl() {
    comms = std::unique_ptr<ncclComm_t>(new ncclComm_t[n_partitions]);
    NCCL_CALL(ncclCommInitAll(comms.get(), n_partitions, NULL));
}


void MultiGpuManager::create_workers(std::vector<DenseMatrixPtr>& column_partitions) {
    for (int64_t i=0; i<n_partitions; i++) {
        workers.push_back(
            DeviceWorkerPtr(new DeviceWorker(
                std::move(column_partitions[i]),
                config,
                i,
                n_partitions,
                comms.get()[i]
            )));
    }
}


///////////////////////////////////////////////////////////////////////////////
// Utilities for dealing the multiple workers.


void MultiGpuManager::broadcast_W_async(int64_t src_device) {
    DeviceWorkerPtr& src_worker = workers[src_device];
    NCCL_CALL(ncclGroupStart());
    for(auto& worker : workers) {
        worker->nccl_broadcast_of_W(*src_worker);
    }
    NCCL_CALL(ncclGroupEnd());
}


void MultiGpuManager::all_reduce_AG_GG_async() {
    NCCL_CALL(ncclGroupStart());
    for(auto& worker : workers) {
        worker->nccl_all_reduce_AG_GG();
    }
    NCCL_CALL(ncclGroupEnd());
}


void MultiGpuManager::synchronize_all_streams() {
    for(auto& worker : workers) {
        worker->synchronize_stream();
    }
}


void MultiGpuManager::join_threads(std::vector<std::thread>& threads) {
    for (auto& thread : threads) { thread.join(); }
}


}  // dn_lrm_factorization
}  // factorizations
}  // npeff
