#include <thread>

#include "./W_partitions.h"

namespace npeff {
namespace outputs {

using DenseMatrixPtr = std::unique_ptr<npeff::DenseMatrix<float>>;



///////////////////////////////////////////////////////////////////////////////

// Local helper class.
class RowMajorMaker {
    // Input, will be column major.
    npeff::DenseMatrix<float>* partition;
    const int64_t example_offset;
    // Output, will be row major.
    npeff::DenseMatrix<float>* W;

public:
    RowMajorMaker(
        npeff::DenseMatrix<float>* partition,
        int64_t example_offset,
        npeff::DenseMatrix<float>* W
    ) :
        partition(partition), example_offset(example_offset), W(W)
    {}

    void operator()() {
        int64_t rank = W->n_cols;
        float* start_sub_mat = W->data.get() + (example_offset * rank);
        partition->convert_to_row_major_onto_buffer(start_sub_mat);
        partition->data.reset();
    }
};


///////////////////////////////////////////////////////////////////////////////


DenseMatrixPtr WPartitions::get_full_row_major_W() {
    int64_t n_examples_total = compute_n_examples_total();
    int64_t rank = compute_rank();

    // NOTE: This will be row-major.
    DenseMatrixPtr W = std::unique_ptr<DenseMatrix<float>>(
        new DenseMatrix<float>(n_examples_total, rank));

    std::vector<RowMajorMaker> makers;
    std::vector<std::thread> maker_threads;

    int64_t example_offset = 0;
    for(int32_t i=0; i<partitions.size(); i++) {
        const auto& p = partitions[i];
        makers.emplace_back(p.get(), example_offset, W.get());
        maker_threads.emplace_back(makers[i]);

        example_offset += p->n_rows;
    }

    for(auto& thread : maker_threads) { thread.join(); }

    return W;
}


///////////////////////////////////////////////////////////////////////////////


int64_t WPartitions::compute_n_examples_total() const {
    int64_t ret = 0;
    for(const auto& p : partitions) {
        ret += p->n_rows;
    }
    return ret;
}


int64_t WPartitions::compute_rank() const {
    return partitions[0]->n_cols;
}


}  // outputs
}  // npeff
