R"""Script for computing and saving LRM-PEFs to disk."""
import os

from absl import app
from absl import flags

from em import datasets as em_datasets
from em.models import em_models
from em.fishers import lrm_pefs
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS


flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("tokenizer", None, "Defaults to the value of --model if not set. Not used for image models.")

flags.DEFINE_string("task", None, "")
#
flags.DEFINE_list("split", ["train"], "If multiple splits are provided, will run on their concatenation.")
flags.DEFINE_integer("n_examples", None, "")
flags.DEFINE_integer("batch_size", 16, "")
flags.DEFINE_integer("sequence_length", 128, "Note that this is used to set image sizes as well.")
#
flags.DEFINE_bool("shuffle", False, "")
flags.DEFINE_integer("skip", None, "")

flags.DEFINE_bool("ds_force_deterministic", False, "Only has effects for some datasets.")

flags.DEFINE_integer("n_fisher_values_per_example", None, "")

flags.DEFINE_integer("top_k_classes", None, "Only compute PEFs for this many top classes. Leave undefined to do all.")


###############################################################################


def create_dataset(tokenizer):
    if FLAGS.task.startswith('winogrande/') and FLAGS.ds_force_deterministic:
        extra_kwargs = {'force_deterministic': True}
    else:
        extra_kwargs = {}

    ds = None

    for split in FLAGS.split:
        split_ds = em_datasets.load(
            FLAGS.task,
            split=split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.sequence_length,
            **extra_kwargs,
        )
        if ds is None:
            ds = split_ds
        else:
            ds = ds.concatenate(split_ds)
    
    if FLAGS.skip is not None:
        ds = ds.skip(FLAGS.skip)

    if FLAGS.shuffle:
        ds = ds.shuffle(1000)
    ds = ds.repeat().take(FLAGS.n_examples).cache()
    ds = ds.batch(FLAGS.batch_size)
    return ds


def main(_):
    model_str = os.path.expanduser(FLAGS.model)
    model = em_models.from_pretrained(model_str, from_pt=FLAGS.from_pt)

    # NOTE: For transformers, trainable_variables and variables are the same. I am
    # not sure what is correct for vision models.
    variables = model.trainable_variables

    tokenizer = em_models.load_tokenizer(FLAGS.tokenizer or model_str)
    ds = create_dataset(tokenizer)

    fisher_computer = lrm_pefs.SparseLrmPefComputer(
        model=model,
        variables=variables,
        n_values_per_example=FLAGS.n_fisher_values_per_example,
        top_k_classes=FLAGS.top_k_classes,
        # vectorized=False,
    )

    saver = lrm_pefs.StreamingLrmPefSaver(
        fisher_computer=fisher_computer,
        n_examples=FLAGS.n_examples,
        use_tqdm=True,
    )

    # TODO: Maybe add some safety in case n_examples is greater than the size of the dataset?
    output_path = os.path.expanduser(FLAGS.output_path)
    saver.compute_and_save_pefs(output_path, ds)


if __name__ == "__main__":
    app.run(main)
