#pragma once

#include <ctime>
#include <chrono>
#include <memory>
#include <thread>
#include <vector>

#include "nccl.h"

#include <util/macros.h>
#include <containers/dense_matrix.h>
#include <containers/sparse_matrix.h>

#include "./config.h"
#include "./device_worker.h"

namespace npeff {
namespace expansion {


template<typename IndT>
class MultiGpuManager {
    using DeviceWorkerPtr = std::unique_ptr<DeviceWorker<IndT>>;
    using CsrMatrixPtr = std::unique_ptr<CsrMatrix<IndT>>;
    using DenseMatrixPtr = std::unique_ptr<DenseMatrix<float>>;

    ExpansionConfig config;
    int64_t n_partitions;

    // Can be empty. Will be cleared once the data has been moved to the devices.
    DenseMatrixPtr initial_W_f;

    // The array of comms used for nccl.
    std::unique_ptr<ncclComm_t> comms;

    // A vector of workers. Each is assigned to a unique device.
    std::vector<DeviceWorkerPtr> workers;

    // Information about the size of the matrix and its partitions.
    const int64_t n_examples;
    const int64_t n_cols_total;
    const std::vector<int64_t> n_cols_per_partition;

public:

    // Collects losses whenever they get computed.
    std::vector<float> losses_G_only;
    std::vector<float> losses_joint;

    MultiGpuManager(
        std::vector<CsrMatrixPtr>& column_partitions,
        std::vector<DenseMatrixPtr>& frozen_G_partitions,
        DenseMatrixPtr initial_W_f,
        ExpansionConfig config
    ) :
        config(config),
        initial_W_f(std::move(initial_W_f)),
        n_partitions(column_partitions.size()),
        n_examples(compute_n_examples(column_partitions)),
        n_cols_total(compute_n_cols_total(column_partitions)),
        n_cols_per_partition(compute_n_cols_per_partition(column_partitions))
    {
        verify_compatability_of_partitions(column_partitions, frozen_G_partitions);
        initialize_nccl();
        create_workers(column_partitions, frozen_G_partitions);
    }

    ~MultiGpuManager() {
        // Have this check in place so that we do not try to
        // pass junk addresses to ncclCommDestroy in case we
        // did not allocate the comms array at the time of
        // desctruction.
        if(comms) {
            for (int64_t i=0; i<n_partitions; i++) {
                ncclCommDestroy(comms.get()[i]);
            }
        }
    }


    DenseMatrixPtr read_W_from_gpu(int64_t src_device = 0) {
        DeviceWorkerPtr& src_worker = workers[src_device];
        auto* d_W = src_worker->get_W_ptr();
        DenseMatrixPtr W = std::unique_ptr<DenseMatrix<float>>(
            new DenseMatrix<float>(n_examples, config.total_rank()));
        src_worker->dctx.copy_to_host_async(W->data.get(), (float*) d_W->data, d_W->n_entries);
        src_worker->synchronize_stream();
        return W;
    }

    DenseMatrixPtr read_G_from_gpu() {
        DenseMatrixPtr G = std::unique_ptr<DenseMatrix<float>>(
            new DenseMatrix<float>(config.total_rank(), n_cols_total));

        int64_t col_offset = 0;
        for(int64_t i=0; i<n_partitions; i++) {
            auto& worker = workers[i];

            copy_to_host_into_submatrix_async(
                worker->dctx,
                *worker->get_G_e_ptr(),
                *G,
                0,
                col_offset
            );

            copy_to_host_into_submatrix_async(
                worker->dctx,
                *worker->get_G_f_ptr(),
                *G,
                config.rank_expansion,
                col_offset
            );

            col_offset += n_cols_per_partition[i];
        }

        synchronize_all_streams();
        return G;
    }

    /////////////////////////////////////////////////////////////////
    // Highest level function for "main loop" of optimization.

    void run() {
        get_workers_ready();
        std::cout << "Starting to precompute constants on workers.\n";
        precompute_constants_on_workers();
        std::cout << "Constants precomputed on workers.\n";
        run_G_only();
        run_expansion_only();
        run_joint();
    }

protected:

