/**
 * Tests to see if functions like removeSmallestL0NormColumns_inPlace are correct.
 * 

nvcc tests/sparse/remove_sparsest_columns_test.cu -I./src -I/usr/local/cuda/include -I/usr/lib/x86_64-linux-gnu/hdf5/serial/include -L/usr/lib/x86_64-linux-gnu/hdf5/serial/lib -L/usr/local/cuda/lib64 -lnccl -lcublas -lcurand -lcusparse -lhdf5_cpp -lhdf5 -o build/remove_sparsest_columns_test; \
./build/remove_sparsest_columns_test


 */

#include <iostream>
#include <random>

#include "util/cuda_system.h"
#include "util/matrices.h"
#include "util/sparse_util.h"


#define ASSERT(x, msg) { \
    if (!(x)) { \
        std::cout << "Assertion Failed: " << msg << "\n"; \
        std::cout << "    (line " << __LINE__ << " of file " << __FILE__ << ")\n"; \
        throw; \
    } \
}


struct MinL0NormTestParams {
    long n_rows;
    long n_cols;
    long seed;
    float density;
    int minL0Norm;
};


template <typename IndT>
class MinL0NormTest {
public:
    MinL0NormTestParams p;

    std::default_random_engine seedGenerator;
    std::uniform_int_distribution<long> seedDistribution;

    MinL0NormTest(MinL0NormTestParams p) {
        this->p = p;

        this->seedGenerator = std::default_random_engine(p.seed);
        this->seedDistribution = std::uniform_int_distribution<long>(std::numeric_limits<long>::min(), std::numeric_limits<long>::max());
    }
 

    void run_compareToRemoveAllZeroColumns() {
        long seed = getSeed();

        // TODO: These aren't the same.
        ElCsrMatrix<IndT> A1 = random_csr_matrix<IndT>(p.n_rows, p.n_cols, p.density, seed);
        ElCsrMatrix<IndT> A2 = A1.clone();
        ASSERT(ElCsrMatrix<IndT>::areEqual(A1, A2), "Error in assumptions, test is badly coded.");

        IndT *newToOgIndex1, *newToOgIndex2;

        removeAllZeroColumns_inPlace(&A1, &newToOgIndex1);
        removeSmallestL0NormColumns_inPlace(&A2, &newToOgIndex2, 1);

        ASSERT(ElCsrMatrix<IndT>::areEqual(A1, A2), "Matrices are different.");

        ASSERT(areArraysEqual(newToOgIndex1, newToOgIndex2, A1.n_cols), "The newToOgIndex arrays are different.");

        std::cout << "Test passed.\n";
    }

private:
    long getSeed() {
        return seedDistribution(seedGenerator);
    }
};

int main(int argc, char *argv[]) {

    MinL0NormTestParams p;
    p.n_rows = 100;
    p.n_cols = 1000;
    p.seed = 42069;
    p.density = 20.0 / ((double) p.n_rows * (double) p.n_cols);

    MinL0NormTest<int32_t> test(p);
    test.run_compareToRemoveAllZeroColumns();


    // Make random CSR matrix A

    return 0;
}
