#include <gpu/gpu_info.h>

#include <preprocess/pef_normalization.h>
#include <preprocess/construct_dn_lrm_pefs_datawise_partitions.h>
#include <factorizations/dn_lrm_factorization/compute_tr_xx.h>
#include <outputs/lrm_npeff_decomposition.h>

#include "./io_util.h"
#include "./manager.h"


namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization2 {


void RunContext::run() {
    std::cout << "Starting to run factorization.\n";

    MultiGpuManager manager(partitions, config);

    manager.initialize(initial_G.get());
    clear_inputs_host_memory();

    manager.run();

    std::cout << "Factorization finished. Writing output to disc...\n";

    // NOTE: These will be made row major by the decomposition saving wrapper class.
    auto W = manager.read_W_from_gpu();
    auto G = manager.read_G_from_gpu();
    
    // Make and save the output.
    npeff::outputs::DenseLrmNpeffDecomposition output;
    output.set_W(std::move(W));
    output.set_G(std::move(G));
    output.set_n_parameters(config.n_cols_total);
    output.set_n_classes(config.n_classes);

    output.set_log_loss_frequency(config.log_loss_frequency);
    output.set_losses_G_only(manager.losses_G_only);
    output.set_losses_joint(manager.losses_joint);

    output.save(output_filepath, /*null_new_to_old_col_indices_ok=*/true);
}


void RunContext::clear_inputs_host_memory() {
    full_pefs.reset();
    full_norms.reset();
    if (initial_G) { initial_G.reset(); }
}


///////////////////////////////////////////////////////////////////////////////


// Note that this function can modify its inputs.
RunContext create_run_context(
    npeff::inputs::DnLrmPefs& pefs,
    FactorizationConfig& partial_config,
    AdditionalRunContextConfig& additional_config,
    std::unique_ptr<npeff::DenseMatrix<float>> initial_G
) {
    // Validate the initial_G if it provided.
    if (initial_G) {
        THROW_IF_FALSE(initial_G->n_rows == partial_config.rank);
        THROW_IF_FALSE(initial_G->n_cols == pefs.n_parameters());
    }

    int64_t n_partitions = npeff::gpu::get_device_count();

    // Normalize the PEFs to unit norm. 
    int64_t n_non_finite_norms = preprocessing::normalize_dn_pefs_in_place(pefs, additional_config.non_finite_norms_to_zeros);
    std::cout << "n_non_finite_norms: " << n_non_finite_norms << "\n";
    std::cout << "Finished normalizing the PEFs.\n";

    // Compute tr(XX^T).
    double tr_xx = dn_lrm_factorization::compute_tr_xx(*pefs.pefs, additional_config.n_preprocess_cpu_threads);
    std::cout << "Finished computing tr(XX^T).\n";

    // Construct the partitions.
    // NOTE: We'll have to keep the memory associated the pefs and norms on the `pefs` input alive
    // as long as we want to read their entries on the corresponding views.
    auto partitions = preprocessing::construct_dn_lrm_pefs_datawise_partitions(pefs, n_partitions);
    std::cout << "Finished constructing the partitions of the PEF matrix.\n";

    // Update the config.
    partial_config.n_classes = pefs.n_classes();
    partial_config.n_cols_total = pefs.n_parameters();
    partial_config.n_examples_total = pefs.n_examples();
    partial_config.tr_xx = tr_xx;

    RunContext ret;
    ret.output_filepath = additional_config.output_filepath;
    ret.config = partial_config;
    ret.partitions = std::move(partitions);
    ret.full_pefs = std::move(pefs.pefs);
    ret.full_norms = std::move(pefs.pef_frobenius_norms);
    ret.initial_G = std::move(initial_G);

    return ret;
}



}  // dn_lrm_factorization2
}  // factorizations
}  // npeff
