"""Computes the coefficients given a pre-computed NMF."""
import dataclasses
import os
import time
from typing import Optional

from absl import app
from absl import flags
from absl import logging

import numpy as np
from transformers import TFAutoModelForSequenceClassification
import tensorflow as tf

from em.fishers import per_example
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.tools.nmf import nmf_transform
from em.util import flat_pack
from em.util import hf_util
from em.util import vat_da_faak_vpn


# A block means what is commonly referred to as a transformer layer. A sub-block
# consists of either the attention layers or the FFW layers within a single block.
# The embeddings and the pooling layer are each their own blocks and sub-blocks.
_SUBSET_STYLES = ['', 'per_block', 'per_sub_block']


FLAGS = flags.FLAGS


flags.DEFINE_string("output_path", None, "Path to h5 file(s) to write output to.")

flags.DEFINE_string("per_example_fishers", None, "Path to file containing per-example Fishers.")
flags.DEFINE_string(
    "decomposition",
    None,
    "Path to file NMF decomposition. If doing per-subset, this will be with the .ssi# removed."
)

flags.DEFINE_integer("n_examples", None, '')
flags.DEFINE_integer("start_fisher_index", None, '')
flags.DEFINE_integer("end_fisher_index", None, '')

nmf_common.add_nmf_flags()


# An empty string means that we are not doing stuff per-subset.
flags.DEFINE_enum('subset_style', '', _SUBSET_STYLES, '')

# The stuff below is not needed if we are not doing per-subset.

# If subset style is not '', then having this set to None will do all subsets.
flags.DEFINE_list("subset_indices", None, '')

flags.DEFINE_string("model", None, "Path to HuggingFace model file (or a model with the same variables).")
flags.DEFINE_bool("from_pt", True, "")

PEF_TMV_PREFIX = 'pef'
tmv.add_variable_filter_flags(PEF_TMV_PREFIX)


def load_pe_fishers():
    print('Starting to load saved per-example Fishers.')
    start = time.time()
    pe_fishers_data = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(FLAGS.per_example_fishers),
        n_examples=FLAGS.n_examples,
        start_fisher_index=FLAGS.start_fisher_index,
        end_fisher_index=FLAGS.end_fisher_index,
    )
    print('Load saved per-example Fishers time: ', time.time() - start)
    return pe_fishers_data


def load_nmf_decomp(subset_index: Optional[int] = None):
    print('Starting to load saved NMF decomposition.')

    filepath = os.path.expanduser(FLAGS.decomposition)

    if subset_index is not None:
        assert filepath.endswith('.h5')
        filepath = f"{filepath[:-3]}.ssi{subset_index}.h5"

    start = time.time()
    decomp = nmf_common.NmfDecomposition.load(filepath)
    decomp.normalize_components_to_unit_norm()
    print('Load saved NMF decomposition time: ', time.time() - start)

    return decomp


###############################################################################

def group_variables(variables):
    subset_style = FLAGS.subset_style
    if subset_style == 'per_block':
        return tmv.group_by_blocks(variables)
    elif subset_style == 'per_sub_block':
        return tmv.group_by_sub_blocks(variables)
    else:
        raise ValueError(f'Invalid subset style: {subset_style}')


def get_packer():
    model = TFAutoModelForSequenceClassification.from_pretrained(
        os.path.expanduser(FLAGS.model),
        from_pt=FLAGS.from_pt,
    )
    variables = hf_util.get_mergeable_variables(model)

    pef_variable_filter = tmv.get_variable_filter_from_flags(PEF_TMV_PREFIX)
    variables = pef_variable_filter.filter_parallel_lists(variables)

    subsets = group_variables(variables)
    print('Number of subsets:', len(subsets))

    # We don't care about the actual shapes, just their number of parameters.
    shapes = [[sum(tf.size(v) for v in subset)] for subset in subsets]
    return flat_pack.FlatPacker(shapes)


def get_subset_indices(n_subsets: int):
    if FLAGS.subset_indices is None:
        return range(n_subsets)
    else:
        return [int(s) for s in FLAGS.subset_indices]


def get_values_for_subset(pe_fishers_data, packer, subset_index: int):
    start_index, end_index = packer.get_range_for_tensor_by_index(subset_index)

    subset_values = []
    subset_indices = []
    for values, inds in zip(pe_fishers_data.fishers, pe_fishers_data.fisher_indices):
        mask = (start_index <= inds) & (inds < end_index)
        subset_values.append(values[mask])
        subset_indices.append(inds[mask] - start_index)

    return subset_values, subset_indices


def main_per_subset():
    assert FLAGS.output_path.endswith('.h5')
    # Remove the file extension.
    output_base_path = os.path.expanduser(FLAGS.output_path[:-3])

    packer = get_packer()
    n_subsets = len(packer.tensor_shapes)

    pe_fishers_data = load_pe_fishers()

    subset_indices = get_subset_indices(n_subsets)

    for ssi in subset_indices:
        decomp = load_nmf_decomp(ssi)

        subset_values, subset_indices = get_values_for_subset(pe_fishers_data, packer, ssi)

        print(f'Starting NMF transform for subset {ssi}.')
        start = time.time()
        coeffs = nmf_transform.transform(
            decomp,
            subset_values,
            subset_indices,
            alpha=FLAGS.nmf_alpha,
            beta=FLAGS.nmf_beta,
            l1_ratio=FLAGS.nmf_l1_ratio,
            tol=FLAGS.nmf_tol,
            max_iter=FLAGS.nmf_max_iter,
        )
        print('NMF transform time: ', time.time() - start)

        decomp_for_new_data = dataclasses.replace(decomp, W=coeffs)
        
        print(f'Starting to save NMF decomposition for subset {ssi}')
        start = time.time()
        decomp_for_new_data.save(f'{output_base_path}.ssi{ssi}.h5')
        print('Save NMF decomposition time: ', time.time() - start)


###############################################################################


def main(_):
    if FLAGS.subset_style:
        main_per_subset()
    else:
        raise NotImplementedError('TODO: Handle non-per-subset stuff.')


if __name__ == "__main__":
    app.run(main)
