R"""Computes and saves per-example Fishers to disk.

Generating Fishers on the fly is often quite slow, so doing it once
and saving to disk can greatly speed development of downstream methods.

TODO: Add support for writing sharded files.
"""
import collections
import functools
import os
import time
from typing import Dict, List, Optional, Sequence, Union

from absl import app
from absl import flags
from absl import logging
import h5py
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import sparse_diagonal
from em.models import transformer_model_vars as tmv

from em.util import hf_util
from em.util import sparse_tf_util
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS


_FLAVORS = ['full', 'sparse_static', 'sparse_dynamic_raw', 'sparse_dynamic_metric_derived']

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("trained_model", None, "")
flags.DEFINE_bool("from_pt_trained", True, "")

flags.DEFINE_string("tokenizer", None, "Defaults to the value of --trained_model if not set.")

flags.DEFINE_string("task", None, "")
#
flags.DEFINE_list("split", ["train"], "If multiple splits are provided, will run on their concatenation.")
flags.DEFINE_integer("n_examples", None, "")
flags.DEFINE_integer("batch_size", 16, "")
flags.DEFINE_integer("n_sub_batches", 1, "Possibly temporary measure to reduce the performance impact of the top-k call.")
flags.DEFINE_integer("sequence_length", 128, "")
#
flags.DEFINE_bool("shuffle", False, "")
flags.DEFINE_integer("skip", None, "")

# TODO: When saving sparse fishers, we lose the information of
# the magnitude of the entire per-example Fisher. If it is desirable
# to have this information, save this information.
# flags.DEFINE_bool("normalize_fishers", False, "")

# TODO: Support saving per-class Fishers, perhaps only doing a subset
# of classes.
flags.DEFINE_bool("expectation_wrt_logits", True, "")

flags.DEFINE_enum('flavor', None, _FLAVORS, '')

flags.DEFINE_bool("ds_force_deterministic", False, "Only has effects for some datasets.")

# Flavor-specific flags:

# Required for "sparse_static":
flags.DEFINE_string("sparse_fisher", None, "")

# Required for "sparse_dynamic_metric_derived":
flags.DEFINE_string("pretrained_model", None, "")
flags.DEFINE_bool("from_pt_pretrained", True, "")

# Required for "sparse_dynamic_*":
# TODO: Allow multiple ways of determining how many values to keep
# from each example's Fisher, perhaps supporting keeping a different
# number of values for each example.
flags.DEFINE_integer("n_fisher_values_per_example", None, "")


TMV_PREFIX = 'include'
tmv.add_variable_filter_flags(TMV_PREFIX)


def create_dataset_for_per_example_fishers(tokenizer):
    if FLAGS.task.startswith('winogrande/') and FLAGS.ds_force_deterministic:
        extra_kwargs = {'force_deterministic': True}
    else:
        extra_kwargs = {}

    ds = None

    for split in FLAGS.split:
        split_ds = em_datasets.load(
            FLAGS.task,
            split=split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.sequence_length,
            **extra_kwargs,
        )
        if ds is None:
            ds = split_ds
        else:
            ds = ds.concatenate(split_ds)
    
    if FLAGS.skip is not None:
        ds = ds.skip(FLAGS.skip)

    if FLAGS.shuffle:
        ds = ds.shuffle(1000)
    ds = ds.repeat().take(FLAGS.n_examples).cache()
    ds = ds.batch(FLAGS.batch_size)
    return ds


def get_predictions(trained_model, examples):
    return tf.argmax(trained_model(examples).logits, axis=-1)


def get_squared_parameter_deltas(trained_model, variable_filter) -> List[tf.Tensor]:
    """Returns list of squared deltas from pretrained to finetuned variables.
    
    The returned list is not filtered by the variable filter.
    """
    finetuned_variables = hf_util.get_all_variables(trained_model)
    _, finetuned_variables = variable_filter.filter_parallel_lists(model_variables, finetuned_variables)

    pretrained_model_str = os.path.expanduser(FLAGS.pretrained_model)
    pretrained_model = TFAutoModelForSequenceClassification.from_pretrained(
        pretrained_model_str,
        from_pt=FLAGS.from_pt_pretrained,
    )
    pretrained_variables = hf_util.get_all_variables(pretrained_model)
    _, pretrained_variables = variable_filter.filter_parallel_lists(model_variables, pretrained_variables)

    assert len(pretrained_variables) == len(finetuned_variables)

    return [
        (ft - pt)**2
        for ft, pt in zip(finetuned_variables, pretrained_variables)
    ]


