R"""Computes and saves per-example Fishers to disk.

Generating Fishers on the fly is often quite slow, so doing it once
and saving to disk can greatly speed development of downstream methods.

TODO: Add support for writing sharded files.
"""
import os

from absl import app
from absl import flags
from absl import logging

import h5py
import tensorflow as tf

from em.datasets import divisibility as divis_ds
from em.fishers import generate_per_example as gpe
from em.models import divis_models

FLAGS = flags.FLAGS

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")
flags.DEFINE_string("model", None, "Path to h5 file containing model info and weights.")

flags.DEFINE_integer("min_divisor", 2, "")
flags.DEFINE_integer("max_divisor", 13, "")
flags.DEFINE_integer("min_dividend", 100, "")
flags.DEFINE_integer("max_dividend", 999_999_999, "")

flags.DEFINE_integer("n_examples", None, "")
flags.DEFINE_integer("batch_size", 16, "")

# NOTE: We only support the sparse_dynamic_raw flavor at the moment.
flags.DEFINE_integer("n_fisher_values_per_example", None, "")

flags.DEFINE_bool("expectation_wrt_logits", True, "")

_DS_BUFFER_SIZE = 4 * 1024 * 1024


def get_dataset_config():
    return divis_ds.DivisibilityDatasetConfig(
        min_divisor=FLAGS.min_divisor,
        max_divisor=FLAGS.max_divisor,
        min_dividend=FLAGS.min_dividend,
        max_dividend=FLAGS.max_dividend,
    )


def main(_):
    model, model_config = divis_models.load_model_from_file(FLAGS.model)

    ds_config = get_dataset_config()
    ds = divis_ds.create_ds(ds_config, buffer_size=_DS_BUFFER_SIZE)

    # TODO: Support other flavors of per-example Fishers.
    flavor_config = gpe.SparseDynamicRawConfig(
        n_fisher_values_per_example=FLAGS.n_fisher_values_per_example,
    )

    # Quick cheese.
    setattr(model, 'num_labels', 2)

    with h5py.File(os.path.expanduser(FLAGS.output_path), "w") as f:
        saver = gpe.DivisibilityGeneratorAndSaver(
            model=model,
            variables=model.trainable_variables,
            dataset=ds,
            n_examples=FLAGS.n_examples,
            batch_size=FLAGS.batch_size,
            flavor_config=flavor_config,
            file=f,
            expectation_wrt_logits=FLAGS.expectation_wrt_logits,
            dataset_config=ds_config,
        )
        saver.generate_and_save(use_tqdm=True)


if __name__ == "__main__":
    app.run(main)
