#pragma once

#include <ctime>
#include <chrono>
#include <functional>
#include <memory>
#include <thread>
#include <vector>

#include "nccl.h"

#include <util/macros.h>

#include "./config.h"
#include "./device_worker.h"


namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization {


class MultiGpuManager {
    using DeviceWorkerPtr = std::unique_ptr<DeviceWorker>;
    using DenseMatrixPtr = std::unique_ptr<npeff::DenseMatrix<float>>;

    const FactorizationConfig config;
    const int64_t n_partitions;

    // The array of comms used for nccl.
    std::unique_ptr<ncclComm_t> comms;

    // A vector of workers. Each is assigned to a unique device.
    std::vector<DeviceWorkerPtr> workers;

    // Information about the size of the matrix and its partitions.
    const int64_t n_examples;
    const int64_t n_cols_total;
    const std::vector<int64_t> n_cols_per_partition;

public:
    
    // Collects losses whenever they get computed.
    std::vector<float> losses_G_only;
    std::vector<float> losses_joint;


    MultiGpuManager(
        std::vector<DenseMatrixPtr>& column_partitions,
        FactorizationConfig config
    ) :
        config(config),
        n_partitions(column_partitions.size()),
        n_examples(compute_n_examples(column_partitions)),
        n_cols_total(compute_n_cols_total(column_partitions)),
        n_cols_per_partition(compute_n_cols_per_partition(column_partitions))
    {
        initialize_nccl();
        create_workers(column_partitions);
    }

    ~MultiGpuManager() {
        // Have this check in place so that we do not try to
        // pass junk addresses to ncclCommDestroy in case we
        // did not allocate the comms array at the time of
        // desctruction.
        if(comms) {
            for (int64_t i=0; i<n_partitions; i++) {
                ncclCommDestroy(comms.get()[i]);
            }
        }
    }

    // Returned matrix has shape [n_examples, rank] and is in column-major format.
    DenseMatrixPtr read_W_from_gpu(int64_t src_device = 0);

    // Returned matrix has shape [rank, n_parameters] and is in column-major format.
    DenseMatrixPtr read_G_from_gpu();

    // Highest level function for "main loop" of optimization.
    void run();

private:
    // For computing step times.
    std::chrono::high_resolution_clock::time_point t_step_start;

    void run_G_only();
    void run_joint();

    bool should_log_loss_at_step(int64_t step) const;

    // For some reason, trying to define this in the cc file leads to the matching
    // declaration not being found.
    void log_loss(const std::string& prefix, int64_t step, std::vector<float>& losses);

    /////////////////////////////////////////////////////////////////
    // Set up and sub-steps of the optimization.

    void get_workers_ready();

    void W_update_step(bool recompute_AG_GG);
    void G_update_step(bool recompute_AG_GG, float learning_rate_G);

    float compute_loss(bool recompute_AG_GG);

    /////////////////////////////////////////////////////////////////
    // Initializations.

    int64_t compute_n_examples(const std::vector<DenseMatrixPtr>& column_partitions) const;
    int64_t compute_n_cols_total(const std::vector<DenseMatrixPtr>& column_partitions) const;
    std::vector<int64_t> compute_n_cols_per_partition(const std::vector<DenseMatrixPtr>& column_partitions) const;

    void initialize_nccl();
    void create_workers(std::vector<DenseMatrixPtr>& column_partitions);

    /////////////////////////////////////////////////////////////////
    // Utilities for dealing the multiple workers.

    void broadcast_W_async(int64_t src_device = 0);

    void all_reduce_AG_GG_async();

    void synchronize_all_streams();

    void join_threads(std::vector<std::thread>& threads);

    template<typename... Args>
    void call_on_workers_then_join(void (DeviceWorker::*method)(Args...), Args... args) {
        std::vector<std::thread> threads;
        for(auto& worker : workers) {
            threads.emplace_back(method, worker.get(), args...);
        }
        join_threads(threads);
    }
};


}  // dn_lrm_factorization
}  // factorizations
}  // npeff
