#include <util/macros.h>

#include "./construct_dn_lrm_pefs_datawise_partitions.h"


namespace npeff {
namespace preprocessing {


// Private helper.
// The last partition will be the one with potentially a few extra elements.
int64_t compute_partition_size(int64_t total_size, int64_t n_partitions, int64_t partition_index) {
    int64_t base_size = total_size / n_partitions;
    if (partition_index >= n_partitions) {
        THROW;
    } else if(partition_index == n_partitions - 1) {
        return base_size + (total_size % n_partitions);
    } else {
        return base_size;
    }
}

// Private helper.
int64_t compute_partition_start_index(int64_t total_size, int64_t n_partitions, int64_t partition_index) {
    // This works because only the last partition can potentially have extra elements.
    int64_t base_size = total_size / n_partitions;
    return partition_index * base_size;
}


///////////////////////////////////////////////////////////////////////////////

// Private helper.
std::unique_ptr<inputs::DnLrmPefsPartition>
create_datawise_partition(const inputs::DnLrmPefs& pefs, int64_t n_partitions, int64_t partition_index) {
    int64_t n_partition_examples = compute_partition_size(pefs.n_examples(), n_partitions, partition_index);
    int64_t partition_start_example_index = compute_partition_start_index(pefs.n_examples(), n_partitions, partition_index);

    float* pefs_start_ptr = &pefs.pefs->get_entry(partition_start_example_index, 0, 0);
    std::unique_ptr<Dense3TensorContiguousView<float>> partition_pefs(
        new Dense3TensorContiguousView<float>(n_partition_examples, pefs.n_classes(), pefs.n_parameters(), pefs_start_ptr));

    float* norms_start_ptr = &pefs.pef_frobenius_norms->get_entry(0, partition_start_example_index);
    std::unique_ptr<DenseMatrixContiguousView<float>> partition_norms(
        new DenseMatrixContiguousView<float>(1, n_partition_examples, norms_start_ptr));

    std::unique_ptr<inputs::DnLrmPefsPartition> partition(new inputs::DnLrmPefsPartition());
    partition->pefs = std::move(partition_pefs);
    partition->pef_frobenius_norms = std::move(partition_norms);

    return partition;
}

///////////////////////////////////////////////////////////////////////////////

std::vector<std::unique_ptr<inputs::DnLrmPefsPartition>>
construct_dn_lrm_pefs_datawise_partitions(const inputs::DnLrmPefs& pefs, int64_t n_partitions) {
    std::vector<std::unique_ptr<inputs::DnLrmPefsPartition>> partitions;

    for (int64_t i=0; i<n_partitions; i++) {
        auto partition = create_datawise_partition(pefs, n_partitions, i);
        partitions.push_back(std::move(partition));
    }

    return partitions;
}



}  // preprocessing
}  // npeff
