"""Performs NMF and saves results to disk."""
import os
import time

from absl import app
from absl import flags
from absl import logging

# NOTE: For some reason, I have to place the`import matplotlib.pyplot as plt`
# statement here or right below the numpy import. Otherwise the
# `torch.sparse_coo_tensor` call causes a segfault. We're not even using
# pyplot here. What the fuck?!?!
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
from torchnmf.nmf import NMF as TorchNMF

from em.fishers import per_example
from em.models import divis_models
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import sparse_util
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("per_example_fishers", None, "Path to file containing per-example Fishers.")

flags.DEFINE_integer("n_examples", None, '')

flags.DEFINE_integer("start_fisher_index", None, '')
flags.DEFINE_integer("end_fisher_index", None, '')

flags.DEFINE_integer("reduce_threshold", 1, "")

# Per-subset flags

# Note that the "per_res_block" style will treat non-res_block layers as "per_layer".
flags.DEFINE_enum('subset_style', None, ['per_layer', 'per_res_block'], '')

flags.DEFINE_string("model", None, "Path to h5 file containing model info and weights.")

flags.DEFINE_list("subset_indices", None, '')
# TODO: Have other ways of specifiying what subsets to compute for.

# NMF flags:

flags.DEFINE_integer("nmf_n_components", None, '')

flags.DEFINE_integer("nmf_max_iter", 200, '')
flags.DEFINE_float("nmf_tol", 1e-6, '')

flags.DEFINE_float("nmf_alpha", 0.0, '')
flags.DEFINE_float("nmf_beta", 1.0, '')
flags.DEFINE_float("nmf_l1_ratio", 0.0, '')


def get_values_for_subset(pe_fishers_data, packer: flat_pack.FlatPacker, subset_index: int):
    start_index, end_index = packer.get_range_for_tensor_by_index(subset_index)

    subset_values = []
    subset_indices = []
    for values, inds in zip(pe_fishers_data.fishers, pe_fishers_data.fisher_indices):
        mask = (start_index <= inds) & (inds < end_index)
        subset_values.append(values[mask])
        subset_indices.append(inds[mask] - start_index)

    return subset_values, subset_indices


def get_reduced_sparse_fishers(pe_fishers_data, packer: flat_pack.FlatPacker, subset_index: int):
    subset_values, subset_indices = get_values_for_subset(pe_fishers_data, packer, subset_index)

    start_index, end_index = packer.get_range_for_tensor_by_index(subset_index)

    reduction_info, (reduced_subset_values, reduced_subset_indices) = sparse_util.remove_always_zero_indices(
        subset_values,
        subset_indices,
        dense_size=end_index - start_index,
        threshold=FLAGS.reduce_threshold,
    )

    coo_indices = sparse_util.to_torch_coo_indices(reduced_subset_indices)
    coo_values = np.concatenate(list(reduced_subset_values), axis=-1)
    dense_shape = [len(reduced_subset_indices), reduction_info.reduced_size]

    # TODO: Directly create the transpose, should be fairly straightforward.
    torch_fishers = torch.sparse_coo_tensor(coo_indices, coo_values, dense_shape).cuda()
    torch_fishers = torch_fishers.t_()

    return reduction_info, torch_fishers


def get_subset_indices():
    assert FLAGS.subset_indices is not None
    return [int(s) for s in FLAGS.subset_indices]


def main(_):
    # Keep tensorflow from allocating all GPU memory to allow torchnmf to
    # use GPU.
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    #######################################################

    assert FLAGS.subset_style == 'per_layer', 'TODO: Support other subset styles'

    assert FLAGS.output_path.endswith('.h5')
    # Remove the file extension.
    output_base_path = os.path.expanduser(FLAGS.output_path[:-3])

    #######################################################

    model, model_config = divis_models.load_model_from_file(FLAGS.model)

    print('Starting to load saved per-example Fishers.')
    start = time.time()
    pe_fishers_data = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(FLAGS.per_example_fishers),
        n_examples=FLAGS.n_examples,
        start_fisher_index=FLAGS.start_fisher_index,
        end_fisher_index=FLAGS.end_fisher_index,
    )
    print('Load saved per-example Fishers time: ', time.time() - start)

    #######################################################

    variables = model.trainable_variables
    
    homogenized_variables = tmv.homogenize_kernel_biases(variables)
    packer = flat_pack.FlatPacker([v.shape for v in homogenized_variables])

    assert pe_fishers_data.is_sparse(), 'TODO: Handle non-sparse fishers'

    #######################################################

    indices_of_subsets = get_subset_indices()
    for subset_index in indices_of_subsets:
        print(f'Starting to run NMF for subset with index {subset_index}')

        reduction_info, torch_fishers = get_reduced_sparse_fishers(pe_fishers_data, packer, subset_index)

        print('Starting NMF decomposition.')
        start = time.time()
        nmf_model = TorchNMF(torch_fishers.shape, rank=FLAGS.nmf_n_components).cuda()
        nmf_model.fit(
            torch_fishers,
            verbose=True,
            max_iter=FLAGS.nmf_max_iter,
            tol=FLAGS.nmf_tol,
            alpha=FLAGS.nmf_alpha,
            beta=FLAGS.nmf_beta,
            l1_ratio=FLAGS.nmf_l1_ratio,
        )
        print('NMF time: ', time.time() - start)

        W = nmf_model.W.detach().cpu().numpy()
        Ht = nmf_model.H.detach().cpu().numpy()

        decomp = nmf_common.NmfDecomposition(
            W=W,
            H=Ht.T,
            reduce_kept_indices=reduction_info.kept_original_indices,
            full_dense_size=reduction_info.original_size,
        )

        print('Starting to save NMF decomposition.')
        start = time.time()
        decomp.save(f'{output_base_path}.ssi{subset_index}.h5')
        print('Save NMF decomposition time: ', time.time() - start)

        # Remove references to these out of paranoia that their memory
        # might not get cleared otherwise.
        del torch_fishers
        del nmf_model

        # Clear GPU memory so hopefully we can run again without OOMing.
        torch.cuda.empty_cache()


if __name__ == "__main__":
    app.run(main)