def _flatten(tensors: Sequence[tf.Tensor]) -> tf.Tensor:
    return tf.concat([
        tf.reshape(t, [-1])
        for t in tensors
    ], axis=0)


def load_static_sparse_indices() -> List[tf.Tensor]:
    sparse_fisher_str = os.path.expanduser(FLAGS.sparse_fisher)
    sparse_fisher = sparse_diagonal.SparseDiagonalFisher.load(sparse_fisher_str)
    return [sparse_tf_util.get_sparse_indices(s) for s in sparse_fisher.fishers]


@tf.function
def _compute_batch_fishers(
    trained_model,
    variables,
    batch,
    sparse_indices=None,
    flat_sq_parameter_deltas=None,
):
    n_sub_batches = FLAGS.n_sub_batches
    sub_batch_size = FLAGS.batch_size // n_sub_batches
    n_batch_examples = diagonal.batch_size_from_batch(batch)

    batch_fishers = [[] for v in variables]
    for i in range(n_sub_batches):
        sub_batch = {
            k: v[i * sub_batch_size : tf.minimum((i + 1) * sub_batch_size, n_batch_examples)]
            for k, v in batch.items()
        }
        sbfs = diagonal.compute_exact_fisher_for_batch(
            sub_batch,
            trained_model,
            variables,
            expectation_wrt_logits=FLAGS.expectation_wrt_logits,
            per_example=True,
        )
        for bfv, sbfv in zip(batch_fishers, sbfs):
            bfv.append(sbfv)

    batch_fishers = [tf.concat(f, axis=0) for f in batch_fishers]

    #

    dense_fisher_norms = [
        tf.reduce_sum(tf.reshape(f**2, [n_batch_examples, -1]), axis=-1)
        for f in batch_fishers
    ]
    dense_fisher_norms = tf.sqrt(tf.reduce_sum(dense_fisher_norms, axis=0))

    if sparse_indices is not None:
        batch_fishers = [
            tf.vectorized_map(functools.partial(tf.gather_nd, indices=inds), f)
            for f, inds in zip(batch_fishers, sparse_indices)
        ]

    batch_fishers = [
        tf.reshape(f, [n_batch_examples, -1])
        for f in batch_fishers
    ]
    batch_fishers = tf.concat(batch_fishers, axis=-1)

    # The topk seems to be taking up the most time. Bigger batch sizes help.
    # TODO: Try approx top k.
    if FLAGS.flavor == "sparse_dynamic_raw":
        # See https://gist.github.com/redzhepdx/1e65bd2a721ff7fbf532 for possible Cuda implementation of quickselect.
        # NOTE: I think that is for doing a bunch of quick selects at once, which isn't really what I want.
        #
        # https://hpc.fau.de/files/2021/03/2021-03-23_ribizel.pdf
        # https://icl.utk.edu/files/publications/2019/icl-utk-1230-2019.pdf
        values, indices = tf.math.top_k(batch_fishers, k=FLAGS.n_fisher_values_per_example)
        return dense_fisher_norms, None, (values, indices)

    elif FLAGS.flavor == "sparse_dynamic_metric_derived":
        fisher_times_sq_delta = batch_fishers * flat_sq_parameter_deltas
        _, indices = tf.math.top_k(
            fisher_times_sq_delta,
            k=FLAGS.n_fisher_values_per_example)
        values = tf.gather(batch_fishers, indices, batch_dims=1)

        dense_metric_derived_norms = tf.sqrt(tf.reduce_sum(fisher_times_sq_delta**2, axis=-1))

        return dense_fisher_norms, dense_metric_derived_norms, (values, indices)

    return dense_fisher_norms, None, batch_fishers


# @tf.function
# def _compute_batch_fishers(
#     trained_model,
#     variables,
#     batch,
#     sparse_indices=None,
#     flat_sq_parameter_deltas=None,
# ):
#     n_batch_examples = diagonal.batch_size_from_batch(batch)

#     batch_fishers = diagonal.compute_exact_fisher_for_batch(
#         batch,
#         trained_model,
#         variables,
#         expectation_wrt_logits=FLAGS.expectation_wrt_logits,
#         per_example=True,
#     )

#     dense_fisher_norms = [
#         tf.reduce_sum(tf.reshape(f**2, [n_batch_examples, -1]), axis=-1)
#         for f in batch_fishers
#     ]
#     dense_fisher_norms = tf.sqrt(tf.reduce_sum(dense_fisher_norms, axis=0))

#     if sparse_indices is not None:
#         batch_fishers = [
#             tf.vectorized_map(functools.partial(tf.gather_nd, indices=inds), f)
#             for f, inds in zip(batch_fishers, sparse_indices)
#         ]

