
#include <gpu/gpu_info.h>

#include <preprocess/construct_dn_pefs_partition_matrices.h>
#include <preprocess/pef_normalization.h>

#include <outputs/lrm_npeff_decomposition.h>

#include "./compute_tr_xx.h"
#include "./manager.h"
#include "./io_util.h"

namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization {


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// RunContext member functions.


void RunContext::run() {
    std::cout << "Starting to run factorization.\n";

    MultiGpuManager manager(partitions, config);
    manager.run();
    std::cout << "Factorization finished. Writing output to disc...\n";

    // NOTE: These will be made row major by the decomposition saving wrapper class.
    auto W = manager.read_W_from_gpu();
    auto G = manager.read_G_from_gpu();
    
    // Make and save the output.
    npeff::outputs::DenseLrmNpeffDecomposition output;
    output.set_W(std::move(W));
    output.set_G(std::move(G));
    output.set_n_parameters(config.n_cols_total);
    output.set_n_classes(config.n_classes);

    output.set_log_loss_frequency(config.log_loss_frequency);
    output.set_losses_G_only(manager.losses_G_only);
    output.set_losses_joint(manager.losses_joint);

    output.save(output_filepath, /*null_new_to_old_col_indices_ok=*/true);
}


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Functions.


// Note that this function can modify its inputs.
RunContext create_run_context(
    npeff::inputs::DnLrmPefs& pefs,
    FactorizationConfig& partial_config,
    AdditionalRunContextConfig& additional_config
) {
    int64_t n_partitions = npeff::gpu::get_device_count();

    // Normalize the PEFs to unit norm. 
    int64_t n_non_finite_norms = preprocessing::normalize_dn_pefs_in_place(pefs, additional_config.non_finite_norms_to_zeros);
    std::cout << "n_non_finite_norms: " << n_non_finite_norms << "\n";
    std::cout << "Finished normalizing the PEFs.\n";

    // Compute tr(XX^T).
    double tr_xx = compute_tr_xx(*pefs.pefs, additional_config.n_preprocess_cpu_threads);
    std::cout << "Finished computing tr(XX^T).\n";

    // Construct the partitions.
    // The pefs.pefs->data pointer will be freed during this function call. The pefs.pefs
    // 3-tensor itself will remain valid though (ignoring its data of course).
    auto partitions = preprocessing::construct_dn_pefs_partition_matrices(pefs, n_partitions);
    std::cout << "Finished constructing the partitions of the PEF matrix.\n";

    // Update the config.
    partial_config.n_classes = pefs.n_classes();
    partial_config.n_cols_total = pefs.n_parameters();
    partial_config.tr_xx = tr_xx;

    RunContext ret;
    ret.output_filepath = additional_config.output_filepath;
    ret.config = partial_config;
    ret.partitions = std::move(partitions);
    return ret;
}



}  // dn_lrm_factorization
}  // factorizations
}  // npeff
