R"""Creates approximate Fishers for seemingly correct components.

# create_nmf_components_fishers.py

"""
import itertools
import os
from typing import Sequence

from absl import app
from absl import flags

import numpy as np
import tensorflow as tf
from tqdm import tqdm
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.fishers import diagonal
from em.models import transformer_model_vars as tmv
from em.util import hf_util
from em.util.color_util import cu

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf


# A block means what is commonly referred to as a transformer layer. A sub-block
# consists of either the attention layers or the FFW layers within a single block.
# The embeddings and the pooling layer are each their own blocks and sub-blocks.
_SUBSET_STYLES = ['per_block', 'per_sub_block']


_TUNING_INDICATORS = ['correct', 'incorrect']


FLAGS = flags.FLAGS

flags.DEFINE_string("output_path", None, "Path to h5 file(s) to write output to.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("tokenizer", None, "Defaults to `model` if not set.")

flags.DEFINE_string("per_example_fishers", None, "Path to file containing per-example Fishers.")
flags.DEFINE_string(
    "decomposition",
    None,
    "Path to file NMF decomposition. If doing per-subset, this will be with the .ssi# removed."
)

flags.DEFINE_enum('tuning_indicator', _TUNING_INDICATORS[0], _TUNING_INDICATORS, '')

flags.DEFINE_enum('subset_style', None, _SUBSET_STYLES, '')

PEF_TMV_PREFIX = 'pef'
tmv.add_variable_filter_flags(PEF_TMV_PREFIX)

flags.DEFINE_bool('shift_labels', True, '')

flags.DEFINE_integer(
    "n_labeled",
    None,
    "Kind of a hack to handle some unlabelled examples. Shouldn't be needed for newer per-example Fishers."
)


# Selection parameter flags.
flags.DEFINE_list('selection_coeff_factors', None, 'List[float]')
flags.DEFINE_list('selection_frac_thresholds', None, 'List[float]')
flags.DEFINE_list('selection_p_value_thresholds', None, 'List[float]')
flags.DEFINE_integer('selection_max_examples', None, '')


flags.DEFINE_bool('also_coalesce_via_batch_fishers', False, '')

# Only needed if --also_coalesce_via_batch_fishers=true.
flags.DEFINE_integer('batch_fisher_batch_size', 2, '')


def group_variables(variables):
    subset_style = FLAGS.subset_style
    if subset_style == 'per_block':
        return tmv.group_by_blocks(variables)
    elif subset_style == 'per_sub_block':
        return tmv.group_by_sub_blocks(variables)
    else:
        raise ValueError(f'Invalid subset style: {subset_style}')


def get_grouped_variables() -> Sequence[Sequence[tf.Variable]]:
    model = TFAutoModelForSequenceClassification.from_pretrained(
        os.path.expanduser(FLAGS.model),
        from_pt=FLAGS.from_pt,
    )
    variables = hf_util.get_mergeable_variables(model)

    pef_variable_filter = tmv.get_variable_filter_from_flags(PEF_TMV_PREFIX)
    variables = pef_variable_filter.filter_parallel_lists(variables)

    return model, variables, group_variables(variables)


def get_selection_parameters():
    coeff_factors = [float(x) for x in FLAGS.selection_coeff_factors]
    frac_thresholds = [float(x) for x in FLAGS.selection_frac_thresholds]
    p_value_thresholds = [float(x) for x in FLAGS.selection_p_value_thresholds]

    max_examples = FLAGS.selection_max_examples

    return [
        ncf.SelectionParameters(
            coeff_factor=cf,
            frac_threshold=ft,
            p_value_threshold=pvt,
            max_examples=max_examples,
        )
        for cf, ft, pvt in itertools.product(coeff_factors, frac_thresholds, p_value_thresholds)
    ]


def _to_filepath_identifier(sp):
    ret = f'cf{sp.coeff_factor:.3f}_ft{sp.frac_threshold:.3f}_pvt{sp.p_value_threshold:.4f}'
    if sp.max_examples is not None:
        ret = f'{ret}_me{sp.max_examples}'
    ret = ret.replace('.', '_')
    return ret


def will_output_filepaths_be_unique(selection_parameters):
    fpids = [_to_filepath_identifier(sp) for sp in selection_parameters]
    return len(fpids) == len(set(fpids))


def get_output_filepath(sp):
    assert FLAGS.output_path.endswith('.h5')
    return f'{os.path.expanduser(FLAGS.output_path)[:-3]}.{_to_filepath_identifier(sp)}.h5'


def get_indicator(container):
    ti = FLAGS.tuning_indicator
    if ti == 'correct':
        return container.get_correct_prediction_indicator()
    elif ti == 'incorrect':
        return container.get_incorrect_prediction_indicator()
    else:
        raise ValueError(f'Invalid tuning indicator: {ti}')

###############################################################################


def _hack_get_bert_token_type_ids(tokenizer, input_ids):
    print(cu.hly('Assuming that we are using a BERT-tokenizer for creating token-type ids.'))
    sep_token_id = tokenizer.sep_token_id
    mask = np.cumsum((input_ids == sep_token_id).astype(np.int32), axis=-1) == 1
    mask = np.concatenate([np.zeros_like(mask[..., :1]), mask[..., :-1]], axis=-1)
    return mask.astype(input_ids.dtype)


def get_ds_from_example_indices(container, example_indices):
    example_indices = np.array(example_indices, dtype=np.int32)
    pef = container.pef
    tokenizer = container.tokenizer

    labels = pef.labels[example_indices]
    input_ids = pef.input_ids[example_indices]
    token_type_ids = _hack_get_bert_token_type_ids(tokenizer, input_ids)

    ds = tf.data.Dataset.from_tensor_slices(({
        'input_ids': input_ids,
        'token_type_ids': token_type_ids,
    }, labels))
    ds = ds.batch(FLAGS.batch_fisher_batch_size)

    return ds


def compute_batch_fisher(container, model, variables, example_indices):
    ds = get_ds_from_example_indices(container, example_indices)
    batch_fisher = diagonal.compute_fisher_for_model(model, ds, variables=variables, use_tqdm=True)
    return [f.numpy() for f in batch_fisher]


###############################################################################


def main(_):
    assert FLAGS.output_path.endswith('.h5')

    assert FLAGS.selection_coeff_factors is not None
    assert FLAGS.selection_frac_thresholds is not None
    assert FLAGS.selection_p_value_thresholds is not None

    selection_parameters = get_selection_parameters()

    if not will_output_filepaths_be_unique(selection_parameters):
        raise ValueError('The output filepaths will not be unique with how I am generating them now. Need to update code!')

    #############################################

    model, variables, grouped_variables = get_grouped_variables()

    tokenizer = AutoTokenizer.from_pretrained(os.path.expanduser(FLAGS.tokenizer))

    container = am.load_pef_nmf_analysis_container(
        pef_filepath=os.path.expanduser(FLAGS.per_example_fishers),
        nmf_filepath=os.path.expanduser(FLAGS.decomposition),
        n_nmfs=len(grouped_variables),
        tokenizer=tokenizer,
        shift_labels=FLAGS.shift_labels,
    )

    if FLAGS.shift_labels:
        # Do this kinda hack to get labels correct. Having shift_labels=True
        # and then doing this gets the string labels and the integer labels
        # properly matched.
        container.labels = container.pef.labels
        container.examples = container._make_nli_examples()

    if FLAGS.n_labeled is not None:
        # Do this kinda hack since we don't know the number of examples beforehand.
        n_examples = container.pef.fishers.shape[0]
        unlabeled_indicator = np.ones([n_examples], dtype=bool)
        unlabeled_indicator[:FLAGS.n_labeled] = False

        container.unlabeled_indicator = unlabeled_indicator
        container.examples = container._make_nli_examples()

    container.nmfs.force_load_all()

    #############################################

    indicator = get_indicator(container)

    for sp in tqdm(selection_parameters):
        # ac_fisher = ncf.get_apparently_correct_fisher(container, sp)
        ac_fisher = ncf.get_apparently_tuned_fisher(container, indicator, sp)

        for group, ssi in zip(grouped_variables, ac_fisher.subset_infos):
            ssi.variable_shapes = [v.shape for v in group]

        ac_fisher.compute_examples_accuracy_info(container)

        if FLAGS.also_coalesce_via_batch_fishers:
            correct_labeled_indices = ac_fisher.get_all_labeled_example_indices()
            erroring_labeled_indices = set(np.nonzero(~container.unlabeled_indicator)[0]) - set(correct_labeled_indices)

            ac_fisher.batch_correct_fishers = compute_batch_fisher(container, model, variables, list(correct_labeled_indices))
            ac_fisher.batch_erroring_fishers = compute_batch_fisher(container, model, variables, list(erroring_labeled_indices))

        ac_fisher.save(get_output_filepath(sp))

        # TODO: Make this cleaner.
        print(ac_fisher.examples_accuracy_info.n_total_examples)
        print(ac_fisher.examples_accuracy_info.n_component_examples / ac_fisher.examples_accuracy_info.n_total_examples)
        print(ac_fisher.examples_accuracy_info.component_examples_accuracy)
        print(ac_fisher.examples_accuracy_info.remaining_examples_accuracy)
        print()


if __name__ == "__main__":
    app.run(main)
