R"""Computes the dense Fisher for a model."""

# # (Maybe) temporary workaround for some SSL certificate stuff.
# if True:
#     import ssl
#     ssl._create_default_https_context = ssl._create_unverified_context

import os

from absl import app
from absl import flags

import tensorflow as tf
from tqdm import tqdm

from em import datasets as em_datasets
from em.models import em_models

from em.fishers import diagonal2
from em.util import hdf5_util

FLAGS = flags.FLAGS

FISHER_ALGORITHMS = diagonal2.FISHER_ALGORITHMS

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("tokenizer", None, "Defaults to the value of --model if not set. Not used for image models.")

flags.DEFINE_string("task", None, "")

flags.DEFINE_list("split", ["train"], "If multiple splits are provided, will run on their concatenation.")
flags.DEFINE_integer("n_examples", None, "")
flags.DEFINE_integer("batch_size", 16, "")
flags.DEFINE_integer("sequence_length", 128, "Note that this is used to set image sizes as well.")
#
flags.DEFINE_bool("shuffle", False, "")
flags.DEFINE_integer("skip", None, "")

flags.DEFINE_integer("logits_batch_size", None,
                     "Number of output classes to compute gradients for at a time before summing. "
                     "Leaving this set to None means to do all labels before summing.")

# flags.DEFINE_bool("expectation_wrt_logits", True, "")

flags.DEFINE_enum('fisher_algorithm', FISHER_ALGORITHMS[0], list(FISHER_ALGORITHMS), '')
flags.DEFINE_float("min_prob_class", 0.0, "")


flags.DEFINE_bool("ds_force_deterministic", False, "Only has effects for some datasets.")

###############################################################################


def create_dataset(tokenizer):
    if FLAGS.task.startswith('winogrande/') and FLAGS.ds_force_deterministic:
        extra_kwargs = {'force_deterministic': True}
    else:
        extra_kwargs = {}

    ds = None

    for split in FLAGS.split:
        split_ds = em_datasets.load(
            FLAGS.task,
            split=split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.sequence_length,
            **extra_kwargs,
        )
        if ds is None:
            ds = split_ds
        else:
            ds = ds.concatenate(split_ds)
    
    if FLAGS.skip is not None:
        ds = ds.skip(FLAGS.skip)

    if FLAGS.shuffle:
        ds = ds.shuffle(1000)
    ds = ds.repeat().take(FLAGS.n_examples).cache()
    ds = ds.batch(FLAGS.batch_size)
    return ds


def compute_fishers(fisher_computer, ds):
    fishers = [
        tf.Variable(tf.zeros(w.shape), trainable=False, name=f"fisher/{w.name}")
        for w in fisher_computer.variables
    ]

    n_examples = 0

    for batch, _ in tqdm(ds):
        n_examples += diagonal2.batch_size_from_batch(batch)
        _, batch_fishers = fisher_computer.compute_exact_fisher_and_logits_for_batch(batch)
        for f, bf in zip(fishers, batch_fishers):
            f.assign_add(bf)

    for f in fishers:
        f.assign(f / float(n_examples))

    return fishers


###############################################################################

def main(_):
    model_str = os.path.expanduser(FLAGS.model)
    model = em_models.from_pretrained(model_str, from_pt=FLAGS.from_pt)

    # NOTE: For transformers, trainable_variables and variables are the same. I am
    # not sure what is correct for vision models.
    variables = model.trainable_variables

    tokenizer = em_models.load_tokenizer(FLAGS.tokenizer or model_str)
    ds = create_dataset(tokenizer)

    fisher_computer = diagonal2.FisherComputer(
        model=model,
        variables=variables,
        per_example=False,
        expectation_wrt_logits=True,
        logits_batch_size=FLAGS.logits_batch_size,
        algorithm=FLAGS.fisher_algorithm,
        min_prob_class=FLAGS.min_prob_class,
    )

    fishers = compute_fishers(fisher_computer, ds)

    output_path = os.path.expanduser(FLAGS.output_path)
    hdf5_util.save_variables_to_hdf5(fishers, output_path)


if __name__ == "__main__":
    app.run(main)
