#include <mutex>

#include "./manager.h"


namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization2 {


using DenseMatrixPtr = std::unique_ptr<npeff::DenseMatrix<float>>;


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Public member functions.


void MultiGpuManager::initialize(npeff::DenseMatrix<float>* initial_G) {
    call_on_workers_then_join(&DeviceWorker::set_up_work);

    if (initial_G != nullptr) {
        // Only copy onto the first device. Its value of G will be broadcast to the other devices.
        workers[0]->set_G_from_host(initial_G);
        synchronize_all_streams();
    }

    // Replicate the G from device 0 on all of the devices. This
    // ensures consistency of the shared parameters across devices.
    broadcast_G_async();
    synchronize_all_streams();

    this->initialized = true;
}


void MultiGpuManager::run() {
    if(!initialized) { THROW_MSG("The initialize() method must be called before the run() method on a MultiGpuManager."); }
    run_G_only();
    run_joint();
}


std::unique_ptr<outputs::WPartitions> MultiGpuManager::read_W_from_gpu() {
    std::vector<DenseMatrixPtr> W_partitions;

    for(int64_t i=0; i<n_partitions; i++) {
        DenseMatrixPtr W = std::unique_ptr<DenseMatrix<float>>(
            new DenseMatrix<float>(output_info.n_examples_per_partition[i], output_info.rank));
        workers[i]->read_W_from_gpu_async(W->data.get());
        W_partitions.push_back(std::move(W));
    }

    synchronize_all_streams();

    return std::unique_ptr<outputs::WPartitions>(new outputs::WPartitions(std::move(W_partitions)));
}


DenseMatrixPtr MultiGpuManager::read_G_from_gpu(int64_t src_device) {
    DenseMatrixPtr G = std::unique_ptr<DenseMatrix<float>>(
        new DenseMatrix<float>(output_info.rank, output_info.n_parameters));

    DeviceWorkerPtr& src_worker = workers[src_device];
    src_worker->read_G_from_gpu_async(G->data.get());
    src_worker->synchronize_stream();

    return G;
}


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Private member functions.


void MultiGpuManager::run_G_only() {
    // TODO: Having the recompute_partition_{AG, GG} set to true for both the
    // loss computation and G update leads to them being needlessly recomputed
    // after each loss update.
    // 
    // I'm currently doing this to make timing more robust, but I might want to
    // change this in the future.
    DeviceWorker::LossComputationOptions opts_loss;
    opts_loss.recompute_partition_AG = true;
    opts_loss.recompute_partition_GG = true;

    DeviceWorker::GUpdateOptions opts_G;
    opts_G.learning_rate = config.learning_rate_G_G_only;
    opts_G.recompute_partition_AG = true;
    opts_G.recompute_partition_GG = true;
    // Needed for at least the first iteration.
    opts_G.recompute_WW = true;

    t_step_start = std::chrono::high_resolution_clock::now();
    for (int64_t step = 0; step < config.n_iters_G_only; step++) {

        G_update_step(opts_G);

        if(should_log_loss_at_step(step)) {
            log_loss("G only", step, losses_G_only, opts_loss);
            // The loss computation invalidates the WW matrix, so we need to recompute it.
            // TODO: It's probably possible to get rid of this by using some unused buffer
            // store the valid copy of WW.
            opts_G.recompute_WW = true;
        } else {
            // Since the coefficients W are not updated and the WW matrix was not invalidated,
            // we do not need to update it for the next run.
            opts_G.recompute_WW = false;
        }
    }
}


void MultiGpuManager::run_joint() {
    DeviceWorker::LossComputationOptions opts_loss;
    // These are set to false as they will have been computed during the W-update
    // step preceding the loss computation.
    opts_loss.recompute_partition_AG = false;
    opts_loss.recompute_partition_GG = false;

    DeviceWorker::GUpdateOptions opts_G;
    opts_G.learning_rate = config.learning_rate_G_joint;
    opts_G.recompute_WW = true;
    // Needed for only the first iteration since the W-update step will otherwise
    // take of this.
    opts_G.recompute_partition_AG = true;
    opts_G.recompute_partition_GG = true;

    DeviceWorker::WUpdateOptions opts_W;
    opts_loss.recompute_partition_AG = true;
    opts_loss.recompute_partition_GG = true;

    // NOTE: Unlike the original version, I'm doing the G-update step before the W-update step.
    // This is so that we do not induce an extra all-reduce to recompute the WW after it gets
    // invalidated during the W-update step.
    t_step_start = std::chrono::high_resolution_clock::now();
    for (int64_t step = 0; step < config.n_iters_joint; step++) {
        G_update_step(opts_G);
        W_update_step(opts_W);
        if(should_log_loss_at_step(step)) {
            log_loss("joint", step, losses_joint, opts_loss);
        }
        opts_G.recompute_partition_AG = false;
        opts_G.recompute_partition_GG = false;
    }
}


bool MultiGpuManager::should_log_loss_at_step(int64_t step) const {
    return (step + 1) % config.log_loss_frequency == 0;
}


void MultiGpuManager::log_loss(
    const std::string& prefix, int64_t step, std::vector<float>& losses, DeviceWorker::LossComputationOptions opts
) {
    auto t_end = std::chrono::high_resolution_clock::now();
    double elapsed_ms = std::chrono::duration<double, std::milli>(t_end-t_step_start).count();

    double loss = compute_loss(opts);
    losses.push_back(loss);

    std::cout << prefix << " step " << step + 1 << ": " << loss << " [" << elapsed_ms / (double) config.log_loss_frequency << " ms/step]\n";

    t_step_start = std::chrono::high_resolution_clock::now();
}



///////////////////////////////////////////////////////////////////////////////
// Sub-steps of the optimization.

void MultiGpuManager::W_update_step(DeviceWorker::WUpdateOptions opts) {
    call_on_workers_then_join(&DeviceWorker::update_partition_W_async, opts);
    synchronize_all_streams();
}

void MultiGpuManager::G_update_step(DeviceWorker::GUpdateOptions opts) {
    call_on_workers_then_join(&DeviceWorker::update_G_step_pre_all_reduces_async, opts);
    G_update_all_reduce_async(opts);
    call_on_workers_then_join(&DeviceWorker::update_G_step_post_all_reduces_async, opts);
    synchronize_all_streams();
}

float MultiGpuManager::compute_loss(DeviceWorker::LossComputationOptions opts) {
    std::mutex mtx;
    float loss = config.tr_xx;

    call_on_workers_then_join(&DeviceWorker::compute_partition_loss_async, opts);
    call_on_workers_then_join(&DeviceWorker::read_partition_loss_from_device, &mtx, &loss);

    return loss;
}


///////////////////////////////////////////////////////////////////////////////
// Initializations.


internal::MultiGpuManagerOutputInfo MultiGpuManager::make_output_info(const std::vector<DnLrmPefsPartitionPtr>& partitions) const {
    internal::MultiGpuManagerOutputInfo ret;
    ret.n_examples_total = config.n_examples_total;
    ret.rank = config.rank;
    ret.n_parameters = partitions[0]->n_parameters();
    for(const auto& p : partitions) {
        ret.n_examples_per_partition.push_back(p->n_partition_examples());
    }
    return ret;
}


void MultiGpuManager::initialize_nccl() {
    comms = std::unique_ptr<ncclComm_t>(new ncclComm_t[n_partitions]);
    NCCL_CALL(ncclCommInitAll(comms.get(), n_partitions, NULL));
}


void MultiGpuManager::create_workers(std::vector<DnLrmPefsPartitionPtr>& partitions) {
    for (int64_t i=0; i<n_partitions; i++) {
        workers.push_back(
            DeviceWorkerPtr(new DeviceWorker(
                std::move(partitions[i]),
                config,
                i,
                n_partitions,
                comms.get()[i]
            )));
    }
}


///////////////////////////////////////////////////////////////////////////////
// Utilities for dealing the multiple workers.


void MultiGpuManager::broadcast_G_async(int64_t src_device) {
    DeviceWorkerPtr& src_worker = workers[src_device];
    NCCL_CALL(ncclGroupStart());
    for(auto& worker : workers) {
        worker->nccl_broadcast_of_G(*src_worker);
    }
    NCCL_CALL(ncclGroupEnd());
}


void MultiGpuManager::all_reduce_WW_async() {
    NCCL_CALL(ncclGroupStart());
    for(auto& worker : workers) {
        worker->nccl_all_reduce_WW();
    }
    NCCL_CALL(ncclGroupEnd());
}

void MultiGpuManager::G_update_all_reduce_async(DeviceWorker::GUpdateOptions opts) {
    NCCL_CALL(ncclGroupStart());
    for(auto& worker : workers) {
        worker->update_G_step_all_reduces_async(opts);
    }
    NCCL_CALL(ncclGroupEnd());
}


void MultiGpuManager::synchronize_all_streams() {
    for(auto& worker : workers) {
        worker->synchronize_stream();
    }
}


void MultiGpuManager::join_threads(std::vector<std::thread>& threads) {
    for (auto& thread : threads) { thread.join(); }
}


}  // dn_lrm_factorization2
}  // factorizations
}  // npeff