    // For computing step times.
    std::chrono::high_resolution_clock::time_point t_step_start;

    void run_G_only() {
        t_step_start = std::chrono::high_resolution_clock::now();
        for (int64_t step = 0; step < config.n_iters_G_only; step++) {
            G_update_step(true, config.learning_rate_G_G_only);
            if(should_log_loss_at_step(step)) {
                log_loss("G only", step, losses_G_only);
            }
        }
    }

    void run_expansion_only() {
        t_step_start = std::chrono::high_resolution_clock::now();
        for (int64_t step = 0; step < config.n_iters_joint_expansion_only; step++) {
            W_update_step(true, true);
            G_update_step(false, config.learning_rate_G_joint);
            if(should_log_loss_at_step(step)) {
                log_loss("joint_expansion_only", step, losses_joint);
            }
        }
    }

    void run_joint() {
        t_step_start = std::chrono::high_resolution_clock::now();
        for (int64_t step = 0; step < config.n_iters_joint; step++) {
            W_update_step(true, false);
            G_update_step(false, config.learning_rate_G_joint);
            if(should_log_loss_at_step(step)) {
                log_loss("joint", step, losses_joint);
            }
        }
    }

    bool should_log_loss_at_step(int64_t step) {
        return (step + 1) % config.log_loss_frequency == 0;
    }

    void log_loss(const std::string& prefix, int64_t step, std::vector<float>& losses) {
        auto t_end = std::chrono::high_resolution_clock::now();
        double elapsed_ms = std::chrono::duration<double, std::milli>(t_end-t_step_start).count();

        // NOTE: For `run_joint`, I think we could save the all-reduces
        // computing the loss after the W update step. For simplicity
        // and robustness, I'm always doing the all-reduces for now.
        double loss = compute_loss(true);
        losses.push_back(loss);

        std::cout << prefix << " step " << step + 1 << ": " << loss << " [" << elapsed_ms / (double) config.log_loss_frequency << " ms/step]\n";

        t_step_start = std::chrono::high_resolution_clock::now();
    }

    /////////////////////////////////////////////////////////////////
    // Set up and sub-steps of the optimization.

    void get_workers_ready() {
        call_on_workers_then_join(&DeviceWorker<IndT>::set_up_work, initial_W_f.get());

        // Replicate the W from device 0 on all of the devices. This
        // ensures consistency of the shared parameters across devices.
        broadcast_W_async();
        synchronize_all_streams();
        initial_W_f.reset();
    }

    void precompute_constants_on_workers() {
        call_on_workers_then_join(&DeviceWorker<IndT>::precompute_local_constants_async);

        // synchronize_all_streams(); std::cout << "precompute_local_constants_async\n";

        NCCL_CALL(ncclGroupStart());
        for(auto& worker : workers) {
            worker->nccl_all_reduce_precomputed_constants();
        }
        NCCL_CALL(ncclGroupEnd());

        // synchronize_all_streams(); std::cout << "nccl_all_reduce_precomputed_constants\n";

        call_on_workers_then_join(&DeviceWorker<IndT>::finish_precomputing_constants_after_all_reduce_async);

        // synchronize_all_streams(); std::cout << "finish_precomputing_constants_after_all_reduce_async\n";
        
        synchronize_all_streams();
    }

    void W_update_step(bool recompute_AG_e_GG_ee_ef, bool update_only_expansion = false) {
        if(recompute_AG_e_GG_ee_ef) {
            call_on_workers_then_join(&DeviceWorker<IndT>::compute_local_AG_e_GG_ee_ef_async);
            all_reduce_AG_e_GG_ee_ef_async();
        }
        call_on_workers_then_join(&DeviceWorker<IndT>::update_local_W_after_all_reduces_async, update_only_expansion);
        synchronize_all_streams();
    }

    void G_update_step(bool recompute_AG_e_GG_ee_ef, float learning_rate_G) {
        if(recompute_AG_e_GG_ee_ef) {
            call_on_workers_then_join(&DeviceWorker<IndT>::compute_local_AG_e_GG_ee_ef_async);
            all_reduce_AG_e_GG_ee_ef_async();
        }
        call_on_workers_then_join(&DeviceWorker<IndT>::update_local_G_after_all_reduces_async, learning_rate_G);
        synchronize_all_streams();
    }