#     batch_fishers = [
#         tf.reshape(f, [n_batch_examples, -1])
#         for f in batch_fishers
#     ]
#     batch_fishers = tf.concat(batch_fishers, axis=-1)

#     # The topk seems to be taking up the most time. Bigger batch sizes help.
#     # TODO: Try approx top k.
#     if FLAGS.flavor == "sparse_dynamic_raw":
#         # See https://gist.github.com/redzhepdx/1e65bd2a721ff7fbf532 for possible Cuda implementation of quickselect.
#         # NOTE: I think that is for doing a bunch of quick selects at once, which isn't really what I want.
#         #
#         # https://hpc.fau.de/files/2021/03/2021-03-23_ribizel.pdf
#         # https://icl.utk.edu/files/publications/2019/icl-utk-1230-2019.pdf
#         values, indices = tf.math.top_k(batch_fishers, k=FLAGS.n_fisher_values_per_example)
#         return dense_fisher_norms, None, (values, indices)

#     elif FLAGS.flavor == "sparse_dynamic_metric_derived":
#         fisher_times_sq_delta = batch_fishers * flat_sq_parameter_deltas
#         _, indices = tf.math.top_k(
#             fisher_times_sq_delta,
#             k=FLAGS.n_fisher_values_per_example)
#         values = tf.gather(batch_fishers, indices, batch_dims=1)

#         dense_metric_derived_norms = tf.sqrt(tf.reduce_sum(fisher_times_sq_delta**2, axis=-1))

#         return dense_fisher_norms, dense_metric_derived_norms, (values, indices)

#     return dense_fisher_norms, None, batch_fishers


