R"""Script for computing and saving LVRM-PEFs to disk."""
import os

from absl import app
from absl import flags

from em import datasets as em_datasets
from em.models import em_models
from em.fishers import lvrm_pefs
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS


flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("tokenizer", None, "Defaults to the value of --model if not set. Not used for image models.")

flags.DEFINE_string("task", None, "")
#
flags.DEFINE_list("split", ["train"], "If multiple splits are provided, will run on their concatenation.")
flags.DEFINE_integer("n_examples", None, "")
flags.DEFINE_integer("sequence_length", 128, "Note that this is used to set image sizes as well.")
#

flags.DEFINE_integer("n_fisher_values_per_example", None, "")
flags.DEFINE_float("min_prob_class", None, "")
flags.DEFINE_integer("max_classes", None, "")

###############################################################################


def create_dataset(tokenizer):
    ds = None

    for split in FLAGS.split:
        split_ds = em_datasets.load(
            FLAGS.task,
            split=split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.sequence_length,
        )
        if ds is None:
            ds = split_ds
        else:
            ds = ds.concatenate(split_ds)
    
    ds = ds.take(FLAGS.n_examples)

    return ds


def main(_):
    model_str = os.path.expanduser(FLAGS.model)
    model = em_models.from_pretrained(model_str, from_pt=FLAGS.from_pt)

    # NOTE: For transformers, trainable_variables and variables are the same. I am
    # not sure what is correct for vision models.
    variables = model.trainable_variables

    tokenizer = em_models.load_tokenizer(FLAGS.tokenizer or model_str)
    ds = create_dataset(tokenizer)

    fisher_computer = lvrm_pefs.SparseLvrmPefComputer(
        model=model,
        variables=variables,
        n_values_per_example=FLAGS.n_fisher_values_per_example,
        min_prob_class=FLAGS.min_prob_class,
        max_classes=FLAGS.max_classes,
    )

    saver = lvrm_pefs.LvrmPefSaver(
        fisher_computer=fisher_computer,
        n_examples=FLAGS.n_examples,
        use_tqdm=True,
    )

    # TODO: Maybe add some safety in case n_examples is greater than the size of the dataset?
    output_path = os.path.expanduser(FLAGS.output_path)
    saver.compute_and_save_pefs(output_path, ds)


if __name__ == "__main__":
    app.run(main)
