R"""Second version of the script to save PEFs.

I started writing this to support image models. I also removed some
code that I wasn't using anymore in the original.
"""
import collections
import functools
import os
import time
from typing import Dict, List, Optional, Sequence, Union

from absl import app
from absl import flags
import h5py
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em import datasets as em_datasets
from em.models import em_models

from em.fishers import diagonal2
from em.fishers import per_example2

from em.util import sparse_tf_util
from em.util import vat_da_faak_vpn

FISHER_ALGORITHMS = diagonal2.FISHER_ALGORITHMS
FISHER_FLAVORS = per_example2.FISHER_FLAVORS


FLAGS = flags.FLAGS

_FLAVORS = ['full', 'sparse_static', 'sparse_dynamic_raw', 'sparse_dynamic_metric_derived']

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("tokenizer", None, "Defaults to the value of --model if not set. Not used for image models.")

flags.DEFINE_string("task", None, "")
#
flags.DEFINE_list("split", ["train"], "If multiple splits are provided, will run on their concatenation.")
flags.DEFINE_integer("n_examples", None, "")
flags.DEFINE_integer("batch_size", 16, "")
flags.DEFINE_integer("sequence_length", 128, "Note that this is used to set image sizes as well.")
#
flags.DEFINE_bool("shuffle", False, "")
flags.DEFINE_integer("skip", None, "")

flags.DEFINE_integer("logits_batch_size", None,
                     "Number of output classes to compute gradients for at a time before summing. "
                     "Leaving this set to None means to do all labels before summing.")

# flags.DEFINE_bool("expectation_wrt_logits", True, "")

flags.DEFINE_enum('flavor', None, list(FISHER_FLAVORS), '')

flags.DEFINE_enum('fisher_algorithm', FISHER_ALGORITHMS[0], list(FISHER_ALGORITHMS), '')
flags.DEFINE_float("min_prob_class", 0.0, "")


flags.DEFINE_bool("ds_force_deterministic", False, "Only has effects for some datasets.")

# Required for "sparse_dynamic_*":
# TODO: Allow multiple ways of determining how many values to keep
# from each example's Fisher, perhaps supporting keeping a different
# number of values for each example.
flags.DEFINE_integer("n_fisher_values_per_example", None, "")


def create_dataset(tokenizer):
    if FLAGS.task.startswith('winogrande/') and FLAGS.ds_force_deterministic:
        extra_kwargs = {'force_deterministic': True}
    else:
        extra_kwargs = {}

    ds = None

    for split in FLAGS.split:
        split_ds = em_datasets.load(
            FLAGS.task,
            split=split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.sequence_length,
            **extra_kwargs,
        )
        if ds is None:
            ds = split_ds
        else:
            ds = ds.concatenate(split_ds)
    
    if FLAGS.skip is not None:
        ds = ds.skip(FLAGS.skip)

    if FLAGS.shuffle:
        ds = ds.shuffle(1000)
    ds = ds.repeat().take(FLAGS.n_examples).cache()
    ds = ds.batch(FLAGS.batch_size)
    return ds


###############################################################################

def main(_):
    if FLAGS.flavor == 'sparse_dynamic_raw':
        assert FLAGS.n_fisher_values_per_example is not None

    model_str = os.path.expanduser(FLAGS.model)
    model = em_models.from_pretrained(model_str, from_pt=FLAGS.from_pt)

    # NOTE: For transformers, trainable_variables and variables are the same. I am
    # not sure what is correct for vision models.
    variables = model.trainable_variables

    tokenizer = em_models.load_tokenizer(FLAGS.tokenizer or model_str)
    ds = create_dataset(tokenizer)

    fisher_computer = diagonal2.FisherComputer(
        model=model,
        variables=variables,
        per_example=True,
        expectation_wrt_logits=True,
        logits_batch_size=FLAGS.logits_batch_size,
        algorithm=FLAGS.fisher_algorithm,
        min_prob_class=FLAGS.min_prob_class,
    )

    saver = per_example2.StreamingPefSaver(
        fisher_computer=fisher_computer,
        flavor=FLAGS.flavor,
        n_fisher_values_per_example=FLAGS.n_fisher_values_per_example,
        use_tqdm=True,
    )

    # TODO: Maybe add some safety in case n_examples is greater than the size of the dataset?
    output_path = os.path.expanduser(FLAGS.output_path)
    saver.compute_and_save_pefs(output_path, ds, n_examples=FLAGS.n_examples)


if __name__ == "__main__":
    app.run(main)