    float compute_loss(bool recompute_AG_GG) {
        if(recompute_AG_GG) {
            call_on_workers_then_join(&DeviceWorker<IndT>::compute_local_AG_e_GG_ee_ef_async);
            all_reduce_AG_e_GG_ee_ef_async();
        }
        // We only need to compute the loss using a single GPU.
        auto& loss_worker = workers[0];
        loss_worker->compute_loss_after_all_reduces_async();
        float loss_term = loss_worker->read_loss_term_from_device();
        synchronize_all_streams();
        return loss_term + config.tr_xx;
    }

    /////////////////////////////////////////////////////////////////
    // Initializations.

    int64_t compute_n_examples(std::vector<CsrMatrixPtr>& column_partitions) {
        int64_t n_rows = column_partitions[0]->n_rows;
        THROW_IF_FALSE((n_rows % config.n_classes) == 0);
        return n_rows / config.n_classes;
    }

    int64_t compute_n_cols_total(std::vector<CsrMatrixPtr>& column_partitions) {
        int64_t n_cols = 0;
        for(auto& mat : column_partitions) { n_cols += mat->n_cols; }
        return n_cols;
    }

    std::vector<int64_t> compute_n_cols_per_partition(std::vector<CsrMatrixPtr>& column_partitions) {
        std::vector<int64_t> ret;
        for(auto& mat : column_partitions) {
            ret.push_back(mat->n_cols);
        }
        return ret;
    }

    void verify_compatability_of_partitions(
        std::vector<CsrMatrixPtr>& column_partitions,
        std::vector<DenseMatrixPtr>& frozen_G_partitions
    ) {
        // TODO: More verifications that the columns partitions are compatible with frozen_G_partitions.
        THROW_IF_FALSE(column_partitions.size() == frozen_G_partitions.size());
        for(int64_t i=0; i<column_partitions.size(); i++) {
            THROW_IF_FALSE(column_partitions[i]->n_cols == frozen_G_partitions[i]->n_cols);
        }
    }

    void initialize_nccl() {
        comms = std::unique_ptr<ncclComm_t>(new ncclComm_t[n_partitions]);
        NCCL_CALL(ncclCommInitAll(comms.get(), n_partitions, NULL));
    }

    void create_workers(
        std::vector<CsrMatrixPtr>& column_partitions,
        std::vector<DenseMatrixPtr>& frozen_G_partitions
    ) {
        for (int64_t i=0; i<n_partitions; i++) {
            workers.push_back(
                DeviceWorkerPtr(new DeviceWorker<IndT>(
                    std::move(column_partitions[i]),
                    std::move(frozen_G_partitions[i]),
                    config,
                    i,
                    n_partitions,
                    comms.get()[i]
                )));
        }
    }

    /////////////////////////////////////////////////////////////////
    // Utilities for dealing the multiple workers.

    void broadcast_W_async(int64_t src_device = 0) {
        DeviceWorkerPtr& src_worker = workers[src_device];
        NCCL_CALL(ncclGroupStart());
        for(auto& worker : workers) {
            worker->nccl_broadcast_of_W(*src_worker);
        }
        NCCL_CALL(ncclGroupEnd());
    }

    void all_reduce_AG_e_GG_ee_ef_async() {
        NCCL_CALL(ncclGroupStart());
        for(auto& worker : workers) {
            worker->nccl_all_reduce_AG_e_GG_ee_ef();
        }
        NCCL_CALL(ncclGroupEnd());
    }

    void synchronize_all_streams() {
        for(auto& worker : workers) {
            worker->synchronize_stream();
        }
    }

    void join_threads(std::vector<std::thread>& threads) {
        for (auto& thread : threads) { thread.join(); }
    }

    template<typename... Args>
    void call_on_workers_then_join(void (DeviceWorker<IndT>::*method)(Args...), Args... args) {
        std::vector<std::thread> threads;
        for(auto& worker : workers) {
            threads.emplace_back(method, worker.get(), args...);
        }
        join_threads(threads);
    }

};


}  // expansion
}  // npeff
