#include <thread>

#include <util/h5_util.h>

#include "./dn_lrm_pefs_loader.h"

namespace npeff {
namespace inputs {

///////////////////////////////////////////////////////////////////////////////
const std::string PEFS_DS_NAME = "data/pefs";
const std::string NORMS_DS_NAME = "data/pef_frobenius_norms";
///////////////////////////////////////////////////////////////////////////////

bool DnLrmPefsFileInfo::has_n_examples() const {
    return n_examples >= 0;
}

///////////////////////////////////////////////////////////////////////////////


struct PefsFileExtraInfo {
    int64_t n_examples_actual;
    int64_t pefs_rank;
    int64_t d_pefs;
};

// Throws an error if something invalid is found.
std::vector<PefsFileExtraInfo> read_extra_infos_and_validate(const std::vector<DnLrmPefsFileInfo>& infos) {
    int64_t rank = -1;
    int64_t d_pefs = -1;

    std::vector<PefsFileExtraInfo> ret;

    for(const DnLrmPefsFileInfo& info : infos) {
        hid_t file = H5Fopen(info.filepath.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
        const auto dims = util::h5::read_dataset_dims(file, PEFS_DS_NAME);

        if (dims.size() != 3) { THROW_MSG("PEFs dataset in input file must have 3 dimensions.");}

        PefsFileExtraInfo extra_info;

        int64_t n_examples_total = dims[0];
        if (info.examples_offset >= n_examples_total) {
            THROW_MSG("The examples_offset for a PEF input file was greater than the number of examples in that file.");
        }
        if (info.has_n_examples()) {
            if (info.examples_offset + info.n_examples > n_examples_total) {
                THROW_MSG("The selected range of examples for a PEF input file exceeds the number of examples in that file.");
            }
            extra_info.n_examples_actual = info.n_examples;
        } else {
            extra_info.n_examples_actual = n_examples_total;
        }

        int64_t f_rank = dims[1];
        if(rank != -1 && f_rank != rank) { THROW_MSG("Ranks of PEFs are not consistent within the input PEF files."); }
        extra_info.pefs_rank = f_rank;
        rank = f_rank;

        int64_t f_d_pefs = dims[2];
        if(d_pefs != -1 && f_d_pefs != d_pefs) { THROW_MSG("Dimensions of PEFs are not consistent within the input PEF files."); }
        extra_info.d_pefs = f_d_pefs;
        d_pefs = f_d_pefs;

        ret.push_back(extra_info);

        H5Fclose(file);
    }

    return ret;
}

PefsFileExtraInfo aggregate_extra_infos(const std::vector<PefsFileExtraInfo>& extra_infos) {
    if (extra_infos.empty()) { THROW_MSG("At least one PEF input file must be provided."); }

    PefsFileExtraInfo ret;
    ret.pefs_rank = extra_infos[0].pefs_rank;
    ret.d_pefs = extra_infos[0].d_pefs;

    ret.n_examples_actual = 0;
    for (const auto& ei : extra_infos) { ret.n_examples_actual += ei.n_examples_actual; }

    return ret;
}

///////////////////////////////////////////////////////////////////////////////

class DnLrmPefsFileLoader {
    const DnLrmPefsFileInfo& info;
    const PefsFileExtraInfo& extra_info;
    const int64_t output_pefs_rank;
    const int64_t output_example_offset;
    Dense3Tensor<float>* pefs;
    DenseMatrix<float>* norms;
public:
    DnLrmPefsFileLoader(
        const DnLrmPefsFileInfo& info, const PefsFileExtraInfo& extra_info,
        int64_t output_pefs_rank, int64_t output_example_offset,
        Dense3Tensor<float>* pefs, DenseMatrix<float>* norms
    ) :
        info(info), extra_info(extra_info),
        output_pefs_rank(output_pefs_rank), output_example_offset(output_example_offset),
        pefs(pefs), norms(norms)
    {}

    void operator()() {
        hid_t file = H5Fopen(info.filepath.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);

        float* pefs_write_ptr = &pefs->get_entry(output_example_offset, 0, 0);
        util::h5::read_dataset_to_ptr(pefs_write_ptr, file, PEFS_DS_NAME, info.n_examples, info.examples_offset, output_pefs_rank);

        float* norms_write_ptr = &norms->get_entry(0, output_example_offset);
        util::h5::read_dataset_to_ptr(norms_write_ptr, file, NORMS_DS_NAME, info.n_examples, info.examples_offset);

        H5Fclose(file);
    }
};

///////////////////////////////////////////////////////////////////////////////

DnLrmPefs load_dn_lrm_pefs(const std::vector<DnLrmPefsFileInfo>& infos, int64_t rank) {
    const auto extra_infos = read_extra_infos_and_validate(infos);
    const PefsFileExtraInfo agg_extra_info = aggregate_extra_infos(extra_infos);

    int64_t output_pefs_rank;
    if (rank < 0) {
        // Use all of the PEF rows present in the file(s).
        output_pefs_rank = agg_extra_info.pefs_rank;
    } else {
        if (rank > agg_extra_info.pefs_rank) {
            THROW_MSG("The supplied PEFs rank must be less than or equal to the rank of the PEFs in the file(s).");
        }
        output_pefs_rank = rank;
    }

    DnLrmPefs ret;

    std::unique_ptr<Dense3Tensor<float>> pefs(
        new Dense3Tensor<float>(agg_extra_info.n_examples_actual, output_pefs_rank, agg_extra_info.d_pefs));
    std::unique_ptr<DenseMatrix<float>> norms(new DenseMatrix<float>(1, agg_extra_info.n_examples_actual));

    std::vector<DnLrmPefsFileLoader> loaders;
    std::vector<std::thread> loader_threads;

    int64_t output_example_offset = 0;
    for(int32_t i=0; i<infos.size(); i++) {
        const auto& info = infos[i];
        const auto& extra_info = extra_infos[i];

        loaders.emplace_back(info, extra_info, output_pefs_rank, output_example_offset, pefs.get(), norms.get());
        loader_threads.emplace_back(loaders[i]);

        output_example_offset += extra_info.n_examples_actual;
    }

    for(auto& thread : loader_threads) { thread.join(); }

    ret.pefs = std::move(pefs);
    ret.pef_frobenius_norms = std::move(norms);

    return ret;
}


DnLrmPefs load_dn_lrm_pefs(const DnLrmPefsFileInfo& info, int64_t rank) {
    std::vector<DnLrmPefsFileInfo> dummy_vec({info});
    return load_dn_lrm_pefs(dummy_vec, rank);
}



}  // inputs
}  // npeff
