"""Performs NMF and saves results to disk."""
import os
import time

from absl import app
from absl import flags
from absl import logging

# NOTE: For some reason, I have to place the`import matplotlib.pyplot as plt`
# statement here or right below the numpy import. Otherwise the
# `torch.sparse_coo_tensor` call causes a segfault. We're not even using
# pyplot here. What the fuck?!?!
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
from torchnmf.nmf import NMF as TorchNMF
from transformers import TFAutoModelForSequenceClassification

from em.fishers import per_example
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util
from em.util import sparse_util
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS

# A block means what is commonly referred to as a transformer layer. A sub-block
# consists of either the attention layers or the FFW layers within a single block.
# The embeddings and the pooling layer are each their own blocks and sub-blocks.
_SUBSET_STYLES = ['per_block', 'per_sub_block']

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("per_example_fishers", None, "Path to file containing per-example Fishers.")

flags.DEFINE_integer("n_examples", None, '')

flags.DEFINE_integer("start_fisher_index", None, '')
flags.DEFINE_integer("end_fisher_index", None, '')

flags.DEFINE_integer("reduce_threshold", 1, "")

flags.DEFINE_enum('subset_style', None, _SUBSET_STYLES, '')
# TODO: Have other ways of specifiying what subsets to compute for.
flags.DEFINE_list("subset_indices", None, '')

flags.DEFINE_string("model", None, "Path to HuggingFace model file (or a model with the same variables).")
flags.DEFINE_bool("from_pt", True, "")

# This should only be used if the Fishers were computed with the sparse_dynamic_metric_derived
# style of sparsity.
flags.DEFINE_bool("use_metric_derived_distances", False, "")


nmf_common.add_nmf_flags()

PEF_TMV_PREFIX = 'pef'
tmv.add_variable_filter_flags(PEF_TMV_PREFIX)


def get_values_for_subset(pe_fishers_data, packer: flat_pack.FlatPacker, subset_index: int):
    start_index, end_index = packer.get_range_for_tensor_by_index(subset_index)

    # normalized_sq_parameter_deltas = pe_fishers_data.sq_parameter_deltas / np.sum(pe_fishers_data.sq_parameter_deltas)

    subset_values = []
    subset_indices = []
    for values, inds in zip(pe_fishers_data.fishers, pe_fishers_data.fisher_indices):
        mask = (start_index <= inds) & (inds < end_index)
        ex_values = values[mask]
        ex_inds = inds[mask]

        if FLAGS.use_metric_derived_distances:
            ex_values = ex_values * pe_fishers_data.sq_parameter_deltas[inds[mask]]

        subset_values.append(ex_values)
        subset_indices.append(ex_inds - start_index)

    return subset_values, subset_indices


def get_reduced_sparse_fishers(pe_fishers_data, packer: flat_pack.FlatPacker, subset_index: int):
    subset_values, subset_indices = get_values_for_subset(pe_fishers_data, packer, subset_index)

    start_index, end_index = packer.get_range_for_tensor_by_index(subset_index)

    reduction_info, (reduced_subset_values, reduced_subset_indices) = sparse_util.remove_always_zero_indices(
        subset_values,
        subset_indices,
        dense_size=end_index - start_index,
        threshold=FLAGS.reduce_threshold,
    )

    coo_indices = sparse_util.to_torch_coo_indices(reduced_subset_indices)
    coo_values = np.concatenate(list(reduced_subset_values), axis=-1)
    dense_shape = [len(reduced_subset_indices), reduction_info.reduced_size]

    # TODO: Directly create the transpose, should be fairly straightforward.
    torch_fishers = torch.sparse_coo_tensor(coo_indices, coo_values, dense_shape).cuda()
    torch_fishers = torch_fishers.t_()

    return reduction_info, torch_fishers


def get_subset_indices():
    assert FLAGS.subset_indices is not None
    return [int(s) for s in FLAGS.subset_indices]


def group_variables(variables):
    subset_style = FLAGS.subset_style
    if subset_style == 'per_block':
        return tmv.group_by_blocks(variables)
    elif subset_style == 'per_sub_block':
        return tmv.group_by_sub_blocks(variables)
    else:
        raise ValueError(f'Invalid subset style: {subset_style}')


def get_packer():
    model = TFAutoModelForSequenceClassification.from_pretrained(
        os.path.expanduser(FLAGS.model),
        from_pt=FLAGS.from_pt,
    )
    variables = hf_util.get_mergeable_variables(model)

    pef_variable_filter = tmv.get_variable_filter_from_flags(PEF_TMV_PREFIX)
    variables = pef_variable_filter.filter_parallel_lists(variables)

    subsets = group_variables(variables)
    print('Number of subsets:', len(subsets))

    # We don't care about the actual shapes, just their number of parameters.
    shapes = [[sum(tf.size(v) for v in subset)] for subset in subsets]
    return flat_pack.FlatPacker(shapes)


def main(_):
    # Keep tensorflow from allocating all GPU memory to allow torchnmf to
    # use GPU.
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    #######################################################

    assert FLAGS.output_path.endswith('.h5')
    # Remove the file extension.
    output_base_path = os.path.expanduser(FLAGS.output_path[:-3])

    #######################################################

    print('Starting to load saved per-example Fishers.')
    start = time.time()
    pe_fishers_data = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(FLAGS.per_example_fishers),
        n_examples=FLAGS.n_examples,
        start_fisher_index=FLAGS.start_fisher_index,
        end_fisher_index=FLAGS.end_fisher_index,
        normalize_fishers=not FLAGS.use_metric_derived_distances,
    )
    
    # Normalize differently in this case.
    if FLAGS.use_metric_derived_distances:
        pe_fishers_data.fishers /= (pe_fishers_data.dense_metric_derived_norms[:, None] + 1e-12)

    print('Load saved per-example Fishers time: ', time.time() - start)

    packer = get_packer()

    #######################################################

    indices_of_subsets = get_subset_indices()
    for subset_index in indices_of_subsets:
        print(f'Starting to run NMF for subset with index {subset_index}')

        reduction_info, torch_fishers = get_reduced_sparse_fishers(pe_fishers_data, packer, subset_index)

        print('Starting NMF decomposition.')
        start = time.time()
        nmf_model = TorchNMF(torch_fishers.shape, **nmf_common.get_nmf_init_kwargs_from_flags()).cuda()
        nmf_model.fit(
            torch_fishers,
            verbose=True,
            **nmf_common.get_nmf_fit_kwargs_from_flags(),
        )
        print('NMF time: ', time.time() - start)

        W = nmf_model.W.detach().cpu().numpy()
        Ht = nmf_model.H.detach().cpu().numpy()

        decomp = nmf_common.NmfDecomposition(
            W=W,
            H=Ht.T,
            reduce_kept_indices=reduction_info.kept_original_indices,
            full_dense_size=reduction_info.original_size,
        )

        print('Starting to save NMF decomposition.')
        start = time.time()
        decomp.save(f'{output_base_path}.ssi{subset_index}.h5')
        print('Save NMF decomposition time: ', time.time() - start)

        # Remove references to these out of paranoia that their memory
        # might not get cleared otherwise.
        del torch_fishers
        del nmf_model

        # Clear GPU memory so hopefully we can run again without OOMing.
        torch.cuda.empty_cache()


if __name__ == "__main__":
    app.run(main)
