R"""Script for computing the diagonal Fisher of a model.


"""
import os

from absl import app
from absl import flags
from absl import logging
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.models import transformer_model_vars as tmv
from em.util import hdf5_util
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS

# TODO: Add descriptions to flags
flags.DEFINE_string("model", None, "")
flags.DEFINE_string("task", None, "")
flags.DEFINE_string("fisher_path", None, "Path of hdf5 file to save Fisher to.")

flags.DEFINE_string("tokenizer", None, "Defaults to the value of --model if not set.")

flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("split", "train", "")
flags.DEFINE_integer("n_examples", 4096, "")
flags.DEFINE_integer("batch_size", 2, "")
flags.DEFINE_integer("sequence_length", 128, "")

flags.DEFINE_integer("skip", None, "")

TMV_PREFIX = 'include'
tmv.add_variable_filter_flags(TMV_PREFIX)


flags.mark_flags_as_required(["model", "task", "fisher_path"])


def main(_):
    # Expand the model just in case it is a path rather than
    # the name of a model from HuggingFace's repository.
    model_str = os.path.expanduser(FLAGS.model)
    model = TFAutoModelForSequenceClassification.from_pretrained(
        model_str, from_pt=FLAGS.from_pt
    )
    tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer or model_str)
    logging.info("Model loaded")

    ds = em_datasets.load(
        FLAGS.task,
        split=FLAGS.split,
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
    )
    if FLAGS.skip is not None:
        ds = ds.skip(FLAGS.skip)
    ds = ds.take(FLAGS.n_examples).batch(FLAGS.batch_size)
    logging.info("Dataset loaded")

    logging.info("Starting Fisher computation")
    variable_filter = tmv.get_variable_filter_from_flags(TMV_PREFIX)
    variables = variable_filter.filter_parallel_lists(model.trainable_variables)
    fisher_diag = diagonal.compute_fisher_for_model(model, ds, variables=variables, use_tqdm=True)

    logging.info("Fisher computed. Saving to file...")
    fisher_path = os.path.expanduser(FLAGS.fisher_path)
    hdf5_util.save_variables_to_hdf5(fisher_diag, fisher_path)
    logging.info("Fisher saved to file")


if __name__ == "__main__":
    app.run(main)
