// Same factorization as run_dn_m_npeff.cu except that tensors are split
// differently across multiple GPUs.
// 
// This method better supports cases where the number of examples is large while
// the product of the dimension and rank of the PEFs is relatively small.
// 
// NOTE: Some other differences might arise between this and run_dn_m_npeff.cu as I make
// changes to one of them without making the corresponding change to the other.

#include <iostream>
#include <memory>
#include <string>
#include <vector>

#include <gflags/gflags.h>
#include <util/flag_util.h>

#include <containers/dense_matrix.h>
#include <containers/sparse_matrix.h>

#include <inputs/dn_pefs/dn_lrm_pefs.h>
#include <inputs/dn_pefs/dn_lrm_pefs_loader.h>
#include <inputs/lrm_npeff_decomposition.h>

#include <factorizations/dn_lrm_factorization2/config.h>
#include <factorizations/dn_lrm_factorization2/io_util.h>

//////////////////////////////////////////////////////////////////////////////
// Flag definitions

DEFINE_string(output_filepath, "", "Filepath where the HDF5 file containing the output will be written to.");

DEFINE_string(pef_filepaths, "", "Comma-separated list of filepath of the HDF5 file containing the per-example Fishers.");
DEFINE_string(n_examples_per_pef, "",
              "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
              "If provided, the list must be the same length as the --pef_filepaths list. "
              "Leave empty to use all examples from all PEFs. "
              "Use a value of -1 for a particular PEF to use all examples from that particular PEF.");

DEFINE_int64(n_components, -1, "The number of elements in the learned dictionary.");

DEFINE_int64(n_iters_G_only, -1, "The number of iterations to learn an initial G given a fixed W.");
DEFINE_int64(n_iters_joint, -1, "The maximum number of iterations to run the factorization algorithm for, learning both W and G.");

DEFINE_double(learning_rate_G, 1e-3, "The learning rate to be used for the G-update steps.");
DEFINE_double(learning_rate_G_G_only, -1.0, "The learning rate to be used for the G-update steps in G only training. If not set, defaults to --learning_rate_G.");

DEFINE_double(mu_eps, 1e-9, "Epsilon for the multiplicative update on W.");

DEFINE_int64(rand_gen_seed, 48230, "Seed to use for random number generation.");

DEFINE_int64(log_loss_frequency, 10, "Compute and log the loss every this number of steps.");

DEFINE_int64(n_preprocess_cpu_threads, 1, "The number of threads to use for preprocessing on the CPU.");

DEFINE_bool(non_finite_norms_to_zeros, false, "Whether to set examples with non-finite PEF norms to zeros.");


DEFINE_int64(pefs_rank, -1,
             "If negative, then read all rows from each LRM-PEF. If positive, then take the first --pefs_rank "
             "rows from each LRM-PEF.");

DEFINE_string(initial_G_filepath, "",
              "Filepath to an NPEFF decomposition containing the initial value to use for G. Leave empty "
              "to randomly initialize G. If provided, the G in the file must match the number of components "
              "in this decomposition and the dimension of the PEFs.");



//////////////////////////////////////////////////////////////////////////////
using namespace npeff::factorizations::dn_lrm_factorization2;
using DnLrmPefsFileInfo = npeff::inputs::DnLrmPefsFileInfo;
using DnLrmPefs = npeff::inputs::DnLrmPefs;
using FloatMatrixPtr = std::unique_ptr<npeff::DenseMatrix<float>>;
//////////////////////////////////////////////////////////////////////////////


