#pragma once

#include <cstdint>
#include <memory>

#include <containers/dense_matrix.h>
#include <containers/dense_3_tensor.h>


namespace npeff {
namespace inputs {

// The class containing the dense LRM-PEF inputs in host memory.
struct DnLrmPefs {
    // pefs.shape = [n_examples, n_classes, nnz_parameters]
    std::unique_ptr<npeff::Dense3Tensor<float>> pefs = nullptr;

    // pef_frobenius_norms.shape = [1, n_examples]
    std::unique_ptr<npeff::DenseMatrix<float>> pef_frobenius_norms = nullptr;

    int64_t n_examples() const { return pefs->d1; }
    int64_t n_classes() const { return pefs->d2; }
    int64_t n_parameters() const { return pefs->d3; }

};


// Contains a partition of the dense LRM-PEFs in host memory.
// 
// NOTE: This struct will NOT own the pefs/norms array memory associated with it.
struct DnLrmPefsPartition {
    // pefs.shape = [n_partition_examples, n_classes, n_parameters]
    std::unique_ptr<npeff::Dense3TensorContiguousView<float>> pefs = nullptr;

    // pef_frobenius_norms.shape = [1, n_partition_examples]
    std::unique_ptr<npeff::DenseMatrixContiguousView<float>> pef_frobenius_norms = nullptr;

    int64_t n_partition_examples() const { return pefs->d1; }
    int64_t n_classes() const { return pefs->d2; }
    int64_t n_parameters() const { return pefs->d3; }

    // Returns the pefs as a matrix with shape [n_parameters, n_partition_examples * n_classes]
    npeff::DenseMatrixContiguousView<float> pefs_matrix() const;
};


}  // inputs
}  // npeff
