#pragma once

#include <cstdint>
#include <memory>
#include <string>
#include <vector>

#include <inputs/dn_pefs/dn_lrm_pefs.h>

#include <factorizations/dn_lrm_factorization2/config.h>

namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization2 {


// Additional configuration for the run context.
struct AdditionalRunContextConfig {
    std::string output_filepath;
    int64_t n_preprocess_cpu_threads = 1;
    bool non_finite_norms_to_zeros = false;
};


struct RunContext {
    using DnLrmPefsPartitionPtr = std::unique_ptr<inputs::DnLrmPefsPartition>;
    using FloatMatrixPtr = std::unique_ptr<npeff::DenseMatrix<float>>;
    using Float3TensorPtr = std::unique_ptr<npeff::Dense3Tensor<float>>;

    std::string output_filepath;

    FactorizationConfig config;
    std::vector<DnLrmPefsPartitionPtr> partitions;

    // The are from the full DnLrmPefs object. Since the partitions are
    // just views on these, we take ownership of them on the run context so
    // that we can free their associated memory once we have copied it
    // the GPUs.
    Float3TensorPtr full_pefs = nullptr;
    FloatMatrixPtr full_norms = nullptr;

    // This can be optionally passed to initialize G. The run context takes ownership
    // of it so that we can free its associated memory once it has been copied to the
    // GPUs.
    FloatMatrixPtr initial_G = nullptr;

    // Performs the actual run.
    void run();

private:
    // Clears the data associated to the full_pefs and full_norms once they
    // have been copied to the device memories.
    void clear_inputs_host_memory();

};


// Note that this function can modify its inputs.
RunContext create_run_context(
    npeff::inputs::DnLrmPefs& pefs,
    FactorizationConfig& partial_config,
    AdditionalRunContextConfig& additional_config,
    std::unique_ptr<npeff::DenseMatrix<float>> initial_G
);



}  // dn_lrm_factorization2
}  // factorizations
}  // npeff
