R"""Script for computing and saving LRM-PEFs to disk.

Only computes and saves the PEFs for examples on which the model
makes an incorrect prediction.
"""
import os

from absl import app
from absl import flags

from em import datasets as em_datasets
from em.models import em_models
from em.fishers import lrm_pefs
from em.fishers import lrm_pefs2
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS


flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("tokenizer", None, "Defaults to the value of --model if not set. Not used for image models.")

flags.DEFINE_string("task", None, "")
#
flags.DEFINE_list("split", ["train"], "If multiple splits are provided, will run on their concatenation.")
flags.DEFINE_integer("n_examples_to_process", None, "Number of examples to process. Number saved will be smaller if any are correctly predicted.")
flags.DEFINE_integer("batch_size", 16, "")
flags.DEFINE_integer("sequence_length", 128, "Note that this is used to set image sizes as well.")
#
flags.DEFINE_bool("shuffle", False, "")
flags.DEFINE_integer("skip", None, "")

flags.DEFINE_bool("ds_force_deterministic", False, "Only has effects for some datasets.")

flags.DEFINE_integer("n_fisher_values_per_example", None, "")

flags.DEFINE_integer("top_k_classes", None, "Only compute PEFs for this many top classes. Leave undefined to do all.")


###############################################################################


def create_dataset(tokenizer):
    if FLAGS.task.startswith('winogrande/') and FLAGS.ds_force_deterministic:
        extra_kwargs = {'force_deterministic': True}
    else:
        extra_kwargs = {}

    ds = None

    for split in FLAGS.split:
        split_ds = em_datasets.load(
            FLAGS.task,
            split=split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.sequence_length,
            **extra_kwargs,
        )
        if ds is None:
            ds = split_ds
        else:
            ds = ds.concatenate(split_ds)
    
    if FLAGS.skip is not None:
        ds = ds.skip(FLAGS.skip)

    if FLAGS.shuffle:
        ds = ds.shuffle(1000)
    ds = ds.repeat().take(FLAGS.n_examples_to_process).cache()
    ds = ds.batch(FLAGS.batch_size)
    return ds


def main(_):
    if 'snli' in FLAGS.task or 'mnli' in FLAGS.task:
        raise NotImplementedError("TODO: Support the option of fixing MNLI/SNLI labels.")
    model_str = os.path.expanduser(FLAGS.model)
    model = em_models.from_pretrained(model_str, from_pt=FLAGS.from_pt)

    # NOTE: For transformers, trainable_variables and variables are the same. I am
    # not sure what is correct for vision models.
    variables = model.trainable_variables

    tokenizer = em_models.load_tokenizer(FLAGS.tokenizer or model_str)
    ds = create_dataset(tokenizer)

    fisher_computer = lrm_pefs.SparseLrmPefComputer(
        model=model,
        variables=variables,
        n_values_per_example=FLAGS.n_fisher_values_per_example,
        top_k_classes=FLAGS.top_k_classes,
    )

    saver = lrm_pefs2.WrongsOnlyStreamingLrmPefSaver(
        fisher_computer=fisher_computer,
        use_tqdm=True,
    )

    # TODO: Maybe add some safety in case n_examples is greater than the size of the dataset?
    output_path = os.path.expanduser(FLAGS.output_path)
    saver.compute_and_save_pefs(output_path, ds)


if __name__ == "__main__":
    app.run(main)
