/**
 * Test to see if the computation of the loss during the NMF computation
 * is correct.
 * 

nvcc tests/sparse/nmf_loss_computation_test.cu -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/nmf_loss_computation_test; \
./build/nmf_loss_computation_test


 */

#include <iostream>
#include <random>
#include <limits>

#include "util/cuda_system.h"
#include "util/matrices.h"
#include "nmf/sparse/multi_mu_dense_factors_nmf1.h"
#include "util/sparse_util.h"

#define ASSERT(x, msg) { \
    if (!(x)) { \
        std::cout << "Assertion Failed: " << msg << "\n"; \
        std::cout << "    (line " << __LINE__ << " of file " << __FILE__ << ")\n"; \
        throw; \
    } \
}

// #define DEFAULT_DENSITY 0.001
#define DEFAULT_DENSITY 0.01

struct LossTestParams {
    long n_rows;
    long n_cols;
    long rank;
    long seed;
    float density = DEFAULT_DENSITY;
};


template <typename IndT>
class LossTest {
public:
    LossTestParams p;

    std::default_random_engine seedGenerator;
    std::uniform_int_distribution<long> seedDistribution;

    // These values don't matter for this test.
    float eps = 1e-12;
    float maxOutputMag = 1e2;
    long maxIters = 100;

    LossTest(LossTestParams p) {
        this->p = p;

        this->seedGenerator = std::default_random_engine(p.seed);
        this->seedDistribution = std::uniform_int_distribution<long>(std::numeric_limits<long>::min(), std::numeric_limits<long>::max());
    }
 

    void run() {
        ElCsrMatrix<IndT> A = random_csr_matrix<IndT>(p.n_rows, p.n_cols, p.density, getSeed());
        MeMatrix denseA = A.asDenseMatrix();
        MuNmf<IndT> nmf(&A, p.rank, maxIters, eps, getSeed(), maxOutputMag);
        // This initializes everything on the GPU(s).
        nmf._initializeBeforeRun();

        float computedLoss = nmf.computeUncachedLoss();

        MeMatrix W = nmf.loadWToHostSync();
        MeMatrix H = nmf.loadHToHostSync();

        MeMatrix WH = MeMatrix::multiply(W, H);

        float expectedLoss = MeMatrix::subtract(denseA, WH).frobeniusNorm();

        std::cout << "Expected loss: " << expectedLoss << "\n";
        std::cout << "Computed loss: " << computedLoss << "\n\n";

    }

private:
    long getSeed() {
        return seedDistribution(seedGenerator);
    }
};

int main(int argc, char *argv[]) {

    LossTestParams p;
    p.n_rows = 1000;
    p.n_cols = 10000;
    p.rank = 32;
    p.seed = 42069;

    LossTest<int32_t> test(p);
    test.run();


    // Make random CSR matrix A

    return 0;
}
