#include <util/macros.h>

#include "./construct_dn_pefs_partition_matrices.h"


namespace npeff {
namespace preprocessing {


// Private helper.
// The last partition will be the one with potentially a few extra elements.
int64_t compute_partition_size(int64_t total_size, int64_t n_partitions, int64_t partition_index) {
    int64_t base_size = total_size / n_partitions;
    if (partition_index >= n_partitions) {
        THROW;
    } else if(partition_index == n_partitions - 1) {
        return base_size + (total_size % n_partitions);
    } else {
        return base_size;
    }
}

// Private helper.
int64_t compute_partition_start_index(int64_t total_size, int64_t n_partitions, int64_t partition_index) {
    // This works because only the last partition can potentially have extra elements.
    int64_t base_size = total_size / n_partitions;
    return partition_index * base_size;
}


///////////////////////////////////////////////////////////////////////////////

// Private helper.

std::unique_ptr<DenseMatrix<float>>
create_columnwise_uniform_partition(const inputs::DnLrmPefs& pefs, int64_t n_partitions, int64_t partition_index) {
    int64_t n_rows = pefs.n_examples() * pefs.n_classes();
    int64_t n_cols = compute_partition_size(pefs.n_parameters(), n_partitions, partition_index);

    std::unique_ptr<DenseMatrix<float>> ret(new DenseMatrix<float>(n_rows, n_cols));

    int64_t partition_start_index = compute_partition_start_index(pefs.n_parameters(), n_partitions, partition_index);

    for (int64_t example_index = 0; example_index < pefs.n_examples(); example_index++) {
        for(int64_t class_index = 0; class_index < pefs.n_classes(); class_index++) {
            int64_t row_index = example_index * pefs.n_classes() + class_index;
            for(int64_t j = 0; j < n_cols; j++) {
                ret->get_entry(row_index, j) = pefs.pefs->get_entry(example_index, class_index, partition_start_index + j);
            }
        }
    }

    return ret;
}

///////////////////////////////////////////////////////////////////////////////


std::vector<std::unique_ptr<DenseMatrix<float>>>
construct_dn_pefs_partition_matrices(inputs::DnLrmPefs& pefs, int64_t n_partitions) {
    std::vector<std::unique_ptr<DenseMatrix<float>>> partitions;
    for (int64_t i=0; i<n_partitions; i++) {
        auto partition = create_columnwise_uniform_partition(pefs, n_partitions, i);
        partitions.push_back(std::move(partition));
    }

    // Free the memory associated with the pefs.pefs array.
    pefs.pefs->data.reset();

    return partitions;
}



}  // preprocessing
}  // npeff