def what_we_want(
    trained_model: TFAutoModelForSequenceClassification,
    pe_ds: tf.data.Dataset,
    file,
):
    n_examples = FLAGS.n_examples
    variable_filter = tmv.get_variable_filter_from_flags(TMV_PREFIX)

    model_variables = hf_util.get_all_variables(trained_model)
    sparse_indices = None
    flat_sq_parameter_deltas = None

    if FLAGS.flavor == 'sparse_static':
        sparse_indices = load_static_sparse_indices()
        assert len(sparse_indices) == len(model_variables)
        model_variables, sparse_indices = variable_filter.filter_parallel_lists(model_variables, sparse_indices)
        fisher_size = tf.reduce_sum([si.shape[0] for si in sparse_indices]).numpy()

    elif FLAGS.flavor == 'full':
        model_variables = variable_filter.filter_parallel_lists(model_variables)
        fisher_size = tf.reduce_sum([tf.size(v) for v in model_variables]).numpy()

    elif FLAGS.flavor == 'sparse_dynamic_raw':
        model_variables = variable_filter.filter_parallel_lists(model_variables)

        dense_fisher_size = tf.reduce_sum([tf.size(v) for v in model_variables]).numpy()
        fisher_size = FLAGS.n_fisher_values_per_example

    elif FLAGS.flavor == 'sparse_dynamic_metric_derived':
        sq_parameter_deltas = get_squared_parameter_deltas(trained_model, variable_filter)
        # model_variables, sq_parameter_deltas = variable_filter.filter_parallel_lists(model_variables, sq_parameter_deltas)
        model_variables = variable_filter.filter_parallel_lists(model_variables)
        flat_sq_parameter_deltas = _flatten(sq_parameter_deltas)
        # This is so I don't accidently refer to this variable later on in this function
        # and cause really annoying silent errors.
        del sq_parameter_deltas

        dense_fisher_size = tf.reduce_sum([tf.size(v) for v in model_variables]).numpy()
        fisher_size = FLAGS.n_fisher_values_per_example

        # Sanity check.
        assert flat_sq_parameter_deltas.shape[0] == dense_fisher_size

    else:
        raise ValueError(FLAGS.flavor)

    data_grp = file.create_group('data')
    data_grp.attrs['flavor'] = FLAGS.flavor

    if flat_sq_parameter_deltas is not None:
        sq_parameter_deltas_ds = data_grp.create_dataset(
            'sq_parameter_deltas',
            [dense_fisher_size],
            dtype=np.float32)
        sq_parameter_deltas_ds[:] = flat_sq_parameter_deltas

    if FLAGS.flavor == 'sparse_dynamic_metric_derived':
        dense_metric_derived_norms_ds = data_grp.create_dataset(
            'dense_metric_derived_norms',
            [n_examples],
            dtype=np.float32)

    if FLAGS.flavor in ('sparse_dynamic_raw', 'sparse_dynamic_metric_derived'):
        fisher_grp = data_grp.create_group('fisher')
        fisher_grp.attrs['dense_fisher_size'] = dense_fisher_size

        fisher_values_ds = fisher_grp.create_dataset(
            'values',
            [n_examples, fisher_size],
            dtype=np.float32)
        fisher_indices_ds = fisher_grp.create_dataset(
            'indices',
            [n_examples, fisher_size],
            dtype=np.int32)

    else:
        fishers_ds = data_grp.create_dataset(
            'fishers',
            [n_examples, fisher_size],
            dtype=np.float32)

    # Norms of the entire dense diagonal Fisher.
    dense_fisher_norms_ds = data_grp.create_dataset(
        'dense_fisher_norms',
        [n_examples],
        dtype=np.float32)

    input_ids_ds = data_grp.create_dataset(
        'input_ids',
        [n_examples, FLAGS.sequence_length],
        dtype=np.int32)
    labels_ds = data_grp.create_dataset(
        'labels',
        [n_examples],
        dtype=np.int32)
    predicted_logits_ds = data_grp.create_dataset(
        'predicted_logits',
        [n_examples, trained_model.num_labels],
        dtype=np.float32)

    n_examples_processed = 0
    for examples, labels in tqdm(pe_ds):
        if n_examples_processed >= n_examples:
            break
        n_batch_examples = diagonal.batch_size_from_batch(examples)

        # NOTE: I think I can get this while computing the batch fishers.
        # start = time.time()
        predicted_logits = trained_model(examples, training=False).logits
        # print('get_logits_time', time.time() - start)

        # start = time.time()
        dense_fisher_norms, dense_metric_derived_norms, batch_fishers = _compute_batch_fishers(
            trained_model,
            model_variables,
            examples,
            sparse_indices,
            flat_sq_parameter_deltas,
        )
        # print('compute_batch_fishers', time.time() - start)

        i1 = n_examples_processed
        i2 = min(n_examples_processed + n_batch_examples, n_examples)
        k = i2 - i1

        if isinstance(batch_fishers, tuple):
            fisher_values, fisher_indices = batch_fishers
            fisher_values_ds[i1:i2] = fisher_values[:k].numpy().astype(np.float32)
            fisher_indices_ds[i1:i2] = fisher_indices[:k].numpy().astype(np.int32)
        else:
            fishers_ds[i1:i2] = batch_fishers[:k].numpy().astype(np.float32)

        if dense_metric_derived_norms is not None:
            dense_metric_derived_norms_ds[i1:i2] = dense_metric_derived_norms[:k].numpy().astype(np.float32)

        dense_fisher_norms_ds[i1:i2] = dense_fisher_norms[:k].numpy().astype(np.float32)
        input_ids_ds[i1:i2] = examples['input_ids'][:k].numpy().astype(np.int32)
        labels_ds[i1:i2] = labels[:k].numpy().astype(np.int32)
        predicted_logits_ds[i1:i2] = predicted_logits[:k].numpy().astype(np.float32)

        n_examples_processed += n_batch_examples


def main(_):
    assert FLAGS.batch_size % FLAGS.n_sub_batches == 0

    if FLAGS.flavor == 'sparse_static':
        assert FLAGS.sparse_fisher is not None

    elif FLAGS.flavor == 'sparse_dynamic_raw':
        assert FLAGS.n_fisher_values_per_example is not None

    elif FLAGS.flavor == 'sparse_dynamic_metric_derived':
        assert FLAGS.n_fisher_values_per_example is not None
        assert FLAGS.pretrained_model is not None

    trained_model_str = os.path.expanduser(FLAGS.trained_model)
    trained_model = TFAutoModelForSequenceClassification.from_pretrained(
        trained_model_str,
        from_pt=FLAGS.from_pt_trained,
    )

    tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer or trained_model_str)
    pe_ds = create_dataset_for_per_example_fishers(tokenizer)

    # TODO: Save additional data to this file.
    output_path = os.path.expanduser(FLAGS.output_path)
    with h5py.File(output_path, "w") as f:
        what_we_want(trained_model, pe_ds, f)


"""
Stuff:
- store flat vs not, maybe always store flat but include variable shapes and masks (and names)
  so that they can be unflattened.
Modes:
- Normal dense.
- Fixed sparse mask (probably from already computed Fisher)
- Per-example sparse mask
    - Maybe sub-modes for determining sparsity by raw Fisher value vs metric-derived value.
"""


if __name__ == "__main__":
    app.run(main)