// Reads the config from flags. Note that not all fields can be
// set directly from flags, so those will need to be written later.
FactorizationConfig read_partial_config_from_flags() {
    FactorizationConfig config;
    
    config.rank = FLAGS_n_components;

    config.n_iters_G_only = FLAGS_n_iters_G_only;
    config.n_iters_joint = FLAGS_n_iters_joint;

    config.rand_gen_seed = FLAGS_rand_gen_seed;
    config.log_loss_frequency = FLAGS_log_loss_frequency;
    config.mu_eps = FLAGS_mu_eps;

    config.learning_rate_G_joint = FLAGS_learning_rate_G;
    config.learning_rate_G_G_only = FLAGS_learning_rate_G_G_only;
    if (config.learning_rate_G_G_only <= 0.0) {
        config.learning_rate_G_G_only = config.learning_rate_G_joint;
    }

    // Config fields not set:
    //   - config.n_classes
    //   - config.tr_xx
    //   - config.n_examples_total
    //   - config.n_cols_total

    // Validations.
    if(config.rank <= 0) {
        THROW_MSG("Must set the --n_components flag to a positive integer.");
    }
    if(config.n_iters_G_only < 0) {
        THROW_MSG("Must set the --n_iters_G_only flag to a non-negative integer.");
    }
    if(config.n_iters_joint <= 0) {
        THROW_MSG("Must set the --n_iters_joint flag to a positive integer.");
    }

    return config;
}


FloatMatrixPtr read_initial_G_from_flags() {
    if (FLAGS_initial_G_filepath.empty()) { return nullptr; }

    npeff::inputs::DenseLrmNpeffDecompositionFromFile initial_G_decomposition(FLAGS_initial_G_filepath);
    return initial_G_decomposition.read_G();
}


AdditionalRunContextConfig read_additional_run_context_config_from_flags() {
    AdditionalRunContextConfig ret;
    ret.output_filepath = FLAGS_output_filepath;
    ret.n_preprocess_cpu_threads = FLAGS_n_preprocess_cpu_threads;
    ret.non_finite_norms_to_zeros = FLAGS_non_finite_norms_to_zeros;
    return ret;
}


std::vector<DnLrmPefsFileInfo> read_pef_file_infos_from_flags() {
    std::vector<std::string> pef_filepaths = npeff::util::flags::parse_string_list(FLAGS_pef_filepaths);
    std::vector<int64_t> n_examples_per_pef = npeff::util::flags::parse_int64_list(FLAGS_n_examples_per_pef);

    if (pef_filepaths.empty()) {
        THROW_MSG("Please provide at least one PEF file to --pef_filepaths.");
    }
    if (!n_examples_per_pef.empty() && pef_filepaths.size() != n_examples_per_pef.size()) {
        THROW_MSG("If --n_examples_per_pef is provided, its list of entries must match that of --pef_filepaths.");
    }

    std::vector<DnLrmPefsFileInfo> ret;
    for(int32_t i=0; i<pef_filepaths.size(); i++) {
        DnLrmPefsFileInfo info;
        info.filepath = pef_filepaths[i];
        if (!n_examples_per_pef.empty()) {
            info.n_examples = n_examples_per_pef[i];
        }
        ret.push_back(info);
    }

    return ret;
}



int main(int argc, char *argv[]) {
    gflags::ParseCommandLineFlags(&argc, &argv, true);

    if(FLAGS_output_filepath.empty()) {
        THROW_MSG("Please provide a valid --output_filepath flag value.");
    }

    FactorizationConfig partial_config = read_partial_config_from_flags();
    AdditionalRunContextConfig additional_config = read_additional_run_context_config_from_flags();

    std::vector<DnLrmPefsFileInfo> pef_file_infos = read_pef_file_infos_from_flags();
    DnLrmPefs pefs = npeff::inputs::load_dn_lrm_pefs(pef_file_infos, FLAGS_pefs_rank);
    std::cout << "LRM-PEFS loaded from disk.\n";

    FloatMatrixPtr initial_G = read_initial_G_from_flags();
    if (initial_G) {
        std::cout << "Initial G loaded from disk.\n";
    }

    auto ctx = create_run_context(pefs, partial_config, additional_config, std::move(initial_G));
    ctx.run();
}
