"""Performs NMF and saves results to disk."""
import os
import time

from absl import app
from absl import flags
from absl import logging

# NOTE: For some reason, I have to place the`import matplotlib.pyplot as plt`
# statement here or right below the numpy import. Otherwise the
# `torch.sparse_coo_tensor` call causes a segfault. We're not even using
# pyplot here. What the fuck?!?!
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
from torchnmf.nmf import NMF as TorchNMF

from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("per_example_fishers", None, "Path to file containing per-example Fishers.")

flags.DEFINE_integer("n_examples", None, '')

flags.DEFINE_integer("start_fisher_index", None, '')
flags.DEFINE_integer("end_fisher_index", None, '')

flags.DEFINE_integer("reduce_threshold", 0, "")


# NMF flags:

flags.DEFINE_integer("nmf_n_components", None, '')

flags.DEFINE_integer("nmf_max_iter", 200, '')
flags.DEFINE_float("nmf_tol", 1e-6, '')

flags.DEFINE_float("nmf_alpha", 0.0, '')
flags.DEFINE_float("nmf_beta", 1.0, '')
flags.DEFINE_float("nmf_l1_ratio", 0.0, '')


# TODO: Add option to do this per-layer or create a new file that does this.


def main(_):
    # Keep tensorflow from allocating all GPU memory to allow torchnmf to
    # use GPU.
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    #######################################################

    print('Starting to load saved per-example Fishers.')
    start = time.time()
    pe_fishers_data = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(FLAGS.per_example_fishers),
        n_examples=FLAGS.n_examples,
        start_fisher_index=FLAGS.start_fisher_index,
        end_fisher_index=FLAGS.end_fisher_index,
    )
    print('Load saved per-example Fishers time: ', time.time() - start)

    reduce_kept_indices = None

    if pe_fishers_data.is_sparse():
        if FLAGS.reduce_threshold == 0:
            indices, values, dense_shape = pe_fishers_data.to_sparse_fishers()
        else:
            # TODO: Find way to specify dense vs sparse reduced Fishers.
            reduce_kept_indices, (indices, values, dense_shape) = pe_fishers_data.to_sparse_reduced_fishers(
                FLAGS.reduce_threshold)

        # TODO: Directly create the transpose, should be fairly straightforward.
        torch_fishers = torch.sparse_coo_tensor(indices, values, dense_shape).cuda()
        torch_fishers = torch_fishers.t_()
    else:
        if FLAGS.reduce_threshold != 0:
            raise ValueError('Reduced fishers not (yet?) supported for dense fishers.')
        torch_fishers = torch.from_numpy(pe_fishers_data.fishers.T).cuda()

    print('Starting NMF decomposition.')
    start = time.time()
    nmf_model = TorchNMF(torch_fishers.shape, rank=FLAGS.nmf_n_components).cuda()
    nmf_model.fit(
        torch_fishers,
        verbose=True,
        max_iter=FLAGS.nmf_max_iter,
        tol=FLAGS.nmf_tol,
        alpha=FLAGS.nmf_alpha,
        beta=FLAGS.nmf_beta,
        l1_ratio=FLAGS.nmf_l1_ratio,
    )
    print('NMF time: ', time.time() - start)

    W = nmf_model.W.detach().cpu().numpy()
    Ht = nmf_model.H.detach().cpu().numpy()

    decomp = nmf_common.NmfDecomposition(
        W=W,
        H=Ht.T,
        reduce_kept_indices=reduce_kept_indices,
        full_dense_size=pe_fishers_data.fisher_dense_size,
    )

    print('Starting to save NMF decomposition.')
    start = time.time()
    decomp.save(os.path.expanduser(FLAGS.output_path))
    print('Save NMF decomposition time: ', time.time() - start)


if __name__ == "__main__":
    app.run(main)
