#pragma once

#include <ctime>
#include <chrono>
#include <functional>
#include <memory>
#include <thread>
#include <vector>

#include "nccl.h"

#include <inputs/dn_pefs/dn_lrm_pefs.h>
#include <outputs/W_partitions.h>
#include <util/macros.h>

#include "./config.h"
#include "./worker.h"


namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization2 {


// Utility/helper stuff.
namespace internal {


// Information to help with the contruction of outputs.
struct MultiGpuManagerOutputInfo {
    int64_t n_examples_total;
    int64_t n_parameters;
    int64_t rank;
    std::vector<int64_t> n_examples_per_partition;
};


}  // internal



class MultiGpuManager {
    using DnLrmPefsPartitionPtr = std::unique_ptr<inputs::DnLrmPefsPartition>;
    using DeviceWorkerPtr = std::unique_ptr<DeviceWorker>;
    using DenseMatrixPtr = std::unique_ptr<npeff::DenseMatrix<float>>;

    const FactorizationConfig config;
    const int64_t n_partitions;

    // Information about the output
    const internal::MultiGpuManagerOutputInfo output_info;

    // The array of comms used for nccl.
    std::unique_ptr<ncclComm_t> comms;

    // A vector of workers. Each is assigned to a unique device.
    std::vector<DeviceWorkerPtr> workers;

    // Set to true in the initialize method. Used as a fail-safe to prevent
    // accidently calling run without initializing the workers first.
    bool initialized = false;

public:
    
    // Collects losses whenever they get computed.
    std::vector<float> losses_G_only;
    std::vector<float> losses_joint;


    MultiGpuManager(
        std::vector<DnLrmPefsPartitionPtr>& partitions,
        FactorizationConfig config
    ) :
        config(config),
        n_partitions(partitions.size()),
        output_info(make_output_info(partitions))
    {
        initialize_nccl();
        create_workers(partitions);
    }

    ~MultiGpuManager() {
        // Have this check in place so that we do not try to
        // pass junk addresses to ncclCommDestroy in case we
        // did not allocate the comms array at the time of
        // desctruction.
        if(comms) {
            for (int64_t i=0; i<n_partitions; i++) {
                ncclCommDestroy(comms.get()[i]);
            }
        }
    }

    // Sets up and initializes the workers, which includes copying the relevant host
    // memory to the GPUs.
    void initialize(npeff::DenseMatrix<float>* initial_G);

    // Highest level function for "main loop" of optimization.
    // 
    // NOTE: Unlike the original version of the dense LRM-NPEFF, this method
    // does NOT set up the workers. This is so that we can free host memory
    // associated with the PEFs after copying it to the GPU within the initialize() method.
    void run();

    std::unique_ptr<outputs::WPartitions> read_W_from_gpu();

    // Returned matrix has shape [rank, n_parameters] and is in column-major format.
    DenseMatrixPtr read_G_from_gpu(int64_t src_device = 0);

private:
    // For computing step times.
    std::chrono::high_resolution_clock::time_point t_step_start;

    void run_G_only();
    void run_joint();

    bool should_log_loss_at_step(int64_t step) const;
    void log_loss(const std::string& prefix, int64_t step, std::vector<float>& losses, DeviceWorker::LossComputationOptions opts);

    /////////////////////////////////////////////////////////////////
    // Sub-steps of the optimization.

    void W_update_step(DeviceWorker::WUpdateOptions opts);
    void G_update_step(DeviceWorker::GUpdateOptions opts);

    float compute_loss(DeviceWorker::LossComputationOptions opts);

    /////////////////////////////////////////////////////////////////
    // Initializations.

    internal::MultiGpuManagerOutputInfo make_output_info(const std::vector<DnLrmPefsPartitionPtr>& partitions) const;

    void initialize_nccl();
    void create_workers(std::vector<DnLrmPefsPartitionPtr>& partitions);

    // /////////////////////////////////////////////////////////////////
    // // Utilities for dealing the multiple workers.

    void broadcast_G_async(int64_t src_device = 0);

    void all_reduce_WW_async();
    void G_update_all_reduce_async(DeviceWorker::GUpdateOptions opts);

    void synchronize_all_streams();

    void join_threads(std::vector<std::thread>& threads);

    template<typename... Args>
    void call_on_workers_then_join(void (DeviceWorker::*method)(Args...), Args... args) {
        std::vector<std::thread> threads;
        for(auto& worker : workers) {
            threads.emplace_back(method, worker.get(), args...);
        }
        join_threads(threads);
    }
};


}  // dn_lrm_factorization2
}  // factorizations
}  // npeff
