#include <iostream>
#include "util/cuda_system.h"
#include "util/matrices.h"
#include "nmf/sparse/multi_mu_dense_factors_nmf1.h"
#include "util/sparse_util.h"
#include "io/pef_fisher_h5.h"

int main() {

    const std::string filepath = "/home/owner/Desktop/projects_data/extract_merge1/feather_berts_0.hans_lone.no_embeddings.5k.32k.h5";

    PefSparseFishers pef = PefSparseFishers::read(filepath);
    pef.normalizeFishers_inPlace();
    // std::cout << "(" << pef.fishers->n_rows << ", " << pef.fishers->n_cols << ")\n";

    int* newToOgIndex;
    removeAllZeroColumns_inPlace(pef.fishers, &newToOgIndex);
    // std::cout << "(" << pef.fishers->n_rows << ", " << pef.fishers->n_cols << ")\n";

    // for(int i=0; i<12; i++) {
    //     // std::cout << pef.fishers->csrValA[i] << "\n";
    //     std::cout << pef.fishers->csrColIndA[i] << "\n";
    // }


    // TODO: See if speed changes depending on whether the columns of fishers are sorted.


    // NOTE: Just freeing up memory for local testing.
    delete[] newToOgIndex;


    std::cout << "Starting NMF.\n";

    // long rank = 1024;
    // long rank = 4;
    long rank = 64;

    long seed = 4319043202;
    float eps = 1e-9;
    // float eps = 1e-6;
    // int max_iters = 10000;
    // int max_iters = 8;
    int max_iters = 1000;

    MuNmf nmf(pef.fishers, rank, max_iters, eps, seed);
    nmf.run();

    // delete[] newToOgIndex;

    return 0;
}

// int main() {
//     long inputMatSeed = 98138924323;
//     // float density = 0.0001f;
//     float density = 0.01f;

//     // int n_rows = 8000;
//     // // int n_cols = 120000;
//     // int n_cols = 80000;

//     int n_rows = 800;
//     int n_cols = 1200;

//     // int n_rows = 8000;
//     // int n_cols = 12000;

//     ElCsrMatrix A = random_csr_matrix(n_rows, n_cols, density, inputMatSeed);

//     std::cout << "A.nnz = " << A.nnz << "\n";

//     std::cout << "Made matrix, starting NMF.\n";

//     long rank = 1024;
//     // long rank = 32;

//     long seed = 4319043202;
//     float eps = 1e-9;
//     // float eps = 1e-6;
//     // int max_iters = 10000;
//     int max_iters = 8;
//     // int max_iters = 1000;

//     MuNmf nmf(&A, rank, max_iters, eps, seed);
//     nmf.run();

//     // delete[] newToOgIndex;

//     return 0;
// }
